jimmyyem 6 өдөр өмнө
parent
commit
13df842ec9

+ 113 - 0
hook/credit/credit.go

@@ -0,0 +1,113 @@
+package credit
+
+import (
+	"context"
+	"fmt"
+	"wechat-api/ent"
+	"wechat-api/ent/creditbalance"
+	"wechat-api/ent/custom_types"
+	"wechat-api/ent/usagetotal"
+	"wechat-api/hook/dify"
+)
+
+func AddCreditUsage(tx *ent.Tx, ctx context.Context,
+	agentId string, userId string, departmentId uint64,
+	question *string, answer *string,
+	originalData *custom_types.OriginalData, chatData *dify.ChatResp) error {
+	// 积分明细表记录使用量
+	modelName, price := GetModelPrice()
+	number := ComputePrice(price, chatData.Metadata.Usage.TotalTokens)
+
+	// 记录Token使用信息
+	usageDetailItem, err := tx.UsageDetail.Create().
+		SetType(3).            //1-微信 2-名片 3-智能体
+		SetBotID(agentId).     //智能体ID
+		SetReceiverID(userId). //接收者userID
+		SetApp(8).             //8-智能体
+		SetSessionID(0).
+		SetRequest(*question).
+		SetResponse(*answer).
+		SetOriginalData(*originalData).
+		SetTotalTokens(chatData.Metadata.Usage.TotalTokens).
+		SetPromptTokens(chatData.Metadata.Usage.PromptTokens).
+		SetCompletionTokens(chatData.Metadata.Usage.CompletionTokens).
+		SetModel(modelName).
+		SetCredits(number).
+		Save(ctx)
+
+	if err != nil {
+		fmt.Printf("create usage_detail failed:%v\n", err)
+		return err
+	}
+
+	// 记录Token使用总量
+	usageTotal, err := tx.UsageTotal.Query().Where(usagetotal.OrganizationID(departmentId)).First(ctx)
+	if err != nil {
+		if ent.IsNotFound(err) {
+			usageTotal, err = tx.UsageTotal.Create().
+				SetBotID(agentId).
+				SetTotalTokens(chatData.Metadata.Usage.TotalTokens).
+				SetEndIndex(usageDetailItem.ID).
+				SetOrganizationID(departmentId).
+				Save(ctx)
+		} else {
+			fmt.Printf("create usage_total failed:organization_id:%v err:%v\n", departmentId, err)
+			return err
+		}
+	} else {
+		// 更新Token使用总量
+		_, err = tx.UsageTotal.Update().
+			Where(usagetotal.OrganizationID(departmentId)).
+			SetTotalTokens(usageTotal.TotalTokens + chatData.Metadata.Usage.TotalTokens).
+			SetEndIndex(usageDetailItem.ID).
+			Save(ctx)
+		if err != nil {
+			fmt.Printf("update usage_total failed:organization_id:%v err:%v\n", departmentId, err)
+			return err
+		}
+	}
+
+	creditBalanceItem, err := tx.CreditBalance.Query().Where(creditbalance.OrganizationID(departmentId)).First(ctx)
+	if err != nil {
+		if ent.IsNotFound(err) {
+			creditBalanceItem, err = tx.CreditBalance.Create().
+				SetOrganizationID(departmentId).
+				SetBalance(0).
+				Save(ctx)
+			if err != nil {
+				fmt.Printf("create credit_balance failed. organization_id:%v error:%v\n", departmentId, err)
+				return err
+			}
+		} else {
+			fmt.Printf("query credit_balance failed: organization_id:%v error:%v\n", departmentId, err)
+			return err
+		}
+	}
+
+	// 积分使用明细记录
+	beforeNumber := creditBalanceItem.Balance
+	afterNumber := Subtraction(beforeNumber, number)
+	_, err = tx.CreditUsage.Create().
+		SetUserID(userId).
+		SetNumber(number).
+		SetBeforeNumber(beforeNumber).
+		SetAfterNumber(afterNumber).
+		SetNtype(1).
+		SetNid(usageDetailItem.ID).
+		SetTable("usage_detail").
+		SetOrganizationID(departmentId).
+		Save(ctx)
+	if err != nil {
+		fmt.Printf("create credit_usage failed:%v\n", err)
+		return err
+	}
+
+	// 积分账户扣减积分
+	_, err = tx.CreditBalance.Update().Where(creditbalance.OrganizationID(departmentId)).SetBalance(afterNumber).Save(ctx)
+	if err != nil {
+		fmt.Printf("update credit_balance failed: organization_id:%v error:%v\n", departmentId, err)
+		return err
+	}
+
+	return nil
+}

+ 79 - 0
hook/credit/models.go

@@ -0,0 +1,79 @@
+package credit
+
+import "math"
+
+var modelArray = []string{
+	"o1",
+	"gpt-4o",
+	"gpt-4.1",
+	"o3-mini",
+	"moonshot-v1-32k",
+	"deepseek-r1",
+	"moonshot-v1-8k",
+	"gpt-4.1-mini",
+	"gpt-3.5-turbo",
+	"qwen-max",
+	"doubao1.5-pro-256k",
+	"deepseek-v3",
+	"qwq-32b-preview",
+	"gpt-4o-mini",
+	"qwen2.5-14b-instruct-1m",
+	"gpt-4.1-nano",
+	"doubao1.5-pro",
+	"doubao1.5-pro-32k",
+	"chatglm3",
+	"qwen-turbo",
+	"doubao1.5-lite-32k",
+}
+
+var priceArray = []float64{
+	0.0001,
+	0.00001667,
+	0.00001333,
+	0.00000733,
+	0.00000548,
+	0.00000365,
+	0.00000274,
+	0.00000267,
+	0.0000025,
+	0.00000219,
+	0.00000205,
+	0.00000183,
+	0.00000137,
+	0.000001,
+	0.00000068,
+	0.00000067,
+	0.00000046,
+	0.00000046,
+	0.00000023,
+	0.00000014,
+	0.00000014,
+}
+
+func getModelName() string {
+	return "gpt-4o"
+}
+
+func GetModelPrice() (model string, price float64) {
+	difyModelName := getModelName()
+	for i, v := range modelArray {
+		if v == difyModelName {
+			return v, priceArray[i]
+		}
+	}
+
+	return modelArray[0], priceArray[0]
+}
+
+func ComputePrice(price float64, tokens uint64) float64 {
+	scale := float64(1000000)
+	return math.Round(price*float64(tokens)*scale) / scale
+}
+
+// Subtraction() 保留小数点后6位的精确减法
+func Subtraction(number1, number2 float64) float64 {
+	d1 := number1 * 1000000
+	d2 := number2 * 1000000
+	res := math.Floor(d1-d2) / 1000000
+	return res
+}

+ 13 - 93
internal/logic/chatrecords/gpts_submit_api_chat_logic.go

@@ -9,7 +9,6 @@ import (
 	"github.com/suyuan32/simple-admin-core/rpc/types/core"
 	"github.com/zeromicro/go-zero/core/logx"
 	"io"
-	"math"
 	"net/http"
 	"strconv"
 	"strings"
@@ -18,7 +17,7 @@ import (
 	"wechat-api/ent/creditbalance"
 	"wechat-api/ent/custom_types"
 	"wechat-api/ent/employee"
-	"wechat-api/ent/usagetotal"
+	"wechat-api/hook/credit"
 	"wechat-api/hook/dify"
 	"wechat-api/internal/svc"
 	"wechat-api/internal/types"
@@ -199,110 +198,31 @@ func (l *GptsSubmitApiChatLogic) GptsSubmitApiChat(tokenStr string, req *types.G
 			if finish {
 				if switcher {
 					tx, err := l.svcCtx.DB.Tx(context.Background())
+					agentId := strconv.Itoa(int(*req.AgentId))
+
 					// 构造 original_data
 					originalData := custom_types.OriginalData{}
 					originalData.Request = chatReq
 					originalData.Response = chatData
 
-					agentId := strconv.Itoa(int(*req.AgentId))
-					// 记录Token使用信息
-					usageDetailItem, err := tx.UsageDetail.Create().
-						SetType(3).            //1-微信 2-名片 3-智能体
-						SetBotID(agentId).     //智能体ID
-						SetReceiverID(userId). //接收者userID
-						SetApp(8).             //8-智能体
-						SetSessionID(0).
-						SetRequest(*req.Content).
-						SetResponse(answer).
-						SetOriginalData(originalData).
-						SetTotalTokens(chatData.Metadata.Usage.TotalTokens).
-						SetPromptTokens(chatData.Metadata.Usage.PromptTokens).
-						SetCompletionTokens(chatData.Metadata.Usage.CompletionTokens).
-						Save(l.ctx)
-
-					if err != nil {
-						_ = tx.Rollback()
-						l.Logger.Errorf("save data to usage_detail error:%v\n", err)
-					}
-
-					// 记录Token使用总量
-					usageTotal, err := tx.UsageTotal.Query().Where(usagetotal.OrganizationID(*userInfo.DepartmentId)).First(l.ctx)
 					if err != nil {
-						if ent.IsNotFound(err) {
-							usageTotal, err = tx.UsageTotal.Create().
-								SetOrganizationID(*userInfo.DepartmentId).
-								SetTotalTokens(0).
-								Save(l.ctx)
-						} else {
+						l.Logger.Errorf("start transaction error:%v\n", err)
+					} else {
+						err = credit.AddCreditUsage(tx, l.ctx,
+							agentId, userId, *userInfo.DepartmentId,
+							req.Content, &answer,
+							&originalData, &chatData,
+						)
+						if err != nil {
 							_ = tx.Rollback()
-							l.Logger.Errorf("create usage_total failed:departmentId:%v err:%v\n", *userInfo.DepartmentId, err)
-						}
-					}
-					_, err = tx.UsageTotal.Update().
-						Where(usagetotal.OrganizationID(*userInfo.DepartmentId)).
-						SetTotalTokens(usageTotal.TotalTokens + chatData.Metadata.Usage.TotalTokens).
-						SetEndIndex(usageDetailItem.ID).
-						Save(l.ctx)
-					if err != nil {
-						_ = tx.Rollback()
-						l.Logger.Errorf("update usage_total failed:departmentId:%v err:%v\n", *userInfo.DepartmentId, err)
-					}
-
-					// 积分明细表记录使用量
-					// 根据1积分=10000token 根据Token换算积分使用量
-					var rate float64 = 10000
-					change := float64(chatData.Metadata.Usage.TotalTokens) / rate
-					number := math.Round(change*rate) / rate
-
-					creditBalanceItem, err := tx.CreditBalance.Query().Where(creditbalance.OrganizationID(*userInfo.DepartmentId)).First(l.ctx)
-					if err != nil {
-						if ent.IsNotFound(err) {
-							creditBalanceItem, err = tx.CreditBalance.Create().SetOrganizationID(*userInfo.DepartmentId).SetBalance(0).Save(l.ctx)
-							if err != nil {
-								_ = tx.Rollback()
-								l.Logger.Errorf("create credit_balance failed. organization:%v error:%v\n", *userInfo.DepartmentId, err)
-							}
+							l.Logger.Errorf("save credits info failed:%v\n", err)
 						} else {
-							_ = tx.Rollback()
-							l.Logger.Errorf("query credit_balance failed. organization:%v error:%v\n", *userInfo.DepartmentId, err)
+							_ = tx.Commit()
 						}
 					}
-
-					beforeNumber := creditBalanceItem.Balance
-					afterNumber := l.subtraction(beforeNumber, number)
-					_, err = tx.CreditUsage.Create().
-						SetUserID(userId).
-						SetNumber(number).
-						SetBeforeNumber(beforeNumber).
-						SetAfterNumber(afterNumber).
-						SetNtype(1).
-						SetNid(usageDetailItem.ID).
-						SetTable("usage_detail").
-						SetOrganizationID(*userInfo.DepartmentId).
-						Save(l.ctx)
-					if err != nil {
-						_ = tx.Rollback()
-						l.Logger.Errorf("save data to credit_usage error:%v\n", err)
-					}
-
-					// 减积分
-					_, err = tx.CreditBalance.Update().Where(creditbalance.OrganizationID(*userInfo.DepartmentId)).SetBalance(creditBalance.Balance - number).Save(l.ctx)
-					if err != nil {
-						_ = tx.Rollback()
-						l.Logger.Errorf("update organization:%v balance error:%v\n", *userInfo.DepartmentId, err)
-					}
-					_ = tx.Commit()
 				}
 				break
 			}
 		}
 	}
 }
-
-// subtraction() 保留小数点后4位的精确减法
-func (l *GptsSubmitApiChatLogic) subtraction(number1, number2 float64) float64 {
-	d1 := number1 * 10000
-	d2 := number2 * 10000
-	res := math.Floor(d1-d2) / 10000
-	return res
-}