package chat import ( "context" "encoding/json" "errors" "fmt" "net" "net/url" "regexp" "strconv" "strings" "wechat-api/ent" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "wechat-api/internal/utils/contextkey" "wechat-api/ent/custom_types" "wechat-api/ent/predicate" "wechat-api/ent/usagedetail" "wechat-api/ent/usagetotal" "github.com/zeromicro/go-zero/core/logx" ) type ChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } type FastgptChatLogic struct { ChatCompletionsLogic } type MismatchChatLogic struct { ChatCompletionsLogic } type baseLogicWorkflow interface { AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) } func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic { return &ChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) { l.ChatCompletionsLogic.AdjustRequest(req, apiKeyObj) //先父类的参数调整 if req.EventType != "fastgpt" { return } if len(req.Model) > 0 { if req.Variables == nil { req.Variables = make(map[string]string) } req.Variables["model"] = req.Model } if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 { req.FastgptChatId = req.ChatId } else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 { req.ChatId = req.FastgptChatId } } func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) { // todo: add your logic here and delete this line var ( apiKeyObj *ent.ApiKey ok bool ) asyncMode = false //从上下文中获取鉴权中间件埋下的apiAuthInfo apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx) if !ok { return asyncMode, nil, errors.New("content get auth info err") } if req.WorkId == "TEST_DOUYIN" || req.WorkId == "TEST_DOUYIN_CN" || req.WorkId == "travel" || req.WorkId == "loreal" || req.WorkId == "xiulike-dc" { //临时加 apiKeyObj.OpenaiBase = "http://cn-agent.gkscrm.com/api/v1/" if req.WorkId == "TEST_DOUYIN" || req.WorkId == "TEST_DOUYIN_CN" { workToken = "fastgpt-jsMmQKEM5uX7tDimT1zHlZHkBhMHRT2k61YaxyDJRZTUHehID7sG8BKXADNIU" } else if req.WorkId == "travel" { workToken = "fastgpt-bcnfFtw1lXWdmYGOv165UVD5R1kY28tyXX8SJv8MHhrSMOgVJsuU" } else if req.WorkId == "loreal" { workToken = "fastgpt-qqJeBEkwhgx7wR9fvGToQygmOb7FVjAbGBTFjOAMbd95InEtndke" } else if req.WorkId == "xiulike-dc" { workToken = "fastgpt-ir9RgnKHMT9HIOnPsUCFChN15ZbW9kt1lbd5Y0ohfLw9gOz3KcPrfaZWHRB" } } /* fmt.Println("=========================================") fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key) fmt.Printf("Auth Token:'%s'\n", apiKeyObj.Key) fmt.Printf("ApiKey AgentID:%d\n", apiKeyObj.AgentID) fmt.Printf("ApiKey APIBase:'%s'\n", apiKeyObj.Edges.Agent.APIBase) fmt.Printf("ApiKey APIKey:'%s'\n", apiKeyObj.Edges.Agent.APIKey) fmt.Printf("ApiKey Type:%d\n", apiKeyObj.Edges.Agent.Type) fmt.Printf("ApiKey Model:'%s'\n", apiKeyObj.Edges.Agent.Model) fmt.Printf("EventType:'%s'\n", req.EventType) fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId) fmt.Println("=========================================") */ //根据请求产生相关的工作流接口集 wf, err := l.getLogicWorkflow(apiKeyObj, req) if err != nil { return false, nil, err } //微调部分请求参数 wf.AdjustRequest(req, apiKeyObj) if isAsyncReqest(req) { //异步请求处理模式 asyncMode = true err = wf.AppendAsyncRequest(apiKeyObj, req) } else { //同步请求处理模式 resp, err = wf.DoSyncRequest(apiKeyObj, req) if err == nil && resp != nil && len(resp.Choices) > 0 { wf.AppendUsageDetailLog(apiKeyObj.Key, req, resp) //请求记录 } else if resp != nil && len(resp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问地址及权限") } } return asyncMode, resp, err } func (l *ChatCompletionsLogic) getLogicWorkflow(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (baseLogicWorkflow, error) { var ( err error wf baseLogicWorkflow ) if apiKeyObj.Edges.Agent.Type != 2 { err = fmt.Errorf("api agent type not support(%d)", apiKeyObj.Edges.Agent.Type) } else if req.EventType == "mismatch" { wf = &MismatchChatLogic{ChatCompletionsLogic: *l} } else { wf = &FastgptChatLogic{ChatCompletionsLogic: *l} } return wf, err } func (l *ChatCompletionsLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) { if len(req.EventType) == 0 { req.EventType = "fastgpt" } if len(req.Model) == 0 && len(apiKeyObj.Edges.Agent.Model) > 0 { req.Model = apiKeyObj.Edges.Agent.Model } //异步任务相关参数调整 if req.IsBatch { //流模式暂时不支持异步模式 //Callback格式非法则取消批量模式 if req.Stream || !IsValidURL(&req.Callback, true) { req.IsBatch = false } } } func (l *ChatCompletionsLogic) DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error) { //return compapi.NewFastgptChatCompletions(l.ctx, apiKeyObj.Edges.Agent.APIKey, apiKeyObj.Edges.Agent.APIBase, req) return compapi.NewClient(l.ctx, compapi.WithApiBase(apiKeyObj.Edges.Agent.APIBase), compapi.WithApiKey(apiKeyObj.Edges.Agent.APIKey)). Chat(req) } func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error { rawReqBs, err := json.Marshal(*req) if err != nil { return err } rawReqStr := string(rawReqBs) res, err := l.svcCtx.DB.CompapiAsynctask.Create(). SetNotNilAuthToken(&apiKeyObj.Key). SetNotNilOpenaiBase(&apiKeyObj.Edges.Agent.APIBase). SetNotNilOpenaiKey(&apiKeyObj.Edges.Agent.APIKey). SetNotNilOrganizationID(&apiKeyObj.OrganizationID). SetNotNilEventType(&req.EventType). SetNillableModel(&req.Model). SetNillableChatID(&req.ChatId). SetNillableResponseChatItemID(&req.ResponseChatItemId). SetNotNilRequestRaw(&rawReqStr). SetNotNilCallbackURL(&req.Callback). Save(l.ctx) if err == nil { logx.Infof("appendAsyncRequest succ,get id:%d", res.ID) } return err } func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { logType := 5 rawReqResp := custom_types.OriginalData{Request: req, Response: resp} tmpId := 0 tmpId, _ = strconv.Atoi(resp.ID) sessionId := uint64(tmpId) orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx) if ok { orgId = apiKeyObj.OrganizationID } promptTokens := uint64(resp.Usage.PromptTokens) completionToken := uint64(resp.Usage.CompletionTokens) totalTokens := promptTokens + completionToken msgContent := getMessageContentStr(req.Messages[0].Content) _, _, _ = logType, sessionId, totalTokens res, err := l.svcCtx.DB.UsageDetail.Create(). SetNotNilType(&logType). SetNotNilBotID(&authToken). SetNotNilReceiverID(&req.EventType). SetNotNilSessionID(&sessionId). SetNillableRequest(&msgContent). SetNillableResponse(&resp.Choices[0].Message.Content). SetNillableOrganizationID(&orgId). SetOriginalData(rawReqResp). SetNillablePromptTokens(&promptTokens). SetNillableCompletionTokens(&completionToken). SetNillableTotalTokens(&totalTokens). Save(l.ctx) if err == nil { //插入UsageDetai之后再统计UsageTotal l.updateUsageTotal(authToken, res.ID, orgId) } return err } func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) { var predicates []predicate.UsageTotal predicates = append(predicates, usagetotal.BotIDEQ(authToken)) return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx) } func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error { Id, err := l.getUsagetotalIdByToken(authToken) if err != nil && !ent.IsNotFound(err) { return err } if Id > 0 { //UsageTotal have record by newUsageDetailId _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id). SetTotalTokens(sumTotalTokens). SetEndIndex(newUsageDetailId). Save(l.ctx) } else { //create new record by newUsageDetailId logType := 5 _, err = l.svcCtx.DB.UsageTotal.Create(). SetNotNilBotID(&authToken). SetNotNilEndIndex(&newUsageDetailId). SetNotNilTotalTokens(&sumTotalTokens). SetNillableType(&logType). SetNotNilOrganizationID(&orgId). Save(l.ctx) } return err } func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error { sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens if err == nil { err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens } return err } // sum total_tokens from usagedetail by AuthToken func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) { var predicates []predicate.UsageDetail predicates = append(predicates, usagedetail.BotIDEQ(authToken)) var res []struct { Sum, Min, Max, Count uint64 } totalTokens := uint64(0) var err error = nil err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"), ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res) if err == nil { if len(res) > 0 { totalTokens = res[0].Sum } else { totalTokens = 0 } } return totalTokens, err } func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) { if eventType != "fastgpt" { return } obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId) } // 合法域名正则(支持通配符、中文域名等场景按需调整) var domainRegex = regexp.MustCompile( // 多级域名(如 example.com) `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` + `|` + // 单级域名(如 localhost 或 mytest-svc) `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`, ) func IsValidURL(input *string, adjust bool) bool { // 空值直接返回 if *input == "" { return false } inputStr := *input // --- 预处理输入:自动补全协议 --- // 若输入不包含协议头,默认添加 http:// if !strings.Contains(*input, "://") { inputStr = "http://" + *input } // --- 解析 URL --- u, err := url.Parse(inputStr) if err != nil { return false } // --- 校验协议 --- // 只允许常见协议(按需扩展) switch u.Scheme { case "http", "https", "ftp", "ftps": default: return false } // --- 拆分 Host 和 Port --- host, port, err := net.SplitHostPort(u.Host) if err != nil { // 无端口时,整个 Host 作为主机名 host = u.Host port = "" } // --- 校验主机名 --- // 场景1:IPv4 或 IPv6 if ip := net.ParseIP(host); ip != nil { // 允许私有或保留 IP(按需调整) // 示例中允许所有合法 IP } else { // 场景2:域名(包括 localhost) if !domainRegex.MatchString(host) { return false } } // --- 校验端口 --- if port != "" { p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80) if err != nil { // 直接尝试解析为数字端口 numPort, err := strconv.Atoi(port) if err != nil || numPort < 1 || numPort > 65535 { return false } } else if p == 0 { // 动态端口为 0 时无效 return false } } if adjust { *input = inputStr } return true } func getMessageContentStr(input any) string { str := "" switch val := input.(type) { case string: str = val case []interface{}: if len(val) > 0 { if valc, ok := val[0].(map[string]interface{}); ok { if valcc, ok := valc["text"]; ok { str, _ = valcc.(string) } } } } return str } func isAsyncReqest(req *types.CompApiReq) bool { return req.IsBatch }