upload_agent_data_logic.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package agent
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/csv"
  6. "fmt"
  7. "github.com/suyuan32/simple-admin-common/msg/errormsg"
  8. "io"
  9. "mime/multipart"
  10. agentModel "wechat-api/ent/agent"
  11. "wechat-api/hook/fastgpt"
  12. "wechat-api/internal/utils/dberrorhandler"
  13. "wechat-api/internal/svc"
  14. "wechat-api/internal/types"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "golang.org/x/text/encoding/simplifiedchinese"
  17. "golang.org/x/text/transform"
  18. )
  19. type UploadAgentDataLogic struct {
  20. logx.Logger
  21. ctx context.Context
  22. svcCtx *svc.ServiceContext
  23. }
  24. func NewUploadAgentDataLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UploadAgentDataLogic {
  25. return &UploadAgentDataLogic{
  26. Logger: logx.WithContext(ctx),
  27. ctx: ctx,
  28. svcCtx: svcCtx}
  29. }
  30. func (l *UploadAgentDataLogic) UploadAgentData(req *types.UploadDataReq, file multipart.File, agentId uint64) (*types.BaseDataInfo, error) {
  31. var count uint64 = 0
  32. reader := csv.NewReader(file)
  33. records, err := reader.ReadAll()
  34. if err != nil {
  35. return nil, err
  36. }
  37. agent, err := l.svcCtx.DB.Agent.Query().Where(agentModel.ID(agentId)).Only(l.ctx)
  38. if err != nil {
  39. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  40. }
  41. var params fastgpt.CreateBulkDataReq
  42. params.CollectionID = agent.CollectionID
  43. params.TrainingMode = "chunk"
  44. qas := make([]fastgpt.DataQuestion, 0, 100)
  45. for idx, record := range records {
  46. if idx == 0 && record[1] == "答案" {
  47. continue
  48. }
  49. // 空内容过滤
  50. if record[0] == "" || record[1] == "" {
  51. continue
  52. }
  53. fmt.Printf("转换前:question=%s, answer=%s \n", record[0], record[1])
  54. var question, answer []byte
  55. reader0 := transform.NewReader(bytes.NewReader([]byte(record[0])), simplifiedchinese.GBK.NewDecoder())
  56. question, _ = io.ReadAll(reader0)
  57. reader1 := transform.NewReader(bytes.NewReader([]byte(record[1])), simplifiedchinese.GBK.NewDecoder())
  58. answer, _ = io.ReadAll(reader1)
  59. fmt.Printf("转换后:question=%s, answer=%s \n", question, answer)
  60. qas = append(qas, fastgpt.DataQuestion{
  61. Q: string(question),
  62. A: string(answer),
  63. })
  64. length := len(qas)
  65. if length > 0 && length%100 == 0 {
  66. params.Data = qas
  67. //fmt.Printf("params=%+v\n", params)
  68. response, err := fastgpt.CreateBulkData(&params)
  69. if err != nil {
  70. l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
  71. return nil, err
  72. }
  73. count += response.Data.InsertLen
  74. qas = make([]fastgpt.DataQuestion, 0, 100)
  75. }
  76. }
  77. if len(qas) > 0 {
  78. params.Data = qas
  79. response, err := fastgpt.CreateBulkData(&params)
  80. if err != nil {
  81. l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
  82. return nil, err
  83. }
  84. count += response.Data.InsertLen
  85. qas = make([]fastgpt.DataQuestion, 0, 100)
  86. }
  87. resp := &types.BaseDataInfo{}
  88. resp.Code = 0
  89. resp.Msg = errormsg.Success
  90. resp.Data = fmt.Sprintf("upload %d rows", count)
  91. return resp, nil
  92. }