|
@@ -7,17 +7,159 @@ import (
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"reflect"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
|
|
|
+ "wechat-api/ent"
|
|
|
+ "wechat-api/ent/custom_types"
|
|
|
+ "wechat-api/ent/predicate"
|
|
|
+ "wechat-api/ent/usagedetail"
|
|
|
+ "wechat-api/ent/usagetotal"
|
|
|
+ "wechat-api/internal/config"
|
|
|
"wechat-api/internal/types"
|
|
|
"wechat-api/internal/utils/contextkey"
|
|
|
|
|
|
openai "github.com/openai/openai-go"
|
|
|
"github.com/openai/openai-go/option"
|
|
|
"github.com/openai/openai-go/packages/ssestream"
|
|
|
+ "github.com/redis/go-redis/v9"
|
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
|
)
|
|
|
|
|
|
+type ServiceContext struct {
|
|
|
+ Config config.Config
|
|
|
+ DB *ent.Client
|
|
|
+ Rds redis.UniversalClient
|
|
|
+}
|
|
|
+
|
|
|
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+
|
|
|
+ authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
|
|
|
+
|
|
|
+ logType := 5
|
|
|
+ rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
|
|
|
+ tmpId := 0
|
|
|
+ tmpId, _ = strconv.Atoi(resp.ID)
|
|
|
+ sessionId := uint64(tmpId)
|
|
|
+ orgId := uint64(0)
|
|
|
+ apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
|
|
|
+ if ok {
|
|
|
+ orgId = apiKeyObj.OrganizationID
|
|
|
+ }
|
|
|
+ promptTokens := uint64(resp.Usage.PromptTokens)
|
|
|
+ completionToken := uint64(resp.Usage.CompletionTokens)
|
|
|
+ totalTokens := promptTokens + completionToken
|
|
|
+
|
|
|
+ msgContent := getMessageContentStr(req.Messages[0].Content)
|
|
|
+
|
|
|
+ _, _, _ = logType, sessionId, totalTokens
|
|
|
+ res, err := svcCtx.DB.UsageDetail.Create().
|
|
|
+ SetNotNilType(&logType).
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilReceiverID(&req.EventType).
|
|
|
+ SetNotNilSessionID(&sessionId).
|
|
|
+ SetNillableRequest(&msgContent).
|
|
|
+ SetNillableResponse(&resp.Choices[0].Message.Content).
|
|
|
+ SetNillableOrganizationID(&orgId).
|
|
|
+ SetOriginalData(rawReqResp).
|
|
|
+ SetNillablePromptTokens(&promptTokens).
|
|
|
+ SetNillableCompletionTokens(&completionToken).
|
|
|
+ SetNillableTotalTokens(&totalTokens).
|
|
|
+ Save(ctx)
|
|
|
+
|
|
|
+ if err == nil { //插入UsageDetai之后再统计UsageTotal
|
|
|
+ updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageTotal
|
|
|
+ predicates = append(predicates, usagetotal.BotIDEQ(authToken))
|
|
|
+ return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+
|
|
|
+ authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
|
|
|
+ if err != nil && !ent.IsNotFound(err) {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if Id > 0 { //UsageTotal have record by newUsageDetailId
|
|
|
+ _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
|
|
|
+ SetTotalTokens(sumTotalTokens).
|
|
|
+ SetEndIndex(newUsageDetailId).
|
|
|
+ Save(ctx)
|
|
|
+ } else { //create new record by newUsageDetailId
|
|
|
+ logType := 5
|
|
|
+ _, err = svcCtx.DB.UsageTotal.Create().
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilEndIndex(&newUsageDetailId).
|
|
|
+ SetNotNilTotalTokens(&sumTotalTokens).
|
|
|
+ SetNillableType(&logType).
|
|
|
+ SetNotNilOrganizationID(&orgId).
|
|
|
+ Save(ctx)
|
|
|
+ }
|
|
|
+
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
|
|
|
+ if err == nil {
|
|
|
+ err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+// sum total_tokens from usagedetail by AuthToken
|
|
|
+func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageDetail
|
|
|
+ predicates = append(predicates, usagedetail.BotIDEQ(authToken))
|
|
|
+
|
|
|
+ var res []struct {
|
|
|
+ Sum, Min, Max, Count uint64
|
|
|
+ }
|
|
|
+ totalTokens := uint64(0)
|
|
|
+ var err error = nil
|
|
|
+ err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
|
|
|
+ ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
|
|
|
+ if err == nil {
|
|
|
+ if len(res) > 0 {
|
|
|
+ totalTokens = res[0].Sum
|
|
|
+ } else {
|
|
|
+ totalTokens = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return totalTokens, err
|
|
|
+}
|
|
|
+
|
|
|
+func getMessageContentStr(input any) string {
|
|
|
+ str := ""
|
|
|
+ switch val := input.(type) {
|
|
|
+ case string:
|
|
|
+ str = val
|
|
|
+ case []interface{}:
|
|
|
+ if len(val) > 0 {
|
|
|
+ if valc, ok := val[0].(map[string]interface{}); ok {
|
|
|
+ if valcc, ok := valc["text"]; ok {
|
|
|
+ str, _ = valcc.(string)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return str
|
|
|
+}
|
|
|
+
|
|
|
func IsOpenaiModel(model string) bool {
|
|
|
|
|
|
prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
|