package chat import ( "context" "encoding/json" "errors" "net" "net/url" "regexp" "strconv" "strings" "wechat-api/ent" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "wechat-api/internal/utils/contextkey" "wechat-api/ent/custom_types" "wechat-api/ent/predicate" "wechat-api/ent/usagedetail" "wechat-api/ent/usagetotal" "github.com/zeromicro/go-zero/core/logx" ) type ChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic { return &ChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) { // todo: add your logic here and delete this line /* 1.鉴权获得token 2.必要参数检测及转换 3. 根据event_type选择不同处理路由 */ var ( apiKeyObj *ent.ApiKey ok bool ) asyncMode = false //微调部分请求参数 reqAdjust(req) //从上下文中获取鉴权中间件埋下的apiAuthInfo apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx) if !ok { return asyncMode, nil, errors.New("content get auth info err") } //微调apiKeyObj的openaikey //apiKeyObjAdjust(req.EventType, req.WorkId, apiKeyObj) /* fmt.Println("=========================================") fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key) fmt.Printf("Title:'%s'\n", apiKeyObj.Title) fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase) fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey) fmt.Printf("workToken:'%s' because %s/%s\n", apiKeyObj.OpenaiKey, req.EventType, req.WorkId) fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId) fmt.Printf("apiKeyObj.CreatedAt:'%v' || apiKeyObj.UpdatedAt:'%v'\n", apiKeyObj.CreatedAt, apiKeyObj.UpdatedAt) fmt.Println("=========================================") */ if isAsyncReqest(req) { //异步请求处理模式 //fmt.Println("~~~~~~~~~~~~~~~~~~~isAsyncReqest:", req.Callback) asyncMode = true err = l.appendAsyncRequest(apiKeyObj, req) } else { //同步请求处理模式 //fmt.Println("~~~~~~~~~~~~~~~~~~~isSyncReqest") resp, err = l.workForFastgpt(apiKeyObj, req) if err == nil && resp != nil && len(resp.Choices) > 0 { l.doSyncRequestLog(apiKeyObj, req, resp) //请求记录 } else if resp != nil && len(resp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问地址及权限") } } return asyncMode, resp, err } func (l *ChatCompletionsLogic) appendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error { //workIDIdx := int8(compapi.GetWorkIdxByID(req.EventType, req.WorkId)) rawReqBs, err := json.Marshal(*req) if err != nil { return err } rawReqStr := string(rawReqBs) res, err := l.svcCtx.DB.CompapiAsynctask.Create(). SetNotNilAuthToken(&apiKeyObj.Key). SetNotNilEventType(&req.EventType). SetNillableChatID(&req.ChatId). //SetNillableWorkidIdx(&workIDIdx). SetNotNilOpenaiBase(&apiKeyObj.OpenaiBase). SetNotNilOpenaiKey(&apiKeyObj.OpenaiKey). SetNotNilRequestRaw(&rawReqStr). SetNotNilCallbackURL(&req.Callback). Save(l.ctx) if err == nil { logx.Infof("appendAsyncRequest succ,get id:%d", res.ID) } return err } func (l *ChatCompletionsLogic) doSyncRequestLog(obj *ent.ApiKey, req *types.CompApiReq, resp *types.CompOpenApiResp) error { return l.appendUsageDetailLog(obj.Key, req, resp) } func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) { var predicates []predicate.UsageTotal predicates = append(predicates, usagetotal.BotIDEQ(authToken)) return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx) } func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error { Id, err := l.getUsagetotalIdByToken(authToken) if err != nil && !ent.IsNotFound(err) { return err } if Id > 0 { //UsageTotal have record by newUsageDetailId _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id). SetTotalTokens(sumTotalTokens). SetEndIndex(newUsageDetailId). Save(l.ctx) } else { //create new record by newUsageDetailId logType := 5 _, err = l.svcCtx.DB.UsageTotal.Create(). SetNotNilBotID(&authToken). SetNotNilEndIndex(&newUsageDetailId). SetNotNilTotalTokens(&sumTotalTokens). SetNillableType(&logType). SetNotNilOrganizationID(&orgId). Save(l.ctx) } return err } func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error { sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens if err == nil { err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens } return err } // sum total_tokens from usagedetail by AuthToken func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) { var predicates []predicate.UsageDetail predicates = append(predicates, usagedetail.BotIDEQ(authToken)) var res []struct { Sum, Min, Max, Count uint64 } totalTokens := uint64(0) var err error = nil err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"), ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res) if err == nil { if len(res) > 0 { totalTokens = res[0].Sum } else { totalTokens = 0 } } return totalTokens, err } func getMessageContentStr(input any) string { str := "" switch val := input.(type) { case string: str = val case []interface{}: if len(val) > 0 { if valc, ok := val[0].(map[string]interface{}); ok { if valcc, ok := valc["text"]; ok { str, _ = valcc.(string) } } } } return str } func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { logType := 5 //workIdx := int(compapi.GetWorkIdxByID(req.EventType, req.WorkId)) rawReqResp := custom_types.OriginalData{Request: req, Response: resp} tmpId := 0 tmpId, _ = strconv.Atoi(resp.ID) sessionId := uint64(tmpId) orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx) if ok { orgId = apiKeyObj.OrganizationID } promptTokens := uint64(resp.Usage.PromptTokens) completionToken := uint64(resp.Usage.CompletionTokens) totalTokens := promptTokens + completionToken msgContent := getMessageContentStr(req.Messages[0].Content) _, _, _ = logType, sessionId, totalTokens res, err := l.svcCtx.DB.UsageDetail.Create(). SetNotNilType(&logType). SetNotNilBotID(&authToken). SetNotNilReceiverID(&req.EventType). SetNotNilSessionID(&sessionId). //SetNillableApp(&workIdx). SetNillableRequest(&msgContent). SetNillableResponse(&resp.Choices[0].Message.Content). SetNillableOrganizationID(&orgId). SetOriginalData(rawReqResp). SetNillablePromptTokens(&promptTokens). SetNillableCompletionTokens(&completionToken). SetNillableTotalTokens(&totalTokens). Save(l.ctx) if err == nil { //插入UsageDetai之后再统计UsageTotal l.updateUsageTotal(authToken, res.ID, orgId) } return err } func (l *ChatCompletionsLogic) workForFastgpt(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) { //apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP" return compapi.NewFastgptChatCompletions(l.ctx, apiKeyObj.OpenaiKey, apiKeyObj.OpenaiBase, req) } func reqAdjust(req *types.CompApiReq) { if len(req.EventType) == 0 { req.EventType = "fastgpt" } if req.EventType != "fastgpt" { return } if len(req.Model) > 0 { if req.Variables == nil { req.Variables = make(map[string]string) } req.Variables["model"] = req.Model } if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 { req.FastgptChatId = req.ChatId } else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 { req.ChatId = req.FastgptChatId } } func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) { if eventType != "fastgpt" { return } obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId) } // 合法域名正则(支持通配符、中文域名等场景按需调整) var domainRegex = regexp.MustCompile( // 多级域名(如 example.com) `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` + `|` + // 单级域名(如 localhost 或 mytest-svc) `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`, ) func IsValidURL(input *string, adjust bool) bool { // 空值直接返回 if *input == "" { return false } inputStr := *input // --- 预处理输入:自动补全协议 --- // 若输入不包含协议头,默认添加 http:// if !strings.Contains(*input, "://") { inputStr = "http://" + *input } // --- 解析 URL --- u, err := url.Parse(inputStr) if err != nil { return false } // --- 校验协议 --- // 只允许常见协议(按需扩展) switch u.Scheme { case "http", "https", "ftp", "ftps": default: return false } // --- 拆分 Host 和 Port --- host, port, err := net.SplitHostPort(u.Host) if err != nil { // 无端口时,整个 Host 作为主机名 host = u.Host port = "" } // --- 校验主机名 --- // 场景1:IPv4 或 IPv6 if ip := net.ParseIP(host); ip != nil { // 允许私有或保留 IP(按需调整) // 示例中允许所有合法 IP } else { // 场景2:域名(包括 localhost) if !domainRegex.MatchString(host) { return false } } // --- 校验端口 --- if port != "" { p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80) if err != nil { // 直接尝试解析为数字端口 numPort, err := strconv.Atoi(port) if err != nil || numPort < 1 || numPort > 65535 { return false } } else if p == 0 { // 动态端口为 0 时无效 return false } } if adjust { *input = inputStr } return true } func isAsyncReqest(req *types.CompApiReq) bool { if !req.IsBatch || !IsValidURL(&req.Callback, true) { return false } if req.Stream { //异步模式暂时不支持流模式 return false } return true }