package chat import ( "context" "errors" "strconv" "wechat-api/ent" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "wechat-api/internal/utils/contextkey" "wechat-api/ent/custom_types" "wechat-api/ent/predicate" "wechat-api/ent/usagedetail" "wechat-api/ent/usagetotal" "github.com/zeromicro/go-zero/core/logx" ) type ChatCompletionsLogic struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic { return &ChatCompletionsLogic{ Logger: logx.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} } func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) { // todo: add your logic here and delete this line /* 1.鉴权获得token 2.必要参数检测及转换 3. 根据event_type选择不同处理路由 */ var ( apiKeyObj *ent.ApiKey ok bool ) workToken := compapi.GetWorkTokenByID(req.EventType, req.WorkId) apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx) if !ok { return nil, errors.New("content get token err") } /* fmt.Println("=========================================") fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key) fmt.Printf("Title:'%s'\n", apiKeyObj.Title) fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase) fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey) fmt.Printf("workToken:'%s' because %s/%s\n", workToken, req.EventType, req.WorkId) fmt.Println("=========================================") */ if len(apiKeyObj.OpenaiBase) == 0 || len(workToken) == 0 { return nil, errors.New("not auth info") } apiResp, err := l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase) if err == nil && apiResp != nil { l.doRequestLog(req, apiResp) //请求记录 } return apiResp, err } func (l *ChatCompletionsLogic) doRequestLog(req *types.CompApiReq, resp *types.CompOpenApiResp) error { authToken, ok := contextkey.OpenapiTokenKey.GetValue(l.ctx) if !ok { return errors.New("content get auth token err") } return l.appendUsageDetailLog(authToken, req, resp) } func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) { var predicates []predicate.UsageTotal predicates = append(predicates, usagetotal.BotIDEQ(authToken)) return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx) } func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error { Id, err := l.getUsagetotalIdByToken(authToken) if err != nil && !ent.IsNotFound(err) { return err } if Id > 0 { //UsageTotal have record by newUsageDetailId _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id). SetTotalTokens(sumTotalTokens). SetEndIndex(newUsageDetailId). Save(l.ctx) } else { //create new record by newUsageDetailId logType := 5 _, err = l.svcCtx.DB.UsageTotal.Create(). SetNotNilBotID(&authToken). SetNotNilEndIndex(&newUsageDetailId). SetNotNilTotalTokens(&sumTotalTokens). SetNillableType(&logType). SetNotNilOrganizationID(&orgId). Save(l.ctx) } return err } func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error { sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens if err == nil { err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens } return err } // sum total_tokens from usagedetail by AuthToken func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) { var predicates []predicate.UsageDetail predicates = append(predicates, usagedetail.BotIDEQ(authToken)) var res []struct { Sum, Min, Max, Count uint64 } totalTokens := uint64(0) var err error = nil err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"), ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res) if err == nil { if len(res) > 0 { totalTokens = res[0].Sum } else { totalTokens = 0 } } return totalTokens, err } func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { logType := 5 workIdx := compapi.GetWorkIdxByID(req.EventType, req.WorkId) rawReqResp := custom_types.OriginalData{Request: req, Response: resp} tmpId := 0 tmpId, _ = strconv.Atoi(resp.ID) sessionId := uint64(tmpId) orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx) if ok { orgId = apiKeyObj.OrganizationID } promptTokens := uint64(resp.Usage.PromptTokens) completionToken := uint64(resp.Usage.CompletionTokens) totalTokens := promptTokens + completionToken res, err := l.svcCtx.DB.UsageDetail.Create(). SetNotNilType(&logType). SetNotNilBotID(&authToken). SetNotNilReceiverID(&req.EventType). SetNotNilSessionID(&sessionId). SetNillableApp(&workIdx). SetNillableRequest(&req.Messages[0].Content). SetNillableResponse(&resp.Choices[0].Message.Content). SetNillableOrganizationID(&orgId). SetOriginalData(rawReqResp). SetNillablePromptTokens(&promptTokens). SetNillableCompletionTokens(&completionToken). SetNillableTotalTokens(&totalTokens). Save(l.ctx) if err == nil { //插入UsageDetai之后再统计UsageTotal l.updateUsageTotal(authToken, res.ID, orgId) } return err } func (l *ChatCompletionsLogic) workForFastgpt(req *types.CompApiReq, apiKey string, apiBase string) (resp *types.CompOpenApiResp, err error) { //apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP" if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 { req.FastgptChatId = req.ChatId } if len(req.Model) > 0 { if req.Variables == nil { req.Variables = make(map[string]string) } req.Variables["model"] = req.Model } return compapi.NewFastgptChatCompletions(l.ctx, apiKey, apiBase, req) }