|
@@ -3,6 +3,7 @@ package chat
|
|
|
import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
+ "strconv"
|
|
|
|
|
|
"wechat-api/ent"
|
|
|
"wechat-api/internal/svc"
|
|
@@ -10,6 +11,11 @@ import (
|
|
|
"wechat-api/internal/utils/compapi"
|
|
|
"wechat-api/internal/utils/contextkey"
|
|
|
|
|
|
+ "wechat-api/ent/custom_types"
|
|
|
+ "wechat-api/ent/predicate"
|
|
|
+ "wechat-api/ent/usagedetail"
|
|
|
+ "wechat-api/ent/usagetotal"
|
|
|
+
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
|
)
|
|
|
|
|
@@ -56,7 +62,126 @@ func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (resp *typ
|
|
|
if len(apiKeyObj.OpenaiBase) == 0 || len(workToken) == 0 {
|
|
|
return nil, errors.New("not auth info")
|
|
|
}
|
|
|
- return l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
|
|
|
+
|
|
|
+ apiResp, err := l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
|
|
|
+ if err == nil {
|
|
|
+ l.doRequestLog(req, apiResp) //请求记录
|
|
|
+ }
|
|
|
+
|
|
|
+ return apiResp, err
|
|
|
+}
|
|
|
+
|
|
|
+func (l *ChatCompletionsLogic) doRequestLog(req *types.CompApiReq, resp *types.CompOpenApiResp) error {
|
|
|
+ authToken, ok := contextkey.OpenapiTokenKey.GetValue(l.ctx)
|
|
|
+ if !ok {
|
|
|
+ return errors.New("content get auth token err")
|
|
|
+ }
|
|
|
+
|
|
|
+ return l.appendUsageDetailLog(authToken, req, resp)
|
|
|
+}
|
|
|
+
|
|
|
+func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageTotal
|
|
|
+ predicates = append(predicates, usagetotal.BotIDEQ(authToken))
|
|
|
+ return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ Id, err := l.getUsagetotalIdByToken(authToken)
|
|
|
+ if err != nil && !ent.IsNotFound(err) {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if Id > 0 { //UsageTotal have record by newUsageDetailId
|
|
|
+ _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
|
|
|
+ SetTotalTokens(sumTotalTokens).
|
|
|
+ SetEndIndex(newUsageDetailId).
|
|
|
+ Save(l.ctx)
|
|
|
+ } else { //create new record by newUsageDetailId
|
|
|
+ logType := 5
|
|
|
+ _, err = l.svcCtx.DB.UsageTotal.Create().
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilEndIndex(&newUsageDetailId).
|
|
|
+ SetNotNilTotalTokens(&sumTotalTokens).
|
|
|
+ SetNillableType(&logType).
|
|
|
+ SetNotNilOrganizationID(&orgId).
|
|
|
+ Save(l.ctx)
|
|
|
+ }
|
|
|
+
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
|
|
|
+ if err == nil {
|
|
|
+ err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+// sum total_tokens from usagedetail by AuthToken
|
|
|
+func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageDetail
|
|
|
+ predicates = append(predicates, usagedetail.BotIDEQ(authToken))
|
|
|
+
|
|
|
+ var res []struct {
|
|
|
+ Sum, Min, Max, Count uint64
|
|
|
+ }
|
|
|
+ totalTokens := uint64(0)
|
|
|
+ var err error = nil
|
|
|
+ err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
|
|
|
+ ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
|
|
|
+ if err == nil {
|
|
|
+ if len(res) > 0 {
|
|
|
+ totalTokens = res[0].Sum
|
|
|
+ } else {
|
|
|
+ totalTokens = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return totalTokens, err
|
|
|
+}
|
|
|
+
|
|
|
+func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
|
|
|
+
|
|
|
+ logType := 5
|
|
|
+ workIdx := compapi.GetWorkIdxByID(req.EventType, req.WorkId)
|
|
|
+ rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
|
|
|
+
|
|
|
+ tmpId := 0
|
|
|
+ tmpId, _ = strconv.Atoi(resp.ID)
|
|
|
+ sessionId := uint64(tmpId)
|
|
|
+ orgId := uint64(0)
|
|
|
+ apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
|
|
|
+ if ok {
|
|
|
+ orgId = apiKeyObj.OrganizationID
|
|
|
+ }
|
|
|
+ promptTokens := uint64(resp.Usage.PromptTokens)
|
|
|
+ completionToken := uint64(resp.Usage.CompletionTokens)
|
|
|
+ totalTokens := promptTokens + completionToken
|
|
|
+
|
|
|
+ res, err := l.svcCtx.DB.UsageDetail.Create().
|
|
|
+ SetNotNilType(&logType).
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilReceiverID(&req.EventType).
|
|
|
+ SetNotNilSessionID(&sessionId).
|
|
|
+ SetNillableApp(&workIdx).
|
|
|
+ SetNillableRequest(&req.Messages[0].Content).
|
|
|
+ SetNillableResponse(&resp.Choices[0].Message.Content).
|
|
|
+ SetNillableOrganizationID(&orgId).
|
|
|
+ SetOriginalData(rawReqResp).
|
|
|
+ SetNillablePromptTokens(&promptTokens).
|
|
|
+ SetNillableCompletionTokens(&completionToken).
|
|
|
+ SetNillableTotalTokens(&totalTokens).
|
|
|
+ Save(l.ctx)
|
|
|
+
|
|
|
+ if err == nil { //插入UsageDetai之后再统计UsageTotal
|
|
|
+ l.updateUsageTotal(authToken, res.ID, orgId)
|
|
|
+ }
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
func (l *ChatCompletionsLogic) workForFastgpt(req *types.CompApiReq, apiKey string, apiBase string) (resp *types.CompOpenApiResp, err error) {
|