Browse Source

实现异步任务调用大模型后追加token usage统计

liwei 1 day ago
parent
commit
44d0adec43

+ 12 - 45
cli/asynctask/asynctask.go

@@ -10,7 +10,6 @@ import (
 	"hash/fnv"
 	"os"
 	"os/signal"
-	"reflect"
 	"runtime"
 	"strconv"
 	"strings"
@@ -324,8 +323,8 @@ func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
 		err     error
 		apiResp *types.CompOpenApiResp
 	)
-	req := types.CompApiReq{}
-	if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
+	req := &types.CompApiReq{}
+	if err = json.Unmarshal([]byte(taskData.RequestRaw), req); err != nil {
 		return 0, err
 	}
 	//初始化client
@@ -341,7 +340,7 @@ func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
 		case <-me.ctx.Done(): //接到信号退出
 			goto endloopTry
 		default:
-			apiResp, err = client.Chat(&req)
+			apiResp, err = client.Chat(req)
 			if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
 				//call succ
 				goto endloopTry
@@ -371,6 +370,13 @@ endloopTry:
 	}
 	//成功后处理环节
 
+	//追加访问大模型token统计数据
+	svcCtx := &compapi.ServiceContext{DB: me.svcCtx.DB, Rds: me.svcCtx.Rds, Config: me.svcCtx.Config}
+	err = compapi.AppendUsageDetailLog(me.ctx, svcCtx, taskData.AuthToken, req, apiResp)
+	if err != nil {
+		return 0, err
+	}
+
 	//更新任务状态 => ReqApi_Done(请求API完成)
 	err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
 	if err != nil {
@@ -390,51 +396,12 @@ endloopTry:
 	return 1, nil
 }
 
-func EntStructGenScanField(structPtr any) (string, []any, error) {
-	t := reflect.TypeOf(structPtr)
-	v := reflect.ValueOf(structPtr)
-
-	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
-		return "", nil, errors.New("input must be a pointer to a struct")
-	}
-	t = t.Elem()
-	v = v.Elem()
-
-	var fields []string
-	var scanArgs []any
-	for i := 0; i < t.NumField(); i++ {
-		field := t.Field(i)
-		value := v.Field(i)
-
-		// Skip unexported fields
-		if !field.IsExported() {
-			continue
-		}
-
-		// Get json tag
-		jsonTag := field.Tag.Get("json")
-		if jsonTag == "-" || jsonTag == "" {
-			continue
-		}
-
-		jsonParts := strings.Split(jsonTag, ",")
-		jsonName := jsonParts[0]
-		if jsonName == "" {
-			continue
-		}
-
-		fields = append(fields, jsonName)
-		scanArgs = append(scanArgs, value.Addr().Interface())
-	}
-	return strings.Join(fields, ", "), scanArgs, nil
-}
-
 /*
 CREATE INDEX idx_compapi_task_status_chat_id_id_desc
 ON compapi_asynctask (task_status, chat_id, id DESC);
 */
 func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) {
-	fieldListStr, _, err := EntStructGenScanField(&ent.CompapiAsynctask{})
+	fieldListStr, _, err := compapi.EntStructGenScanField(&ent.CompapiAsynctask{})
 	if err != nil {
 		return nil, err
 	}
@@ -464,7 +431,7 @@ func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) {
 	for rows.Next() {
 		taskrow := ent.CompapiAsynctask{}
 		var scanParams []any
-		_, scanParams, err = EntStructGenScanField(&taskrow)
+		_, scanParams, err = compapi.EntStructGenScanField(&taskrow)
 		if err != nil {
 			break
 		}

+ 4 - 244
internal/logic/chat/chat_completions_logic.go

@@ -5,24 +5,13 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"net"
-	"net/url"
-	"regexp"
-	"strconv"
-	"strings"
 	"wechat-api/ent"
-	"wechat-api/hook/credit"
 	"wechat-api/internal/svc"
 	"wechat-api/internal/types"
 	"wechat-api/internal/utils/compapi"
 	"wechat-api/internal/utils/contextkey"
 	"wechat-api/internal/utils/typekit"
 
-	"wechat-api/ent/custom_types"
-	"wechat-api/ent/predicate"
-	"wechat-api/ent/usagedetail"
-	"wechat-api/ent/usagetotal"
-
 	"github.com/zeromicro/go-zero/core/logx"
 )
 
@@ -137,7 +126,7 @@ func (l *ChatCompletionsLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *e
 	if req.IsBatch {
 		//流模式暂时不支持异步模式
 		//Callback格式非法则取消批量模式
-		if req.Stream || !IsValidURL(&req.Callback, true) {
+		if req.Stream || !compapi.IsValidURL(&req.Callback, true) {
 			req.IsBatch = false
 		}
 	}
@@ -184,139 +173,9 @@ func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *ty
 }
 
 func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
-	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
-	orgId := uint64(0)
-	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
-	if ok {
-		orgId = apiKeyObj.OrganizationID
-	}
-	msgContent := getMessageContentStr(req.Messages[0].Content)
-	tx, err := l.svcCtx.DB.Tx(context.Background())
-	if err != nil {
-		l.Logger.Errorf("start transaction error:%v\n", err)
-	} else {
-		usage := credit.Usage{
-			CompletionTokens: uint64(resp.Usage.CompletionTokens),
-			PromptTokens:     uint64(resp.Usage.PromptTokens),
-			TotalTokens:      uint64(resp.Usage.TotalTokens),
-		}
-		err = credit.AddCreditUsage(tx, l.ctx,
-			authToken, req.EventType, orgId,
-			&msgContent, &resp.Choices[0].Message.Content,
-			&rawReqResp, &usage,
-		)
-		if err != nil {
-			_ = tx.Rollback()
-			l.Logger.Errorf("save credits info failed:%v\n", err)
-		} else {
-			_ = tx.Commit()
-		}
-	}
-	return err
-}
-
-//func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
-//
-//	logType := 5
-//	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
-//	tmpId := 0
-//	tmpId, _ = strconv.Atoi(resp.ID)
-//	sessionId := uint64(tmpId)
-//	orgId := uint64(0)
-//	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
-//	if ok {
-//		orgId = apiKeyObj.OrganizationID
-//	}
-//	promptTokens := uint64(resp.Usage.PromptTokens)
-//	completionToken := uint64(resp.Usage.CompletionTokens)
-//	totalTokens := promptTokens + completionToken
-//
-//	msgContent := getMessageContentStr(req.Messages[0].Content)
-//
-//	_, _, _ = logType, sessionId, totalTokens
-//	res, err := l.svcCtx.DB.UsageDetail.Create().
-//		SetNotNilType(&logType).
-//		SetNotNilBotID(&authToken).
-//		SetNotNilReceiverID(&req.EventType).
-//		SetNotNilSessionID(&sessionId).
-//		SetNillableRequest(&msgContent).
-//		SetNillableResponse(&resp.Choices[0].Message.Content).
-//		SetNillableOrganizationID(&orgId).
-//		SetOriginalData(rawReqResp).
-//		SetNillablePromptTokens(&promptTokens).
-//		SetNillableCompletionTokens(&completionToken).
-//		SetNillableTotalTokens(&totalTokens).
-//		Save(l.ctx)
-//
-//	if err == nil { //插入UsageDetai之后再统计UsageTotal
-//		l.updateUsageTotal(authToken, res.ID, orgId)
-//	}
-//	return err
-//}
-
-func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageTotal
-	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
-	return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
-
-}
-
-func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
-
-	Id, err := l.getUsagetotalIdByToken(authToken)
-	if err != nil && !ent.IsNotFound(err) {
-		return err
-	}
-	if Id > 0 { //UsageTotal have record by  newUsageDetailId
-		_, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
-			SetTotalTokens(sumTotalTokens).
-			SetEndIndex(newUsageDetailId).
-			Save(l.ctx)
-	} else { //create new record by  newUsageDetailId
-		logType := 5
-		_, err = l.svcCtx.DB.UsageTotal.Create().
-			SetNotNilBotID(&authToken).
-			SetNotNilEndIndex(&newUsageDetailId).
-			SetNotNilTotalTokens(&sumTotalTokens).
-			SetNillableType(&logType).
-			SetNotNilOrganizationID(&orgId).
-			Save(l.ctx)
-	}
-
-	return err
-}
-
-func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
-
-	sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
-	if err == nil {
-		err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
-	}
-	return err
-}
-
-// sum total_tokens from usagedetail by AuthToken
-func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
-
-	var predicates []predicate.UsageDetail
-	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
-
-	var res []struct {
-		Sum, Min, Max, Count uint64
-	}
-	totalTokens := uint64(0)
-	var err error = nil
-	err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
-		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
-	if err == nil {
-		if len(res) > 0 {
-			totalTokens = res[0].Sum
-		} else {
-			totalTokens = 0
-		}
-	}
-	return totalTokens, err
+	svcCtx := &compapi.ServiceContext{Config: l.svcCtx.Config, DB: l.svcCtx.DB,
+		Rds: l.svcCtx.Rds}
+	return compapi.AppendUsageDetailLog(l.ctx, svcCtx, authToken, req, resp)
 }
 
 func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
@@ -419,105 +278,6 @@ func humanSeeValidResult(ctx context.Context, req *types.CompApiReq, resp *types
 	fmt.Println(typekit.PrettyPrint(nres))
 }
 
-func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) {
-	if eventType != "fastgpt" {
-		return
-	}
-	obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId)
-}
-
-// 合法域名正则(支持通配符、中文域名等场景按需调整)
-var domainRegex = regexp.MustCompile(
-	// 多级域名(如 example.com)
-	`^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
-		`|` +
-		// 单级域名(如 localhost 或 mytest-svc)
-		`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
-)
-
-func IsValidURL(input *string, adjust bool) bool {
-	// 空值直接返回
-	if *input == "" {
-		return false
-	}
-	inputStr := *input
-
-	// --- 预处理输入:自动补全协议 ---
-	// 若输入不包含协议头,默认添加 http://
-	if !strings.Contains(*input, "://") {
-		inputStr = "http://" + *input
-	}
-
-	// --- 解析 URL ---
-	u, err := url.Parse(inputStr)
-	if err != nil {
-		return false
-	}
-
-	// --- 校验协议 ---
-	// 只允许常见协议(按需扩展)
-	switch u.Scheme {
-	case "http", "https", "ftp", "ftps":
-	default:
-		return false
-	}
-
-	// --- 拆分 Host 和 Port ---
-	host, port, err := net.SplitHostPort(u.Host)
-	if err != nil {
-		// 无端口时,整个 Host 作为主机名
-		host = u.Host
-		port = ""
-	}
-
-	// --- 校验主机名 ---
-	// 场景1:IPv4 或 IPv6
-	if ip := net.ParseIP(host); ip != nil {
-		// 允许私有或保留 IP(按需调整)
-		// 示例中允许所有合法 IP
-	} else {
-		// 场景2:域名(包括 localhost)
-		if !domainRegex.MatchString(host) {
-			return false
-		}
-	}
-
-	// --- 校验端口 ---
-	if port != "" {
-		p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
-		if err != nil {
-			// 直接尝试解析为数字端口
-			numPort, err := strconv.Atoi(port)
-			if err != nil || numPort < 1 || numPort > 65535 {
-				return false
-			}
-		} else if p == 0 { // 动态端口为 0 时无效
-			return false
-		}
-	}
-	if adjust {
-		*input = inputStr
-	}
-	return true
-}
-
-func getMessageContentStr(input any) string {
-	str := ""
-	switch val := input.(type) {
-	case string:
-		str = val
-	case []interface{}:
-		if len(val) > 0 {
-			if valc, ok := val[0].(map[string]interface{}); ok {
-				if valcc, ok := valc["text"]; ok {
-					str, _ = valcc.(string)
-				}
-			}
-		}
-	}
-	return str
-}
-
 func isAsyncReqest(req *types.CompApiReq) bool {
 	return req.IsBatch
 }

+ 318 - 0
internal/utils/compapi/func.go

@@ -5,18 +5,274 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
+	"net/url"
+	"reflect"
+	"regexp"
+	"strconv"
 	"strings"
 
+	"wechat-api/ent"
+	"wechat-api/ent/custom_types"
+	"wechat-api/hook/credit"
+	"wechat-api/internal/config"
 	"wechat-api/internal/types"
 	"wechat-api/internal/utils/contextkey"
 
 	openai "github.com/openai/openai-go"
 	"github.com/openai/openai-go/option"
 	"github.com/openai/openai-go/packages/ssestream"
+	"github.com/redis/go-redis/v9"
 	"github.com/zeromicro/go-zero/rest/httpx"
 )
 
+type ServiceContext struct {
+	Config config.Config
+	DB     *ent.Client
+	Rds    redis.UniversalClient
+}
+
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
+	authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+	tx, err := svcCtx.DB.Tx(context.Background())
+	if err != nil {
+		err = fmt.Errorf("start transaction error:%v", err)
+	} else {
+		usage := credit.Usage{
+			CompletionTokens: uint64(resp.Usage.CompletionTokens),
+			PromptTokens:     uint64(resp.Usage.PromptTokens),
+			TotalTokens:      uint64(resp.Usage.TotalTokens),
+		}
+		err = credit.AddCreditUsage(tx, ctx,
+			authToken, req.EventType, orgId,
+			&msgContent, &resp.Choices[0].Message.Content,
+			&rawReqResp, &usage,
+		)
+		if err != nil {
+			_ = tx.Rollback()
+			err = fmt.Errorf("save credits info failed:%v", err)
+		} else {
+			_ = tx.Commit()
+		}
+	}
+	return err
+}
+
+func IsValidURL(input *string, adjust bool) bool {
+
+	// 合法域名正则(支持通配符、中文域名等场景按需调整)
+	var domainRegex = regexp.MustCompile(
+		// 多级域名(如 example.com)
+		`^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
+			`|` +
+			// 单级域名(如 localhost 或 mytest-svc)
+			`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
+	)
+
+	// 空值直接返回
+	if *input == "" {
+		return false
+	}
+	inputStr := *input
+
+	// --- 预处理输入:自动补全协议 ---
+	// 若输入不包含协议头,默认添加 http://
+	if !strings.Contains(*input, "://") {
+		inputStr = "http://" + *input
+	}
+
+	// --- 解析 URL ---
+	u, err := url.Parse(inputStr)
+	if err != nil {
+		return false
+	}
+
+	// --- 校验协议 ---
+	// 只允许常见协议(按需扩展)
+	switch u.Scheme {
+	case "http", "https", "ftp", "ftps":
+	default:
+		return false
+	}
+
+	// --- 拆分 Host 和 Port ---
+	host, port, err := net.SplitHostPort(u.Host)
+	if err != nil {
+		// 无端口时,整个 Host 作为主机名
+		host = u.Host
+		port = ""
+	}
+
+	// --- 校验主机名 ---
+	// 场景1:IPv4 或 IPv6
+	if ip := net.ParseIP(host); ip != nil {
+		// 允许私有或保留 IP(按需调整)
+		// 示例中允许所有合法 IP
+	} else {
+		// 场景2:域名(包括 localhost)
+		if !domainRegex.MatchString(host) {
+			return false
+		}
+	}
+
+	// --- 校验端口 ---
+	if port != "" {
+		p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
+		if err != nil {
+			// 直接尝试解析为数字端口
+			numPort, err := strconv.Atoi(port)
+			if err != nil || numPort < 1 || numPort > 65535 {
+				return false
+			}
+		} else if p == 0 { // 动态端口为 0 时无效
+			return false
+		}
+	}
+	if adjust {
+		*input = inputStr
+	}
+	return true
+}
+
+func getMessageContentStr(input any) string {
+	str := ""
+	switch val := input.(type) {
+	case string:
+		str = val
+	case []any:
+		if len(val) > 0 {
+			if valc, ok := val[0].(map[string]any); ok {
+				if valcc, ok := valc["text"]; ok {
+					str, _ = valcc.(string)
+				}
+			}
+		}
+	}
+	return str
+}
+
+/*
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	logType := 5
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	tmpId := 0
+	tmpId, _ = strconv.Atoi(resp.ID)
+	sessionId := uint64(tmpId)
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	promptTokens := uint64(resp.Usage.PromptTokens)
+	completionToken := uint64(resp.Usage.CompletionTokens)
+	totalTokens := promptTokens + completionToken
+
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+
+	_, _, _ = logType, sessionId, totalTokens
+	res, err := svcCtx.DB.UsageDetail.Create().
+		SetNotNilType(&logType).
+		SetNotNilBotID(&authToken).
+		SetNotNilReceiverID(&req.EventType).
+		SetNotNilSessionID(&sessionId).
+		SetNillableRequest(&msgContent).
+		SetNillableResponse(&resp.Choices[0].Message.Content).
+		SetNillableOrganizationID(&orgId).
+		SetOriginalData(rawReqResp).
+		SetNillablePromptTokens(&promptTokens).
+		SetNillableCompletionTokens(&completionToken).
+		SetNillableTotalTokens(&totalTokens).
+		Save(ctx)
+
+	if err == nil { //插入UsageDetai之后再统计UsageTotal
+		updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
+	}
+	return err
+}
+
+func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageTotal
+	predicates = append(predicates, usagetotal.BotIDEQ(authToken))
+	return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
+
+}
+
+func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
+
+	authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
+
+	Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
+	if err != nil && !ent.IsNotFound(err) {
+		return err
+	}
+	if Id > 0 { //UsageTotal have record by  newUsageDetailId
+		_, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
+			SetTotalTokens(sumTotalTokens).
+			SetEndIndex(newUsageDetailId).
+			Save(ctx)
+	} else { //create new record by  newUsageDetailId
+		logType := 5
+		_, err = svcCtx.DB.UsageTotal.Create().
+			SetNotNilBotID(&authToken).
+			SetNotNilEndIndex(&newUsageDetailId).
+			SetNotNilTotalTokens(&sumTotalTokens).
+			SetNillableType(&logType).
+			SetNotNilOrganizationID(&orgId).
+			Save(ctx)
+	}
+
+	return err
+}
+
+func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
+	authToken string, newUsageDetailId uint64, orgId uint64) error {
+
+	sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
+	if err == nil {
+		err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
+	}
+	return err
+}
+
+// sum total_tokens from usagedetail by AuthToken
+func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
+	authToken string) (uint64, error) {
+
+	var predicates []predicate.UsageDetail
+	predicates = append(predicates, usagedetail.BotIDEQ(authToken))
+
+	var res []struct {
+		Sum, Min, Max, Count uint64
+	}
+	totalTokens := uint64(0)
+	var err error = nil
+	err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
+		ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
+	if err == nil {
+		if len(res) > 0 {
+			totalTokens = res[0].Sum
+		} else {
+			totalTokens = 0
+		}
+	}
+	return totalTokens, err
+}
+*/
+
 func IsOpenaiModel(model string) bool {
 
 	prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
@@ -30,6 +286,68 @@ func IsOpenaiModel(model string) bool {
 	return false
 }
 
+func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
+	t := reflect.TypeOf(structPtr)
+	v := reflect.ValueOf(structPtr)
+
+	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
+		return "", nil, errors.New("input must be a pointer to a struct")
+	}
+	t = t.Elem()
+	v = v.Elem()
+
+	var fields []string
+	var scanArgs []any
+
+	ignoredMap := make(map[reflect.Type]struct{})
+	// 检查调用者是否传入了任何要忽略的类型
+	if len(ignoredTypes) > 0 {
+		for _, ignoredType := range ignoredTypes {
+			if ignoredType != nil { // 防止 nil 类型加入 map
+				ignoredMap[ignoredType] = struct{}{}
+			}
+		}
+	}
+
+	for i := 0; i < t.NumField(); i++ {
+		field := t.Field(i)
+		value := v.Field(i)
+
+		// Skip unexported fields
+		if !field.IsExported() {
+			continue
+		}
+
+		// Get json tag
+		jsonTag := field.Tag.Get("json")
+		if jsonTag == "-" || jsonTag == "" {
+			continue
+		}
+
+		jsonParts := strings.Split(jsonTag, ",")
+		jsonName := jsonParts[0]
+		if jsonName == "" {
+			continue
+		}
+
+		//传入了要忽略的类型时进行处理
+		if len(ignoredMap) > 0 {
+			fieldType := field.Type //获取字段的实际 Go 类型
+			//如果字段是指针,我们通常关心的是指针指向的元素的类型
+			if fieldType.Kind() == reflect.Ptr {
+				fieldType = fieldType.Elem() // 获取元素类型
+			}
+			if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
+				continue // 成员类型存在于忽略列表中则忽略
+			}
+		}
+
+		fields = append(fields, jsonName)
+		scanArgs = append(scanArgs, value.Addr().Interface())
+	}
+	return strings.Join(fields, ", "), scanArgs, nil
+}
+
 type StdChatClient struct {
 	*openai.Client
 }

+ 6 - 3
internal/utils/compapi/intent.go

@@ -11,7 +11,7 @@ type IntentClient struct {
 func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFormatConfig {
 
 	//Message重构的配置
-	me.ResformatConfig.SysmesArgs = []any{}
+	me.ResformatConfig.SysmesArgs = []any{} //以下SysmesTmpl内无替换占位符,所以这里为空
 	me.ResformatConfig.SysmesTmpl = `# 任务
 1. 首先,判断用户的第一句话是否说了:“你好,(任意内容)通话”,如果说了,则不用理会评级规则,直接强制分配为"语音助手"
 2. 如果不属于“语音助手”,请根据评级规则,对聊天记录给出评级、置信度、评分依据(逐项分析不要遗漏)
@@ -19,6 +19,7 @@ func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFor
 # 细节说明
 置信度从0到1,0为置信度最低,1为置信度最高。`
 
+	//来自请求的变量用来替换UsermesTmpl中的占位符
 	me.ResformatConfig.UsermesArgs = []any{req.Variables["chat_history"]}
 	me.ResformatConfig.UsermesTmpl = `# 评级规则:
         [
@@ -41,18 +42,20 @@ func (me *IntentClient) ResponseFormatSetting(req *types.CompApiReq) ResponseFor
         `
 
 	//ResponseFormat设置的配置
-	me.ResformatConfig.ResformatDesc = "为通话记录进行评级"
+	me.ResformatConfig.ResformatDesc = "为通话记录进行评级" //Resformat描述
+	//非openai兼容大模型所使用的Resformat文本
 	me.ResformatConfig.ResformatTxt = `{
     "score": str, #评分结果:有意向、待进一步分析、暂时无法沟通、其他
     "confidence_score": int, #置信度分数,范围从0.0到1.0
     "scoring_criteria": str, #请逐步介绍为何评为这个结果
 }`
+	//openai兼容大模型所使用的Resformat结构或其他类型
 	me.ResformatConfig.ResformatStruct = struct {
 		Score           string  `json:"score" jsonschema_description:"评分结果:有意向、待进一步分析、暂时无法沟通、其他"`
 		ConfidenceScore float32 `json:"confidence_score" jsonschema_description:"置信度分数,范围从0.0到1.0"`
 		ScoringCriteria string  `json:"scoring_criteria" jsonschema_description:"请逐步介绍为何评为这个结果"`
 	}{}
-	me.ResformatConfig.HaveSet = true
+	me.ResformatConfig.HaveSet = true //很关键,避免父类的参数再设置一遍
 
 	return me.ResformatConfig
 }