package chat import ( "context" "errors" "fmt" "net" "net/url" "strconv" "strings" "time" "unicode" "wechat-api/ent" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "wechat-api/internal/utils/contextkey" "wechat-api/internal/utils/typekit" "wechat-api/ent/custom_types" "wechat-api/ent/predicate" "wechat-api/ent/usagedetail" "wechat-api/ent/usagetotal" "github.com/zeromicro/go-zero/core/logx" ) type ChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic { return &ChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) { // todo: add your logic here and delete this line /* 1.鉴权获得token 2.必要参数检测及转换 3. 根据event_type选择不同处理路由 */ var ( apiKeyObj *ent.ApiKey ok bool ) asyncMode = false //微调部分请求参数 reqAdjust(req) //从上下文中获取鉴权中间件埋下的apiAuthInfo apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx) if !ok { return asyncMode, nil, errors.New("content get auth info err") } //微调apiKeyObj的openaikey apiKeyObjAdjust(req.EventType, req.WorkId, apiKeyObj) apiKeyObj.CreatedAt = time.Now() fmt.Println(typekit.PrettyPrint(apiKeyObj)) fmt.Println("=========================================") fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key) fmt.Printf("Title:'%s'\n", apiKeyObj.Title) fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase) fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey) fmt.Printf("workToken:'%s' because %s/%s\n", apiKeyObj.OpenaiKey, req.EventType, req.WorkId) fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId) fmt.Printf("apiKeyObj.CreatedAt:'%v' || apiKeyObj.UpdatedAt:'%v'\n", apiKeyObj.CreatedAt, apiKeyObj.UpdatedAt) fmt.Println("=========================================") if isAsyncReqest(req) { //异步请求处理模式 fmt.Println("~~~~~~~~~~~~~~~~~~~isAsyncReqest:", req.Callback) asyncMode = true err = l.appendAsyncRequest(apiKeyObj, req) } else { //同步请求处理模式 fmt.Println("~~~~~~~~~~~~~~~~~~~isSyncReqest") resp, err = l.workForFastgpt(apiKeyObj, req) if err == nil && resp != nil { l.doSyncRequestLog(apiKeyObj, req, resp) //请求记录 } } return asyncMode, resp, err } func (l *ChatCompletionsLogic) appendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error { workIDIdx := int8(compapi.GetWorkIdxByID(req.EventType, req.WorkId)) rawReqResp := custom_types.OriginalData{Request: req} res, err := l.svcCtx.DB.CompapiJob.Create(). SetNotNilCallbackURL(&req.Callback). SetNotNilAuthToken(&apiKeyObj.Key). SetNotNilEventType(&req.EventType). SetNillableWorkidIdx(&workIDIdx). SetNotNilRequestJSON(&rawReqResp). SetNillableChatID(&req.ChatId). Save(l.ctx) if err == nil { logx.Infof("appendAsyncRequest succ,get id:%d", res.ID) } return err } func (l *ChatCompletionsLogic) doSyncRequestLog(obj *ent.ApiKey, req *types.CompApiReq, resp *types.CompOpenApiResp) error { return l.appendUsageDetailLog(obj.Key, req, resp) } func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) { var predicates []predicate.UsageTotal predicates = append(predicates, usagetotal.BotIDEQ(authToken)) return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx) } func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error { Id, err := l.getUsagetotalIdByToken(authToken) if err != nil && !ent.IsNotFound(err) { return err } if Id > 0 { //UsageTotal have record by newUsageDetailId _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id). SetTotalTokens(sumTotalTokens). SetEndIndex(newUsageDetailId). Save(l.ctx) } else { //create new record by newUsageDetailId logType := 5 _, err = l.svcCtx.DB.UsageTotal.Create(). SetNotNilBotID(&authToken). SetNotNilEndIndex(&newUsageDetailId). SetNotNilTotalTokens(&sumTotalTokens). SetNillableType(&logType). SetNotNilOrganizationID(&orgId). Save(l.ctx) } return err } func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error { sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens if err == nil { err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens } return err } // sum total_tokens from usagedetail by AuthToken func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) { var predicates []predicate.UsageDetail predicates = append(predicates, usagedetail.BotIDEQ(authToken)) var res []struct { Sum, Min, Max, Count uint64 } totalTokens := uint64(0) var err error = nil err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"), ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res) if err == nil { if len(res) > 0 { totalTokens = res[0].Sum } else { totalTokens = 0 } } return totalTokens, err } func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { logType := 5 workIdx := int(compapi.GetWorkIdxByID(req.EventType, req.WorkId)) rawReqResp := custom_types.OriginalData{Request: req, Response: resp} tmpId := 0 tmpId, _ = strconv.Atoi(resp.ID) sessionId := uint64(tmpId) orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx) if ok { orgId = apiKeyObj.OrganizationID } promptTokens := uint64(resp.Usage.PromptTokens) completionToken := uint64(resp.Usage.CompletionTokens) totalTokens := promptTokens + completionToken msgContent := "" switch val := req.Messages[0].Content.(type) { case string: msgContent = val case []interface{}: if len(val) > 0 { if valc, ok := val[0].(map[string]interface{}); ok { if valcc, ok := valc["text"]; ok { msgContent, _ = valcc.(string) } } } } res, err := l.svcCtx.DB.UsageDetail.Create(). SetNotNilType(&logType). SetNotNilBotID(&authToken). SetNotNilReceiverID(&req.EventType). SetNotNilSessionID(&sessionId). SetNillableApp(&workIdx). SetNillableRequest(&msgContent). SetNillableResponse(&resp.Choices[0].Message.Content). SetNillableOrganizationID(&orgId). SetOriginalData(rawReqResp). SetNillablePromptTokens(&promptTokens). SetNillableCompletionTokens(&completionToken). SetNillableTotalTokens(&totalTokens). Save(l.ctx) if err == nil { //插入UsageDetai之后再统计UsageTotal l.updateUsageTotal(authToken, res.ID, orgId) } return err } func (l *ChatCompletionsLogic) workForFastgpt(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) { //apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP" return compapi.NewFastgptChatCompletions(l.ctx, apiKeyObj.OpenaiKey, apiKeyObj.OpenaiBase, req) } func reqAdjust(req *types.CompApiReq) { if req.EventType != "fastgpt" { return } if len(req.Model) > 0 { if req.Variables == nil { req.Variables = make(map[string]string) } req.Variables["model"] = req.Model } if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 { req.FastgptChatId = req.ChatId } else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 { req.ChatId = req.FastgptChatId } } func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) { if eventType != "fastgpt" { return } obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId) } func IsValidURL(s string) bool { // 阶段1:快速预检 if len(s) < 10 || len(s) > 2048 { // 常见URL长度范围 return false } // 阶段2:标准库解析 u, err := url.Parse(s) if err != nil || u.Scheme == "" || u.Host == "" { return false } // 阶段3:协议校验(支持常见网络协议) switch u.Scheme { case "http", "https", "ftp", "ftps", "sftp": // 允许的协议类型 default: return false } // 阶段4:主机名深度校验 host := u.Hostname() if strings.Contains(host, "..") || strings.ContainsAny(host, "!#$%&'()*,:;<=>?[]^`{|}~") { return false } // IPv4/IPv6 校验 if ip := net.ParseIP(host); ip != nil { return true // 有效IP地址 } // 域名格式校验 if !isValidDomain(host) { return false } // 阶段5:端口校验(可选) if port := u.Port(); port != "" { for _, r := range port { if !unicode.IsDigit(r) { return false } } if len(port) > 5 || port == "0" { return false } } return true } // 高性能域名校验 (支持国际化域名IDNA) func isValidDomain(host string) bool { // 快速排除非法字符 if strings.ContainsAny(host, " _+/\\") { return false } // 分段检查 labels := strings.Split(host, ".") if len(labels) < 2 { // 至少包含顶级域和二级域 return false } for _, label := range labels { if len(label) < 1 || len(label) > 63 { return false } if label[0] == '-' || label[len(label)-1] == '-' { return false } } // 最终DNS格式校验 if _, err := net.LookupHost(host); err == nil { return true // 实际DNS解析验证(根据需求开启) } return true // 若不需要实际解析可始终返回true } func isAsyncReqest(req *types.CompApiReq) bool { if !req.IsBatch || !IsValidURL(req.Callback) { return false } return true }