123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- package agent
- import (
- "bytes"
- "context"
- "encoding/csv"
- "fmt"
- "github.com/suyuan32/simple-admin-common/msg/errormsg"
- "io"
- "mime/multipart"
- "strings"
- agentModel "wechat-api/ent/agent"
- "wechat-api/hook/fastgpt"
- "wechat-api/internal/utils/dberrorhandler"
- "wechat-api/internal/svc"
- "wechat-api/internal/types"
- "github.com/zeromicro/go-zero/core/logx"
- "golang.org/x/text/encoding/simplifiedchinese"
- "golang.org/x/text/transform"
- )
- type UploadAgentDataLogic struct {
- logx.Logger
- ctx context.Context
- svcCtx *svc.ServiceContext
- }
- func NewUploadAgentDataLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UploadAgentDataLogic {
- return &UploadAgentDataLogic{
- Logger: logx.WithContext(ctx),
- ctx: ctx,
- svcCtx: svcCtx}
- }
- func (l *UploadAgentDataLogic) UploadAgentData(req *types.UploadDataReq, file multipart.File, agentId uint64) (*types.BaseDataInfo, error) {
- var count uint64 = 0
- reader := csv.NewReader(file)
- records, err := reader.ReadAll()
- if err != nil {
- return nil, err
- }
- agent, err := l.svcCtx.DB.Agent.Query().Where(agentModel.ID(agentId)).Only(l.ctx)
- if err != nil {
- return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
- }
- var params fastgpt.CreateBulkDataReq
- params.CollectionID = agent.CollectionID
- params.TrainingMode = "chunk"
- qas := make([]fastgpt.DataQuestion, 0, 100)
- for idx, record := range records {
- if idx == 0 && record[1] == "答案" {
- continue
- }
- // 空内容过滤
- if record[0] == "" || record[1] == "" {
- continue
- }
- fmt.Printf("转换前:question=%s, answer=%s \n", record[0], record[1])
- question := transCharset(record[0])
- answer := transCharset(record[1])
- fmt.Printf("转换后:question=%s, answer=%s \n", question, answer)
- qas = append(qas, fastgpt.DataQuestion{
- Q: string(question),
- A: string(answer),
- })
- length := len(qas)
- if length > 0 && length%100 == 0 {
- params.Data = qas
- //fmt.Printf("params=%+v\n", params)
- response, err := fastgpt.CreateBulkData(¶ms)
- if err != nil {
- l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
- return nil, err
- }
- count += response.Data.InsertLen
- qas = make([]fastgpt.DataQuestion, 0, 100)
- }
- }
- if len(qas) > 0 {
- params.Data = qas
- response, err := fastgpt.CreateBulkData(¶ms)
- if err != nil {
- l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
- return nil, err
- }
- count += response.Data.InsertLen
- qas = make([]fastgpt.DataQuestion, 0, 100)
- }
- resp := &types.BaseDataInfo{}
- resp.Code = 0
- resp.Msg = errormsg.Success
- resp.Data = fmt.Sprintf("upload %d rows", count)
- return resp, nil
- }
- func trim(s string) string {
- s = strings.TrimLeft(s, " \r\n\t")
- s = strings.TrimRight(s, " \r\n\t")
- return s
- }
- func transCharset(s string) string {
- s = trim(s)
- return s
- rd := transform.NewReader(bytes.NewReader([]byte(s)), simplifiedchinese.GBK.NewDecoder())
- bytes, err := io.ReadAll(rd)
- fmt.Printf("bytes=%s err=%v\n", bytes, err)
- return string(bytes)
- }
|