123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- package agent
- import (
- "bytes"
- "context"
- "encoding/csv"
- "fmt"
- "github.com/saintfish/chardet"
- "github.com/suyuan32/simple-admin-common/msg/errormsg"
- "golang.org/x/text/encoding"
- "golang.org/x/text/encoding/charmap"
- "golang.org/x/text/encoding/japanese"
- "golang.org/x/text/encoding/korean"
- "golang.org/x/text/encoding/traditionalchinese"
- "golang.org/x/text/encoding/unicode"
- "io"
- "mime/multipart"
- "strings"
- agentModel "wechat-api/ent/agent"
- "wechat-api/hook/fastgpt"
- "wechat-api/internal/utils/dberrorhandler"
- "wechat-api/internal/svc"
- "wechat-api/internal/types"
- "github.com/zeromicro/go-zero/core/logx"
- "golang.org/x/text/encoding/simplifiedchinese"
- "golang.org/x/text/transform"
- )
- type UploadAgentDataLogic struct {
- logx.Logger
- ctx context.Context
- svcCtx *svc.ServiceContext
- }
- func NewUploadAgentDataLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UploadAgentDataLogic {
- return &UploadAgentDataLogic{
- Logger: logx.WithContext(ctx),
- ctx: ctx,
- svcCtx: svcCtx}
- }
- func (l *UploadAgentDataLogic) UploadAgentData(req *types.UploadDataReq, file multipart.File, agentId uint64) (*types.BaseDataInfo, error) {
- var count uint64 = 0
- reader := csv.NewReader(file)
- records, err := reader.ReadAll()
- if err != nil {
- return nil, err
- }
- agent, err := l.svcCtx.DB.Agent.Query().Where(agentModel.ID(agentId)).Only(l.ctx)
- if err != nil {
- return nil, dberrorhandler.DefaultEntError(l.Logger, err, req)
- }
- var params fastgpt.CreateBulkDataReq
- params.CollectionID = agent.CollectionID
- params.TrainingMode = "chunk"
- qas := make([]fastgpt.DataQuestion, 0, 100)
- for idx, record := range records {
- if idx == 0 && record[1] == "答案" {
- continue
- }
- // 空内容过滤
- if record[0] == "" || record[1] == "" {
- continue
- }
- fmt.Printf("转换前:question=%s, answer=%s \n", record[0], record[1])
- question := transCharset(record[0])
- answer := transCharset(record[1])
- fmt.Printf("转换后:question=%s, answer=%s \n", question, answer)
- qas = append(qas, fastgpt.DataQuestion{
- Q: string(question),
- A: string(answer),
- })
- length := len(qas)
- if length > 0 && length%100 == 0 {
- params.Data = qas
- //fmt.Printf("params=%+v\n", params)
- response, err := fastgpt.CreateBulkData(¶ms)
- if err != nil {
- l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
- return nil, err
- }
- count += response.Data.InsertLen
- qas = make([]fastgpt.DataQuestion, 0, 100)
- }
- }
- if len(qas) > 0 {
- params.Data = qas
- response, err := fastgpt.CreateBulkData(¶ms)
- if err != nil {
- l.Logger.Errorf("batch insert data to fastgpt failed. collection=%s error=%s", agent.CollectionID, err.Error())
- return nil, err
- }
- count += response.Data.InsertLen
- qas = make([]fastgpt.DataQuestion, 0, 100)
- }
- resp := &types.BaseDataInfo{}
- resp.Code = 0
- resp.Msg = errormsg.Success
- resp.Data = fmt.Sprintf("upload %d rows", count)
- return resp, nil
- }
- func trim(s string) string {
- s = strings.TrimLeft(s, " \r\n\t")
- s = strings.TrimRight(s, " \r\n\t")
- return s
- }
- // transCharset 自动检测编码并转换为 UTF-8
- func transCharset(s string) string {
- // 1. 自动检测编码
- detector := chardet.NewTextDetector()
- result, err := detector.DetectBest([]byte(s))
- if err != nil {
- fmt.Println("Encoding detection failed:", err)
- return s
- }
- // 2. 找到相应的编码
- fmt.Println("result.Charset:", result.Charset)
- enc := getEncoding(result.Charset)
- fmt.Println("enc:", enc)
- if enc == nil {
- // 直接返回原始字符串
- fmt.Println("Unsupported charset:", result.Charset)
- return s
- }
- // 3. 转换为 UTF-8
- rd := transform.NewReader(bytes.NewReader([]byte(s)), enc.NewDecoder())
- utf8Bytes, err := io.ReadAll(rd)
- if err != nil {
- fmt.Println("Encoding conversion failed:", err)
- return s
- }
- // 4. 返回转换后的 UTF-8 字符串
- return string(utf8Bytes)
- }
- // 根据字符集名称获取 `encoding.Encoding`
- func getEncoding(charset string) encoding.Encoding {
- switch charset {
- case "UTF-8", "ASCII":
- return encoding.Nop // 无需转换
- case "ISO-8859-1":
- return charmap.ISO8859_1
- case "ISO-8859-2":
- return charmap.ISO8859_2
- case "ISO-8859-15":
- return charmap.ISO8859_15
- case "Windows-1252":
- return charmap.Windows1252
- case "Big5":
- return traditionalchinese.Big5
- case "GB-2312", "GBK", "GB-18030":
- return simplifiedchinese.GBK
- case "Shift_JIS":
- return japanese.ShiftJIS
- case "EUC-JP":
- return japanese.EUCJP
- case "EUC-KR":
- return korean.EUCKR
- case "UTF-16LE":
- return unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
- case "UTF-16BE":
- return unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM)
- default:
- return nil
- }
- }
|