ソースを参照

Merge branch 'fixbug/000-bowen-usage'

* fixbug/000-bowen-usage:
  AI 通用接口增加积分记录
  commit 本分支asynctask运行依赖的yaml配制文件
  将token统计函数移到compapi包中,为后续异步请求调用提供便利
  fix:修改model价格&&消耗积分不记录beforeNumber afterNumber
  统一 token 和积分记录
  增加一些使用注释

# Conflicts:
#	internal/utils/compapi/func.go
boweniac 7 時間 前
コミット
664ecf7970

+ 10 - 3
cli/asynctask/asynctask.go

@@ -323,8 +323,8 @@ func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
 		err     error
 		apiResp *types.CompOpenApiResp
 	)
-	req := types.CompApiReq{}
-	if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
+	req := &types.CompApiReq{}
+	if err = json.Unmarshal([]byte(taskData.RequestRaw), req); err != nil {
 		return 0, err
 	}
 	//初始化client
@@ -340,7 +340,7 @@ func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
 		case <-me.ctx.Done(): //接到信号退出
 			goto endloopTry
 		default:
-			apiResp, err = client.Chat(&req)
+			apiResp, err = client.Chat(req)
 			if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
 				//call succ
 				goto endloopTry
@@ -370,6 +370,13 @@ endloopTry:
 	}
 	//成功后处理环节
 
+	//追加访问大模型token统计数据
+	svcCtx := &compapi.ServiceContext{DB: me.svcCtx.DB, Rds: me.svcCtx.Rds, Config: me.svcCtx.Config}
+	err = compapi.AppendUsageDetailLog(me.ctx, svcCtx, taskData.AuthToken, req, apiResp)
+	if err != nil {
+		return 0, err
+	}
+
 	//更新任务状态 => ReqApi_Done(请求API完成)
 	err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
 	if err != nil {

+ 21 - 0
cli/asynctask/etc/asynctask.yaml

@@ -0,0 +1,21 @@
+BatchLoadTask: 200 #每批次取任务数
+MaxWorker: 10       #最大消费者数量
+MaxChannel: 2      #最大消费通道数量
+Debug: false
+
+DatabaseConf: #数据库配置
+  Type: mysql
+  Host: 127.0.0.1
+  Port: 3306
+  DBName: wechat
+  Username: root
+  Password: simple-admin.
+  MaxOpenConn: 100
+  SSLMode: disable
+  CacheTime: 5
+
+RedisConf: #redis配置
+  Host: 127.0.0.1:6379
+
+
+

+ 20 - 15
hook/credit/credit.go

@@ -7,30 +7,35 @@ import (
 	"wechat-api/ent/creditbalance"
 	"wechat-api/ent/custom_types"
 	"wechat-api/ent/usagetotal"
-	"wechat-api/hook/dify"
 )
 
+type Usage struct {
+	PromptTokens     uint64 `json:"prompt_tokens"`
+	CompletionTokens uint64 `json:"completion_tokens"`
+	TotalTokens      uint64 `json:"total_tokens"`
+}
+
 func AddCreditUsage(tx *ent.Tx, ctx context.Context,
 	agentId string, userId string, departmentId uint64,
 	question *string, answer *string,
-	originalData *custom_types.OriginalData, chatData *dify.ChatResp) error {
+	originalData *custom_types.OriginalData, chatData *Usage, model string) error {
 	// 积分明细表记录使用量
-	modelName, price := GetModelPrice()
-	number := ComputePrice(price, chatData.Metadata.Usage.TotalTokens)
-
+	modelName, price := GetModelPrice(model)
+	number := ComputePrice(price, chatData.TotalTokens)
+	
 	// 记录Token使用信息
 	usageDetailItem, err := tx.UsageDetail.Create().
-		SetType(3).            //1-微信 2-名片 3-智能体
-		SetBotID(agentId).     //智能体ID
+		SetType(3). //1-微信 2-名片 3-智能体
+		SetBotID(agentId). //智能体ID
 		SetReceiverID(userId). //接收者userID
-		SetApp(8).             //8-智能体
+		SetApp(8). //8-智能体
 		SetSessionID(0).
 		SetRequest(*question).
 		SetResponse(*answer).
 		SetOriginalData(*originalData).
-		SetTotalTokens(chatData.Metadata.Usage.TotalTokens).
-		SetPromptTokens(chatData.Metadata.Usage.PromptTokens).
-		SetCompletionTokens(chatData.Metadata.Usage.CompletionTokens).
+		SetTotalTokens(chatData.TotalTokens).
+		SetPromptTokens(chatData.PromptTokens).
+		SetCompletionTokens(chatData.CompletionTokens).
 		SetModel(modelName).
 		SetCredits(number).
 		Save(ctx)
@@ -46,7 +51,7 @@ func AddCreditUsage(tx *ent.Tx, ctx context.Context,
 		if ent.IsNotFound(err) {
 			usageTotal, err = tx.UsageTotal.Create().
 				SetBotID(agentId).
-				SetTotalTokens(chatData.Metadata.Usage.TotalTokens).
+				SetTotalTokens(chatData.TotalTokens).
 				SetEndIndex(usageDetailItem.ID).
 				SetOrganizationID(departmentId).
 				Save(ctx)
@@ -58,7 +63,7 @@ func AddCreditUsage(tx *ent.Tx, ctx context.Context,
 		// 更新Token使用总量
 		_, err = tx.UsageTotal.Update().
 			Where(usagetotal.OrganizationID(departmentId)).
-			SetTotalTokens(usageTotal.TotalTokens + chatData.Metadata.Usage.TotalTokens).
+			SetTotalTokens(usageTotal.TotalTokens + chatData.TotalTokens).
 			SetEndIndex(usageDetailItem.ID).
 			Save(ctx)
 		if err != nil {
@@ -90,8 +95,8 @@ func AddCreditUsage(tx *ent.Tx, ctx context.Context,
 	_, err = tx.CreditUsage.Create().
 		SetUserID(userId).
 		SetNumber(number).
-		SetBeforeNumber(beforeNumber).
-		SetAfterNumber(afterNumber).
+		SetBeforeNumber(0).
+		SetAfterNumber(0).
 		SetNtype(1).
 		SetNid(usageDetailItem.ID).
 		SetTable("usage_detail").

+ 30 - 26
hook/credit/models.go

@@ -1,6 +1,9 @@
 package credit
 
-import "math"
+import (
+	"math"
+	"strings"
+)
 
 var modelArray = []string{
 	"o1",
@@ -27,35 +30,36 @@ var modelArray = []string{
 }
 
 var priceArray = []float64{
+	0.01,
+	0.001667,
+	0.001333,
+	0.000733,
+	0.000548,
+	0.000365,
+	0.000274,
+	0.000267,
+	0.00025,
+	0.000219,
+	0.000205,
+	0.000183,
+	0.000137,
 	0.0001,
-	0.00001667,
-	0.00001333,
-	0.00000733,
-	0.00000548,
-	0.00000365,
-	0.00000274,
-	0.00000267,
-	0.0000025,
-	0.00000219,
-	0.00000205,
-	0.00000183,
-	0.00000137,
-	0.000001,
-	0.00000068,
-	0.00000067,
-	0.00000046,
-	0.00000046,
-	0.00000023,
-	0.00000014,
-	0.00000014,
+	0.000068,
+	0.000067,
+	0.000046,
+	0.000046,
+	0.000023,
+	0.000014,
+	0.000014,
 }
 
-func getModelName() string {
-	return "gpt-4o"
+func getModelName(modelName string) string {
+	// 将字符串转换为小写
+	return strings.ToLower(modelName)
 }
 
-func GetModelPrice() (model string, price float64) {
-	difyModelName := getModelName()
+func GetModelPrice(modelName string) (model string, price float64) {
+	difyModelName := getModelName(modelName)
 	for i, v := range modelArray {
 		if v == difyModelName {
 			return v, priceArray[i]
@@ -70,7 +74,7 @@ func ComputePrice(price float64, tokens uint64) float64 {
 	return math.Round(price*float64(tokens)*scale) / scale
 }
 
-// Subtraction() 保留小数点后6位的精确减法
+// Subtraction 保留小数点后6位的精确减法
 func Subtraction(number1, number2 float64) float64 {
 	d1 := number1 * 1000000
 	d2 := number2 * 1000000

+ 4 - 212
internal/logic/chat/chat_completions_logic.go

@@ -5,12 +5,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"net"
-	"net/url"
-	"regexp"
-	"strconv"
-	"strings"
-
 	"wechat-api/ent"
 	"wechat-api/internal/svc"
 	"wechat-api/internal/types"
@@ -18,11 +12,6 @@ import (
 	"wechat-api/internal/utils/contextkey"
 	"wechat-api/internal/utils/typekit"
 
-	"wechat-api/ent/custom_types"
-	"wechat-api/ent/predicate"
-	"wechat-api/ent/usagedetail"
-	"wechat-api/ent/usagetotal"
-
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -137,7 +126,7 @@ func (l *ChatCompletionsLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *e
 	if req.IsBatch {
 		//流模式暂时不支持异步模式
 		//Callback格式非法则取消批量模式
-		if req.Stream || !IsValidURL(&req.Callback, true) {
+		if req.Stream || !compapi.IsValidURL(&req.Callback, true) {
 			req.IsBatch = false
 		}
 	}
@@ -184,107 +173,9 @@ func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *ty
 }
 
 func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
-
-	logType := 5
-	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
-	tmpId := 0
-	tmpId, _ = strconv.Atoi(resp.ID)
-	sessionId := uint64(tmpId)
-	orgId := uint64(0)
-	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
-	if ok {
-		orgId = apiKeyObj.OrganizationID
-	}
-	promptTokens := uint64(resp.Usage.PromptTokens)
-	completionToken := uint64(resp.Usage.CompletionTokens)
-	totalTokens := promptTokens + completionToken
-
-	msgContent := getMessageContentStr(req.Messages[0].Content)
-
-	_, _, _ = logType, sessionId, totalTokens
-	res, err := l.svcCtx.DB.UsageDetail.Create().
-		SetNotNilType(&logType).
-		SetNotNilBotID(&authToken).
-		SetNotNilReceiverID(&req.EventType).
-		SetNotNilSessionID(&sessionId).
-		SetNillableRequest(&msgContent).
-		SetNillableResponse(&resp.Choices[0].Message.Content).
-		SetNillableOrganizationID(&orgId).
-		SetOriginalData(rawReqResp).
-		SetNillablePromptTokens(&promptTokens).
-		SetNillableCompletionTokens(&completionToken).
-		SetNillableTotalTokens(&totalTokens).
-		Save(l.ctx)
-
-	if err == nil { //插入UsageDetai之后再统计UsageTotal
-		l.updateUsageTotal(authToken, res.ID, orgId)
-	}
-	return err
-}
-
-func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageTotal
-	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
-	return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
-
-}
-
-func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
-
-	Id, err := l.getUsagetotalIdByToken(authToken)
-	if err != nil && !ent.IsNotFound(err) {
-		return err
-	}
-	if Id > 0 { //UsageTotal have record by  newUsageDetailId
-		_, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
-			SetTotalTokens(sumTotalTokens).
-			SetEndIndex(newUsageDetailId).
-			Save(l.ctx)
-	} else { //create new record by  newUsageDetailId
-		logType := 5
-		_, err = l.svcCtx.DB.UsageTotal.Create().
-			SetNotNilBotID(&authToken).
-			SetNotNilEndIndex(&newUsageDetailId).
-			SetNotNilTotalTokens(&sumTotalTokens).
-			SetNillableType(&logType).
-			SetNotNilOrganizationID(&orgId).
-			Save(l.ctx)
-	}
-
-	return err
-}
-
-func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
-
-	sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
-	if err == nil {
-		err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
-	}
-	return err
-}
-
-// sum total_tokens from usagedetail by AuthToken
-func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageDetail
-	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
-
-	var res []struct {
-		Sum, Min, Max, Count uint64
-	}
-	totalTokens := uint64(0)
-	var err error = nil
-	err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
-		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
-	if err == nil {
-		if len(res) > 0 {
-			totalTokens = res[0].Sum
-		} else {
-			totalTokens = 0
-		}
-	}
-	return totalTokens, err
+	svcCtx := &compapi.ServiceContext{Config: l.svcCtx.Config, DB: l.svcCtx.DB,
+		Rds: l.svcCtx.Rds}
+	return compapi.AppendUsageDetailLog(l.ctx, svcCtx, authToken, req, resp)
 }
 
 func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
@@ -387,105 +278,6 @@ func humanSeeValidResult(ctx context.Context, req *types.CompApiReq, resp *types
 	fmt.Println(typekit.PrettyPrint(nres))
 }
 
-func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) {
-	if eventType != "fastgpt" {
-		return
-	}
-	obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId)
-}
-
-// 合法域名正则(支持通配符、中文域名等场景按需调整)
-var domainRegex = regexp.MustCompile(
-	// 多级域名(如 example.com)
-	`^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
-		`|` +
-		// 单级域名(如 localhost 或 mytest-svc)
-		`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
-)
-
-func IsValidURL(input *string, adjust bool) bool {
-	// 空值直接返回
-	if *input == "" {
-		return false
-	}
-	inputStr := *input
-
-	// --- 预处理输入:自动补全协议 ---
-	// 若输入不包含协议头,默认添加 http://
-	if !strings.Contains(*input, "://") {
-		inputStr = "http://" + *input
-	}
-
-	// --- 解析 URL ---
-	u, err := url.Parse(inputStr)
-	if err != nil {
-		return false
-	}
-
-	// --- 校验协议 ---
-	// 只允许常见协议(按需扩展)
-	switch u.Scheme {
-	case "http", "https", "ftp", "ftps":
-	default:
-		return false
-	}
-
-	// --- 拆分 Host 和 Port ---
-	host, port, err := net.SplitHostPort(u.Host)
-	if err != nil {
-		// 无端口时,整个 Host 作为主机名
-		host = u.Host
-		port = ""
-	}
-
-	// --- 校验主机名 ---
-	// 场景1:IPv4 或 IPv6
-	if ip := net.ParseIP(host); ip != nil {
-		// 允许私有或保留 IP(按需调整)
-		// 示例中允许所有合法 IP
-	} else {
-		// 场景2:域名(包括 localhost)
-		if !domainRegex.MatchString(host) {
-			return false
-		}
-	}
-
-	// --- 校验端口 ---
-	if port != "" {
-		p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
-		if err != nil {
-			// 直接尝试解析为数字端口
-			numPort, err := strconv.Atoi(port)
-			if err != nil || numPort < 1 || numPort > 65535 {
-				return false
-			}
-		} else if p == 0 { // 动态端口为 0 时无效
-			return false
-		}
-	}
-	if adjust {
-		*input = inputStr
-	}
-	return true
-}
-
-func getMessageContentStr(input any) string {
-	str := ""
-	switch val := input.(type) {
-	case string:
-		str = val
-	case []interface{}:
-		if len(val) > 0 {
-			if valc, ok := val[0].(map[string]interface{}); ok {
-				if valcc, ok := valc["text"]; ok {
-					str, _ = valcc.(string)
-				}
-			}
-		}
-	}
-	return str
-}
-
 func isAsyncReqest(req *types.CompApiReq) bool {
 	return req.IsBatch
 }

+ 6 - 2
internal/logic/chatrecords/gpts_submit_api_chat_logic.go

@@ -204,14 +204,18 @@ func (l *GptsSubmitApiChatLogic) GptsSubmitApiChat(tokenStr string, req *types.G
 					originalData := custom_types.OriginalData{}
 					originalData.Request = chatReq
 					originalData.Response = chatData
-
+					usage := credit.Usage{
+						CompletionTokens: chatData.Metadata.Usage.CompletionTokens,
+						PromptTokens:     chatData.Metadata.Usage.PromptTokens,
+						TotalTokens:      chatData.Metadata.Usage.TotalTokens,
+					}
 					if err != nil {
 						l.Logger.Errorf("start transaction error:%v\n", err)
 					} else {
 						err = credit.AddCreditUsage(tx, l.ctx,
 							agentId, userId, *userInfo.DepartmentId,
 							req.Content, &answer,
-							&originalData, &chatData,
+							&originalData, &usage, "gpt-4o-mini",
 						)
 						if err != nil {
 							_ = tx.Rollback()

+ 255 - 0
internal/utils/compapi/func.go

@@ -5,19 +5,274 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
+	"net/url"
 	"reflect"
+	"regexp"
+	"strconv"
 	"strings"
 
+	"wechat-api/ent"
+	"wechat-api/ent/custom_types"
+	"wechat-api/hook/credit"
+	"wechat-api/internal/config"
 	"wechat-api/internal/types"
 	"wechat-api/internal/utils/contextkey"
 
 	openai "github.com/openai/openai-go"
 	"github.com/openai/openai-go/option"
 	"github.com/openai/openai-go/packages/ssestream"
+	"github.com/redis/go-redis/v9"
 	"github.com/zeromicro/go-zero/rest/httpx"
 )
 
+type ServiceContext struct {
+	Config config.Config
+	DB     *ent.Client
+	Rds    redis.UniversalClient
+}
+
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
+	authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+	tx, err := svcCtx.DB.Tx(context.Background())
+	if err != nil {
+		err = fmt.Errorf("start transaction error:%v", err)
+	} else {
+		usage := credit.Usage{
+			CompletionTokens: uint64(resp.Usage.CompletionTokens),
+			PromptTokens:     uint64(resp.Usage.PromptTokens),
+			TotalTokens:      uint64(resp.Usage.TotalTokens),
+		}
+		err = credit.AddCreditUsage(tx, ctx,
+			authToken, req.EventType, orgId,
+			&msgContent, &resp.Choices[0].Message.Content,
+			&rawReqResp, &usage, req.Model,
+		)
+		if err != nil {
+			_ = tx.Rollback()
+			err = fmt.Errorf("save credits info failed:%v", err)
+		} else {
+			_ = tx.Commit()
+		}
+	}
+	return err
+}
+
+func IsValidURL(input *string, adjust bool) bool {
+
+	// 合法域名正则(支持通配符、中文域名等场景按需调整)
+	var domainRegex = regexp.MustCompile(
+		// 多级域名(如 example.com)
+		`^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
+			`|` +
+			// 单级域名(如 localhost 或 mytest-svc)
+			`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
+	)
+
+	// 空值直接返回
+	if *input == "" {
+		return false
+	}
+	inputStr := *input
+
+	// --- 预处理输入:自动补全协议 ---
+	// 若输入不包含协议头,默认添加 http://
+	if !strings.Contains(*input, "://") {
+		inputStr = "http://" + *input
+	}
+
+	// --- 解析 URL ---
+	u, err := url.Parse(inputStr)
+	if err != nil {
+		return false
+	}
+
+	// --- 校验协议 ---
+	// 只允许常见协议(按需扩展)
+	switch u.Scheme {
+	case "http", "https", "ftp", "ftps":
+	default:
+		return false
+	}
+
+	// --- 拆分 Host 和 Port ---
+	host, port, err := net.SplitHostPort(u.Host)
+	if err != nil {
+		// 无端口时,整个 Host 作为主机名
+		host = u.Host
+		port = ""
+	}
+
+	// --- 校验主机名 ---
+	// 场景1:IPv4 或 IPv6
+	if ip := net.ParseIP(host); ip != nil {
+		// 允许私有或保留 IP(按需调整)
+		// 示例中允许所有合法 IP
+	} else {
+		// 场景2:域名(包括 localhost)
+		if !domainRegex.MatchString(host) {
+			return false
+		}
+	}
+
+	// --- 校验端口 ---
+	if port != "" {
+		p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
+		if err != nil {
+			// 直接尝试解析为数字端口
+			numPort, err := strconv.Atoi(port)
+			if err != nil || numPort < 1 || numPort > 65535 {
+				return false
+			}
+		} else if p == 0 { // 动态端口为 0 时无效
+			return false
+		}
+	}
+	if adjust {
+		*input = inputStr
+	}
+	return true
+}
+
+func getMessageContentStr(input any) string {
+	str := ""
+	switch val := input.(type) {
+	case string:
+		str = val
+	case []any:
+		if len(val) > 0 {
+			if valc, ok := val[0].(map[string]any); ok {
+				if valcc, ok := valc["text"]; ok {
+					str, _ = valcc.(string)
+				}
+			}
+		}
+	}
+	return str
+}
+
+/*
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	logType := 5
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	tmpId := 0
+	tmpId, _ = strconv.Atoi(resp.ID)
+	sessionId := uint64(tmpId)
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	promptTokens := uint64(resp.Usage.PromptTokens)
+	completionToken := uint64(resp.Usage.CompletionTokens)
+	totalTokens := promptTokens + completionToken
+
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+
+	_, _, _ = logType, sessionId, totalTokens
+	res, err := svcCtx.DB.UsageDetail.Create().
+		SetNotNilType(&logType).
+		SetNotNilBotID(&authToken).
+		SetNotNilReceiverID(&req.EventType).
+		SetNotNilSessionID(&sessionId).
+		SetNillableRequest(&msgContent).
+		SetNillableResponse(&resp.Choices[0].Message.Content).
+		SetNillableOrganizationID(&orgId).
+		SetOriginalData(rawReqResp).
+		SetNillablePromptTokens(&promptTokens).
+		SetNillableCompletionTokens(&completionToken).
+		SetNillableTotalTokens(&totalTokens).
+		Save(ctx)
+
+	if err == nil { //插入UsageDetai之后再统计UsageTotal
+		updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
+	}
+	return err
+}
+
+func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageTotal
+	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
+	return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
+
+}
+
+func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
+
+	Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
+	if err != nil && !ent.IsNotFound(err) {
+		return err
+	}
+	if Id > 0 { //UsageTotal have record by  newUsageDetailId
+		_, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
+			SetTotalTokens(sumTotalTokens).
+			SetEndIndex(newUsageDetailId).
+			Save(ctx)
+	} else { //create new record by  newUsageDetailId
+		logType := 5
+		_, err = svcCtx.DB.UsageTotal.Create().
+			SetNotNilBotID(&authToken).
+			SetNotNilEndIndex(&newUsageDetailId).
+			SetNotNilTotalTokens(&sumTotalTokens).
+			SetNillableType(&logType).
+			SetNotNilOrganizationID(&orgId).
+			Save(ctx)
+	}
+
+	return err
+}
+
+func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
+	authToken string, newUsageDetailId uint64, orgId uint64) error {
+
+	sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
+	if err == nil {
+		err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
+	}
+	return err
+}
+
+// sum total_tokens from usagedetail by AuthToken
+func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageDetail
+	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
+
+	var res []struct {
+		Sum, Min, Max, Count uint64
+	}
+	totalTokens := uint64(0)
+	var err error = nil
+	err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
+		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
+	if err == nil {
+		if len(res) > 0 {
+			totalTokens = res[0].Sum
+		} else {
+			totalTokens = 0
+		}
+	}
+	return totalTokens, err
+}
+*/
+
 func IsOpenaiModel(model string) bool {
 
 	prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}

+ 6 - 3
internal/utils/compapi/intent.go

@@ -11,7 +11,7 @@ type IntentClient struct {
 func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFormatConfig {
 
 	//Message重构的配置
-	me.ResformatConfig.SysmesArgs = []any{}
+	me.ResformatConfig.SysmesArgs = []any{} //以下SysmesTmpl内无替换占位符,所以这里为空
 	me.ResformatConfig.SysmesTmpl = `# 任务
 1. 首先,判断用户的第一句话是否说了:“你好,(任意内容)通话”,如果说了,则不用理会评级规则,直接强制分配为"语音助手"
 2. 如果不属于“语音助手”,请根据评级规则,对聊天记录给出评级、置信度、评分依据(逐项分析不要遗漏)
@@ -19,6 +19,7 @@ func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFor
 # 细节说明
 置信度从0到1,0为置信度最低,1为置信度最高。`
 
+	//来自请求的变量用来替换UsermesTmpl中的占位符
 	me.ResformatConfig.UsermesArgs = []any{req.Variables["chat_history"]}
 	me.ResformatConfig.UsermesTmpl = `# 评级规则:
         [
@@ -41,18 +42,20 @@ func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFor
         `
 
 	//ResponseFormat设置的配置
-	me.ResformatConfig.ResformatDesc = "为通话记录进行评级"
+	me.ResformatConfig.ResformatDesc = "为通话记录进行评级" //Resformat描述
+	//非openai兼容大模型所使用的Resformat文本
 	me.ResformatConfig.ResformatTxt = `{
     "score": str, #评分结果:有意向、待进一步分析、暂时无法沟通、其他
     "confidence_score": int, #置信度分数,范围从0.0到1.0
     "scoring_criteria": str, #请逐步介绍为何评为这个结果
 }`
+	//openai兼容大模型所使用的Resformat结构或其他类型
 	me.ResformatConfig.ResformatStruct = struct {
 		Score           string  `json:"score" jsonschema_description:"评分结果:有意向、待进一步分析、暂时无法沟通、其他"`
 		ConfidenceScore float32 `json:"confidence_score" jsonschema_description:"置信度分数,范围从0.0到1.0"`
 		ScoringCriteria string  `json:"scoring_criteria" jsonschema_description:"请逐步介绍为何评为这个结果"`
 	}{}
-	me.ResformatConfig.HaveSet = true
+	me.ResformatConfig.HaveSet = true //很关键,避免父类的参数再设置一遍
 
 	return me.ResformatConfig
 }