Эх сурвалжийг харах

将token统计函数移到compapi包中,为后续异步请求调用提供便利

liwei 1 өдөр өмнө
parent
commit
089c8fb8b0

+ 3 - 105
internal/logic/chat/chat_completions_logic.go

@@ -18,11 +18,6 @@ import (
 	"wechat-api/internal/utils/contextkey"
 	"wechat-api/internal/utils/typekit"
 
-	"wechat-api/ent/custom_types"
-	"wechat-api/ent/predicate"
-	"wechat-api/ent/usagedetail"
-	"wechat-api/ent/usagetotal"
-
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -185,106 +180,9 @@ func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *ty
 
 func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
 
-	logType := 5
-	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
-	tmpId := 0
-	tmpId, _ = strconv.Atoi(resp.ID)
-	sessionId := uint64(tmpId)
-	orgId := uint64(0)
-	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
-	if ok {
-		orgId = apiKeyObj.OrganizationID
-	}
-	promptTokens := uint64(resp.Usage.PromptTokens)
-	completionToken := uint64(resp.Usage.CompletionTokens)
-	totalTokens := promptTokens + completionToken
-
-	msgContent := getMessageContentStr(req.Messages[0].Content)
-
-	_, _, _ = logType, sessionId, totalTokens
-	res, err := l.svcCtx.DB.UsageDetail.Create().
-		SetNotNilType(&logType).
-		SetNotNilBotID(&authToken).
-		SetNotNilReceiverID(&req.EventType).
-		SetNotNilSessionID(&sessionId).
-		SetNillableRequest(&msgContent).
-		SetNillableResponse(&resp.Choices[0].Message.Content).
-		SetNillableOrganizationID(&orgId).
-		SetOriginalData(rawReqResp).
-		SetNillablePromptTokens(&promptTokens).
-		SetNillableCompletionTokens(&completionToken).
-		SetNillableTotalTokens(&totalTokens).
-		Save(l.ctx)
-
-	if err == nil { //插入UsageDetai之后再统计UsageTotal
-		l.updateUsageTotal(authToken, res.ID, orgId)
-	}
-	return err
-}
-
-func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageTotal
-	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
-	return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
-
-}
-
-func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
-
-	Id, err := l.getUsagetotalIdByToken(authToken)
-	if err != nil && !ent.IsNotFound(err) {
-		return err
-	}
-	if Id > 0 { //UsageTotal have record by  newUsageDetailId
-		_, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
-			SetTotalTokens(sumTotalTokens).
-			SetEndIndex(newUsageDetailId).
-			Save(l.ctx)
-	} else { //create new record by  newUsageDetailId
-		logType := 5
-		_, err = l.svcCtx.DB.UsageTotal.Create().
-			SetNotNilBotID(&authToken).
-			SetNotNilEndIndex(&newUsageDetailId).
-			SetNotNilTotalTokens(&sumTotalTokens).
-			SetNillableType(&logType).
-			SetNotNilOrganizationID(&orgId).
-			Save(l.ctx)
-	}
-
-	return err
-}
-
-func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
-
-	sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
-	if err == nil {
-		err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
-	}
-	return err
-}
-
-// sum total_tokens from usagedetail by AuthToken
-func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageDetail
-	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
-
-	var res []struct {
-		Sum, Min, Max, Count uint64
-	}
-	totalTokens := uint64(0)
-	var err error = nil
-	err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
-		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
-	if err == nil {
-		if len(res) > 0 {
-			totalTokens = res[0].Sum
-		} else {
-			totalTokens = 0
-		}
-	}
-	return totalTokens, err
+	svcCtx := &compapi.ServiceContext{Config: l.svcCtx.Config, DB: l.svcCtx.DB,
+		Rds: l.svcCtx.Rds}
+	return compapi.AppendUsageDetailLog(l.ctx, svcCtx, authToken, req, resp)
 }
 
 func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {

+ 142 - 0
internal/utils/compapi/func.go

@@ -7,17 +7,159 @@ import (
 	"fmt"
 	"net/http"
 	"reflect"
+	"strconv"
 	"strings"
 
+	"wechat-api/ent"
+	"wechat-api/ent/custom_types"
+	"wechat-api/ent/predicate"
+	"wechat-api/ent/usagedetail"
+	"wechat-api/ent/usagetotal"
+	"wechat-api/internal/config"
 	"wechat-api/internal/types"
 	"wechat-api/internal/utils/contextkey"
 
 	openai "github.com/openai/openai-go"
 	"github.com/openai/openai-go/option"
 	"github.com/openai/openai-go/packages/ssestream"
+	"github.com/redis/go-redis/v9"
 	"github.com/zeromicro/go-zero/rest/httpx"
 )
 
+type ServiceContext struct {
+	Config config.Config
+	DB     *ent.Client
+	Rds    redis.UniversalClient
+}
+
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	logType := 5
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	tmpId := 0
+	tmpId, _ = strconv.Atoi(resp.ID)
+	sessionId := uint64(tmpId)
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	promptTokens := uint64(resp.Usage.PromptTokens)
+	completionToken := uint64(resp.Usage.CompletionTokens)
+	totalTokens := promptTokens + completionToken
+
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+
+	_, _, _ = logType, sessionId, totalTokens
+	res, err := svcCtx.DB.UsageDetail.Create().
+		SetNotNilType(&logType).
+		SetNotNilBotID(&authToken).
+		SetNotNilReceiverID(&req.EventType).
+		SetNotNilSessionID(&sessionId).
+		SetNillableRequest(&msgContent).
+		SetNillableResponse(&resp.Choices[0].Message.Content).
+		SetNillableOrganizationID(&orgId).
+		SetOriginalData(rawReqResp).
+		SetNillablePromptTokens(&promptTokens).
+		SetNillableCompletionTokens(&completionToken).
+		SetNillableTotalTokens(&totalTokens).
+		Save(ctx)
+
+	if err == nil { //插入UsageDetai之后再统计UsageTotal
+		updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
+	}
+	return err
+}
+
+func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageTotal
+	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
+	return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
+
+}
+
+func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
+
+	Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
+	if err != nil && !ent.IsNotFound(err) {
+		return err
+	}
+	if Id > 0 { //UsageTotal have record by  newUsageDetailId
+		_, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
+			SetTotalTokens(sumTotalTokens).
+			SetEndIndex(newUsageDetailId).
+			Save(ctx)
+	} else { //create new record by  newUsageDetailId
+		logType := 5
+		_, err = svcCtx.DB.UsageTotal.Create().
+			SetNotNilBotID(&authToken).
+			SetNotNilEndIndex(&newUsageDetailId).
+			SetNotNilTotalTokens(&sumTotalTokens).
+			SetNillableType(&logType).
+			SetNotNilOrganizationID(&orgId).
+			Save(ctx)
+	}
+
+	return err
+}
+
+func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
+	authToken string, newUsageDetailId uint64, orgId uint64) error {
+
+	sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
+	if err == nil {
+		err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
+	}
+	return err
+}
+
+// sum total_tokens from usagedetail by AuthToken
+func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageDetail
+	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
+
+	var res []struct {
+		Sum, Min, Max, Count uint64
+	}
+	totalTokens := uint64(0)
+	var err error = nil
+	err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
+		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
+	if err == nil {
+		if len(res) > 0 {
+			totalTokens = res[0].Sum
+		} else {
+			totalTokens = 0
+		}
+	}
+	return totalTokens, err
+}
+
+func getMessageContentStr(input any) string {
+	str := ""
+	switch val := input.(type) {
+	case string:
+		str = val
+	case []interface{}:
+		if len(val) > 0 {
+			if valc, ok := val[0].(map[string]interface{}); ok {
+				if valcc, ok := valc["text"]; ok {
+					str, _ = valcc.(string)
+				}
+			}
+		}
+	}
+	return str
+}
+
 func IsOpenaiModel(model string) bool {
 
 	prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}