Browse Source

1.修改workID和workToken映射字典结构,增加函数可以获得WorkId对应的token及数字索引值 2.compapi权限验证增加rds读取 3.每次请求后增加数据库日志记录及实时token统计

liwei 1 week ago
parent
commit
257d52affd

+ 126 - 1
internal/logic/chat/chat_completions_logic.go

@@ -3,6 +3,7 @@ package chat
 import (
 	"context"
 	"errors"
+	"strconv"
 
 	"wechat-api/ent"
 	"wechat-api/internal/svc"
@@ -10,6 +11,11 @@ import (
 	"wechat-api/internal/utils/compapi"
 	"wechat-api/internal/utils/contextkey"
 
+	"wechat-api/ent/custom_types"
+	"wechat-api/ent/predicate"
+	"wechat-api/ent/usagedetail"
+	"wechat-api/ent/usagetotal"
+
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -56,7 +62,126 @@ func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (resp *typ
 	if len(apiKeyObj.OpenaiBase) == 0 || len(workToken) == 0 {
 		return nil, errors.New("not auth info")
 	}
-	return l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
+
+	apiResp, err := l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
+	if err == nil {
+		l.doRequestLog(req, apiResp) //请求记录
+	}
+
+	return apiResp, err
+}
+
+func (l *ChatCompletionsLogic) doRequestLog(req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+	authToken, ok := contextkey.OpenapiTokenKey.GetValue(l.ctx)
+	if !ok {
+		return errors.New("content get auth token err")
+	}
+
+	return l.appendUsageDetailLog(authToken, req, resp)
+}
+
+func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageTotal
+	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
+	return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
+
+}
+
+func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
+
+	Id, err := l.getUsagetotalIdByToken(authToken)
+	if err != nil && !ent.IsNotFound(err) {
+		return err
+	}
+	if Id > 0 { //UsageTotal have record by  newUsageDetailId
+		_, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
+			SetTotalTokens(sumTotalTokens).
+			SetEndIndex(newUsageDetailId).
+			Save(l.ctx)
+	} else { //create new record by  newUsageDetailId
+		logType := 5
+		_, err = l.svcCtx.DB.UsageTotal.Create().
+			SetNotNilBotID(&authToken).
+			SetNotNilEndIndex(&newUsageDetailId).
+			SetNotNilTotalTokens(&sumTotalTokens).
+			SetNillableType(&logType).
+			SetNotNilOrganizationID(&orgId).
+			Save(l.ctx)
+	}
+
+	return err
+}
+
+func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
+
+	sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
+	if err == nil {
+		err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
+	}
+	return err
+}
+
+// sum total_tokens from usagedetail by AuthToken
+func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageDetail
+	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
+
+	var res []struct {
+		Sum, Min, Max, Count uint64
+	}
+	totalTokens := uint64(0)
+	var err error = nil
+	err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
+		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
+	if err == nil {
+		if len(res) > 0 {
+			totalTokens = res[0].Sum
+		} else {
+			totalTokens = 0
+		}
+	}
+	return totalTokens, err
+}
+
+func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	logType := 5
+	workIdx := compapi.GetWorkIdxByID(req.EventType, req.WorkId)
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+
+	tmpId := 0
+	tmpId, _ = strconv.Atoi(resp.ID)
+	sessionId := uint64(tmpId)
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	promptTokens := uint64(resp.Usage.PromptTokens)
+	completionToken := uint64(resp.Usage.CompletionTokens)
+	totalTokens := promptTokens + completionToken
+
+	res, err := l.svcCtx.DB.UsageDetail.Create().
+		SetNotNilType(&logType).
+		SetNotNilBotID(&authToken).
+		SetNotNilReceiverID(&req.EventType).
+		SetNotNilSessionID(&sessionId).
+		SetNillableApp(&workIdx).
+		SetNillableRequest(&req.Messages[0].Content).
+		SetNillableResponse(&resp.Choices[0].Message.Content).
+		SetNillableOrganizationID(&orgId).
+		SetOriginalData(rawReqResp).
+		SetNillablePromptTokens(&promptTokens).
+		SetNillableCompletionTokens(&completionToken).
+		SetNillableTotalTokens(&totalTokens).
+		Save(l.ctx)
+
+	if err == nil { //插入UsageDetai之后再统计UsageTotal
+		l.updateUsageTotal(authToken, res.ID, orgId)
+	}
+	return err
 }
 
 func (l *ChatCompletionsLogic) workForFastgpt(req *types.CompApiReq, apiKey string, apiBase string) (resp *types.CompOpenApiResp, err error) {

+ 57 - 50
internal/middleware/openauthority_middleware.go

@@ -3,9 +3,8 @@ package middleware
 import (
 	"context"
 	"encoding/json"
-	"fmt"
+	"errors"
 	"net/http"
-	"reflect"
 
 	"wechat-api/ent"
 	"wechat-api/ent/apikey"
@@ -42,59 +41,74 @@ func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c con
 
 func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
 	var (
-		rc  int
-		err error
-		val *ent.ApiKey
+		rc        int
+		err       error
+		apiKeyObj *ent.ApiKey
+		fromId    = -1
 	)
-	val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
-	return val, rc, err
-
-	/*
-		r, e = m.getTokenUserInfoByRds(ctx, loginToken)
-		fmt.Println("redis:", "code-", r, "err-", e)
-	*/
-	/*
-		//首先从redis取数据
-		_, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
-		fmt.Printf("++++++++++++++++++++++++get authinfo from rds out:%d/err:%s\n", rc, err)
-		if rc <= 0 || err != nil { //无法获得后再从数据库获得
-			rc = 0
-			err = nil
-			val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
-			fmt.Println("----------------------After m.getTokenUserInfoByDb:", val)
-			err = m.saveTokenUserInfoToRds(ctx, authToken, val)
-			fmt.Println("------------save saveTokenUserInfoToRd err:", err)
+	_ = fromId
+
+	//首先从redis取数据
+	apiKeyObj, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
+	if rc <= 0 || err != nil { //无法获得后再从数据库获得	{
+
+		rc = 0
+		err = nil
+		apiKeyObj, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
+		if err == nil {
+			//get apiKeyObj from db succ
+			fromId = 1
+			err = m.saveTokenUserInfoToRds(ctx, authToken, apiKeyObj)
 		}
+	} else {
+		fromId = 2
+	}
 
-		_ = rc
-		if err != nil {
-			return nil, 0, err
+	/*
+		if err == nil {
+
+			fromStr := ""
+			switch fromId {
+			case 1:
+				fromStr = "DB"
+			case 2:
+				fromStr = "RDS"
+			}
+			fmt.Println("=========================================>>>")
+			fmt.Printf("In checkTokenUserInfo Get Token Info From %s\n", fromStr)
+			fmt.Printf("Key:'%s'\n", apiKeyObj.Key)
+			fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
+			fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
+			fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
+			fmt.Println("<<<=========================================")
 		}
-		return val, 0, nil
 	*/
+	return apiKeyObj, rc, err
 }
 
 func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
-	if bs, err := json.Marshal(saveInfo); err == nil {
-		return err
-	} else {
-		rc, err := m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
-		fmt.Printf("#~~~~~~~~~~~~~~~++~~~~~~~~~~~~~HSet Val:%s get Result:%d/%s\n", string(bs), rc, err)
-		return err
+	bs, err := json.Marshal(saveInfo)
+	if err == nil {
+
+		_, err = m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
 	}
+	return err
 }
 func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
 
 	rcode := -1
-	val, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
-	if err == redis.Nil {
-		rcode = 0
-	} else if err == nil {
-		rcode = 1
+	result := ent.ApiKey{}
+	jsonStr, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
+	if errors.Is(err, redis.Nil) {
+		rcode = 0 //key not exist
+	} else if err == nil { //find key
+
+		err := json.Unmarshal([]byte(jsonStr), &result)
+		if err == nil {
+			rcode = 1
+		}
 	}
-	fmt.Printf("#####################From Redis By Key:'%s' Get '%s'(%s/%T)\n", authToken, val, reflect.TypeOf(val), val)
-	fmt.Println(val)
-	return nil, rcode, err
+	return &result, rcode, err
 }
 
 func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
@@ -124,16 +138,9 @@ func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc
 			httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
 			return
 		}
-		//ctx = contextkey.OpenapiTokenKey.WithValue(ctx, apiToken)
 		ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
-		/*
-			fmt.Println("=========================================")
-			fmt.Printf("In Middleware Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
-			fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
-			fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
-			fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
-			fmt.Println("=========================================")
-		*/
+		ctx = contextkey.OpenapiTokenKey.WithValue(ctx, authToken)
+
 		newReq := r.WithContext(ctx)
 		// Passthrough to next handler if need
 		next(w, newReq)

+ 6 - 10
internal/utils/compapi/config.go

@@ -11,16 +11,12 @@ const (
 	APIAuthInfoKey string = "COMPAPI_AUTHINFO"
 )
 
-var fastgptWorkIdMap = map[string]string{
-	"default":              "fastgpt-jcDATa9aH4vtUsjDpCU773BxmLU50IxKUX9nUT0mCTLQkEoo1hPxPEdNQeOEWGTn",
-	"OPTIMIZE_CALL_SCRIPT": "fastgpt-bcQ9cWKd6y9a2LfizweeWEnukkQi1Oq46yoiRg9yDNLm8NPTWXsyFwcB",
+type workIdInfo struct {
+	Id  string
+	Idx int
 }
 
-// 获取workToken
-func GetWorkTokenByID(eventType string, workId string) string {
-	val, exist := fastgptWorkIdMap[workId]
-	if !exist {
-		val = fastgptWorkIdMap["default"]
-	}
-	return val
+var fastgptWorkIdMap = map[string]workIdInfo{
+	"default":              {"fastgpt-jcDATa9aH4vtUsjDpCU773BxmLU50IxKUX9nUT0mCTLQkEoo1hPxPEdNQeOEWGTn", 0},
+	"OPTIMIZE_CALL_SCRIPT": {"fastgpt-bcQ9cWKd6y9a2LfizweeWEnukkQi1Oq46yoiRg9yDNLm8NPTWXsyFwcB", 1},
 }

+ 18 - 0
internal/utils/compapi/func.go

@@ -188,3 +188,21 @@ func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chat
 	}
 	return nil, nil
 }
+
+// 获取workToken
+func GetWorkTokenByID(eventType string, workId string) string {
+	val, exist := fastgptWorkIdMap[workId]
+	if !exist {
+		val = fastgptWorkIdMap["default"]
+	}
+	return val.Id
+}
+
+// 获取workIdx
+func GetWorkIdxByID(eventType string, workId string) int {
+	val, exist := fastgptWorkIdMap[workId]
+	if !exist {
+		val = fastgptWorkIdMap["default"]
+	}
+	return val.Idx
+}