upload_agent_data_logic.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package agent
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/csv"
  6. "fmt"
  7. "github.com/saintfish/chardet"
  8. "github.com/suyuan32/simple-admin-common/msg/errormsg"
  9. "golang.org/x/text/encoding"
  10. "golang.org/x/text/encoding/charmap"
  11. "golang.org/x/text/encoding/japanese"
  12. "golang.org/x/text/encoding/korean"
  13. "golang.org/x/text/encoding/traditionalchinese"
  14. "golang.org/x/text/encoding/unicode"
  15. "io"
  16. "mime/multipart"
  17. "strings"
  18. agentModel "wechat-api/ent/agent"
  19. "wechat-api/hook/fastgpt"
  20. "wechat-api/internal/utils/dberrorhandler"
  21. "wechat-api/internal/svc"
  22. "wechat-api/internal/types"
  23. "github.com/zeromicro/go-zero/core/logx"
  24. "golang.org/x/text/encoding/simplifiedchinese"
  25. "golang.org/x/text/transform"
  26. )
  27. type UploadAgentDataLogic struct {
  28. logx.Logger
  29. ctx context.Context
  30. svcCtx *svc.ServiceContext
  31. }
  32. func NewUploadAgentDataLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UploadAgentDataLogic {
  33. return &UploadAgentDataLogic{
  34. Logger: logx.WithContext(ctx),
  35. ctx: ctx,
  36. svcCtx: svcCtx}
  37. }
  38. func (l *UploadAgentDataLogic) UploadAgentData(req *types.UploadDataReq, file multipart.File, agentId uint64) (*types.BaseDataInfo, error) {
  39. var count uint64 = 0
  40. reader := csv.NewReader(file)
  41. records, err := reader.ReadAll()
  42. if err != nil {
  43. return nil, err
  44. }
  45. agent, err := l.svcCtx.DB.Agent.Query().Where(agentModel.ID(agentId)).Only(l.ctx)
  46. if err != nil {
  47. return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
  48. }
  49. var params fastgpt.CreateBulkDataReq
  50. params.CollectionID = agent.CollectionID
  51. params.TrainingMode = "chunk"
  52. qas := make([]fastgpt.DataQuestion, 0, 100)
  53. for idx, record := range records {
  54. if idx == 0 && record[1] == "答案" {
  55. continue
  56. }
  57. // 空内容过滤
  58. if record[0] == "" || record[1] == "" {
  59. continue
  60. }
  61. fmt.Printf("转换前:question=%s, answer=%s \n", record[0], record[1])
  62. question := transCharset(record[0])
  63. answer := transCharset(record[1])
  64. fmt.Printf("转换后:question=%s, answer=%s \n", question, answer)
  65. qas = append(qas, fastgpt.DataQuestion{
  66. Q: string(question),
  67. A: string(answer),
  68. })
  69. length := len(qas)
  70. if length > 0 && length%100 == 0 {
  71. params.Data = qas
  72. //fmt.Printf("params=%+v\n", params)
  73. response, err := fastgpt.CreateBulkData(&params)
  74. if err != nil {
  75. l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
  76. return nil, err
  77. }
  78. count += response.Data.InsertLen
  79. qas = make([]fastgpt.DataQuestion, 0, 100)
  80. }
  81. }
  82. if len(qas) > 0 {
  83. params.Data = qas
  84. response, err := fastgpt.CreateBulkData(&params)
  85. if err != nil {
  86. l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
  87. return nil, err
  88. }
  89. count += response.Data.InsertLen
  90. qas = make([]fastgpt.DataQuestion, 0, 100)
  91. }
  92. resp := &types.BaseDataInfo{}
  93. resp.Code = 0
  94. resp.Msg = errormsg.Success
  95. resp.Data = fmt.Sprintf("upload %d rows", count)
  96. return resp, nil
  97. }
  98. func trim(s string) string {
  99. s = strings.TrimLeft(s, " \r\n\t")
  100. s = strings.TrimRight(s, " \r\n\t")
  101. return s
  102. }
  103. // transCharset 自动检测编码并转换为 UTF-8
  104. func transCharset(s string) string {
  105. // 1. 自动检测编码
  106. detector := chardet.NewTextDetector()
  107. result, err := detector.DetectBest([]byte(s))
  108. if err != nil {
  109. fmt.Println("Encoding detection failed:", err)
  110. return s
  111. }
  112. // 2. 找到相应的编码
  113. fmt.Println("result.Charset:", result.Charset)
  114. enc := getEncoding(result.Charset)
  115. fmt.Println("enc:", enc)
  116. if enc == nil {
  117. // 直接返回原始字符串
  118. fmt.Println("Unsupported charset:", result.Charset)
  119. return s
  120. }
  121. // 3. 转换为 UTF-8
  122. rd := transform.NewReader(bytes.NewReader([]byte(s)), enc.NewDecoder())
  123. utf8Bytes, err := io.ReadAll(rd)
  124. if err != nil {
  125. fmt.Println("Encoding conversion failed:", err)
  126. return s
  127. }
  128. // 4. 返回转换后的 UTF-8 字符串
  129. return string(utf8Bytes)
  130. }
  131. // 根据字符集名称获取 `encoding.Encoding`
  132. func getEncoding(charset string) encoding.Encoding {
  133. switch charset {
  134. case "UTF-8", "ASCII":
  135. return encoding.Nop // 无需转换
  136. case "ISO-8859-1":
  137. return charmap.ISO8859_1
  138. case "ISO-8859-2":
  139. return charmap.ISO8859_2
  140. case "ISO-8859-15":
  141. return charmap.ISO8859_15
  142. case "Windows-1252":
  143. return charmap.Windows1252
  144. case "Big5":
  145. return traditionalchinese.Big5
  146. case "GB-2312", "GBK", "GB-18030":
  147. return simplifiedchinese.GBK
  148. case "Shift_JIS":
  149. return japanese.ShiftJIS
  150. case "EUC-JP":
  151. return japanese.EUCJP
  152. case "EUC-KR":
  153. return korean.EUCKR
  154. case "UTF-16LE":
  155. return unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
  156. case "UTF-16BE":
  157. return unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM)
  158. default:
  159. return nil
  160. }
  161. }