|
@@ -5,18 +5,274 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "net"
|
|
|
"net/http"
|
|
|
+ "net/url"
|
|
|
+ "reflect"
|
|
|
+ "regexp"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
|
|
|
+ "wechat-api/ent"
|
|
|
+ "wechat-api/ent/custom_types"
|
|
|
+ "wechat-api/hook/credit"
|
|
|
+ "wechat-api/internal/config"
|
|
|
"wechat-api/internal/types"
|
|
|
"wechat-api/internal/utils/contextkey"
|
|
|
|
|
|
openai "github.com/openai/openai-go"
|
|
|
"github.com/openai/openai-go/option"
|
|
|
"github.com/openai/openai-go/packages/ssestream"
|
|
|
+ "github.com/redis/go-redis/v9"
|
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
|
)
|
|
|
|
|
|
+type ServiceContext struct {
|
|
|
+ Config config.Config
|
|
|
+ DB *ent.Client
|
|
|
+ Rds redis.UniversalClient
|
|
|
+}
|
|
|
+
|
|
|
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
|
|
|
+
|
|
|
+ rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
|
|
|
+ orgId := uint64(0)
|
|
|
+ apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
|
|
|
+ if ok {
|
|
|
+ orgId = apiKeyObj.OrganizationID
|
|
|
+ }
|
|
|
+ msgContent := getMessageContentStr(req.Messages[0].Content)
|
|
|
+ tx, err := svcCtx.DB.Tx(context.Background())
|
|
|
+ if err != nil {
|
|
|
+ err = fmt.Errorf("start transaction error:%v", err)
|
|
|
+ } else {
|
|
|
+ usage := credit.Usage{
|
|
|
+ CompletionTokens: uint64(resp.Usage.CompletionTokens),
|
|
|
+ PromptTokens: uint64(resp.Usage.PromptTokens),
|
|
|
+ TotalTokens: uint64(resp.Usage.TotalTokens),
|
|
|
+ }
|
|
|
+ err = credit.AddCreditUsage(tx, ctx,
|
|
|
+ authToken, req.EventType, orgId,
|
|
|
+ &msgContent, &resp.Choices[0].Message.Content,
|
|
|
+ &rawReqResp, &usage,
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ _ = tx.Rollback()
|
|
|
+ err = fmt.Errorf("save credits info failed:%v", err)
|
|
|
+ } else {
|
|
|
+ _ = tx.Commit()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func IsValidURL(input *string, adjust bool) bool {
|
|
|
+
|
|
|
+ // 合法域名正则(支持通配符、中文域名等场景按需调整)
|
|
|
+ var domainRegex = regexp.MustCompile(
|
|
|
+ // 多级域名(如 example.com)
|
|
|
+ `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
|
|
|
+ `|` +
|
|
|
+ // 单级域名(如 localhost 或 mytest-svc)
|
|
|
+ `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
|
|
|
+ )
|
|
|
+
|
|
|
+ // 空值直接返回
|
|
|
+ if *input == "" {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ inputStr := *input
|
|
|
+
|
|
|
+ // --- 预处理输入:自动补全协议 ---
|
|
|
+ // 若输入不包含协议头,默认添加 http://
|
|
|
+ if !strings.Contains(*input, "://") {
|
|
|
+ inputStr = "http://" + *input
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- 解析 URL ---
|
|
|
+ u, err := url.Parse(inputStr)
|
|
|
+ if err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- 校验协议 ---
|
|
|
+ // 只允许常见协议(按需扩展)
|
|
|
+ switch u.Scheme {
|
|
|
+ case "http", "https", "ftp", "ftps":
|
|
|
+ default:
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- 拆分 Host 和 Port ---
|
|
|
+ host, port, err := net.SplitHostPort(u.Host)
|
|
|
+ if err != nil {
|
|
|
+ // 无端口时,整个 Host 作为主机名
|
|
|
+ host = u.Host
|
|
|
+ port = ""
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- 校验主机名 ---
|
|
|
+ // 场景1:IPv4 或 IPv6
|
|
|
+ if ip := net.ParseIP(host); ip != nil {
|
|
|
+ // 允许私有或保留 IP(按需调整)
|
|
|
+ // 示例中允许所有合法 IP
|
|
|
+ } else {
|
|
|
+ // 场景2:域名(包括 localhost)
|
|
|
+ if !domainRegex.MatchString(host) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- 校验端口 ---
|
|
|
+ if port != "" {
|
|
|
+ p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
|
|
|
+ if err != nil {
|
|
|
+ // 直接尝试解析为数字端口
|
|
|
+ numPort, err := strconv.Atoi(port)
|
|
|
+ if err != nil || numPort < 1 || numPort > 65535 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ } else if p == 0 { // 动态端口为 0 时无效
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if adjust {
|
|
|
+ *input = inputStr
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func getMessageContentStr(input any) string {
|
|
|
+ str := ""
|
|
|
+ switch val := input.(type) {
|
|
|
+ case string:
|
|
|
+ str = val
|
|
|
+ case []any:
|
|
|
+ if len(val) > 0 {
|
|
|
+ if valc, ok := val[0].(map[string]any); ok {
|
|
|
+ if valcc, ok := valc["text"]; ok {
|
|
|
+ str, _ = valcc.(string)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return str
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
+func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+
|
|
|
+ authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
|
|
|
+
|
|
|
+ logType := 5
|
|
|
+ rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
|
|
|
+ tmpId := 0
|
|
|
+ tmpId, _ = strconv.Atoi(resp.ID)
|
|
|
+ sessionId := uint64(tmpId)
|
|
|
+ orgId := uint64(0)
|
|
|
+ apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
|
|
|
+ if ok {
|
|
|
+ orgId = apiKeyObj.OrganizationID
|
|
|
+ }
|
|
|
+ promptTokens := uint64(resp.Usage.PromptTokens)
|
|
|
+ completionToken := uint64(resp.Usage.CompletionTokens)
|
|
|
+ totalTokens := promptTokens + completionToken
|
|
|
+
|
|
|
+ msgContent := getMessageContentStr(req.Messages[0].Content)
|
|
|
+
|
|
|
+ _, _, _ = logType, sessionId, totalTokens
|
|
|
+ res, err := svcCtx.DB.UsageDetail.Create().
|
|
|
+ SetNotNilType(&logType).
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilReceiverID(&req.EventType).
|
|
|
+ SetNotNilSessionID(&sessionId).
|
|
|
+ SetNillableRequest(&msgContent).
|
|
|
+ SetNillableResponse(&resp.Choices[0].Message.Content).
|
|
|
+ SetNillableOrganizationID(&orgId).
|
|
|
+ SetOriginalData(rawReqResp).
|
|
|
+ SetNillablePromptTokens(&promptTokens).
|
|
|
+ SetNillableCompletionTokens(&completionToken).
|
|
|
+ SetNillableTotalTokens(&totalTokens).
|
|
|
+ Save(ctx)
|
|
|
+
|
|
|
+ if err == nil { //插入UsageDetai之后再统计UsageTotal
|
|
|
+ updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageTotal
|
|
|
+ predicates = append(predicates, usagetotal.BotIDEQ(authToken))
|
|
|
+ return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+
|
|
|
+ authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
|
|
|
+ if err != nil && !ent.IsNotFound(err) {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if Id > 0 { //UsageTotal have record by newUsageDetailId
|
|
|
+ _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
|
|
|
+ SetTotalTokens(sumTotalTokens).
|
|
|
+ SetEndIndex(newUsageDetailId).
|
|
|
+ Save(ctx)
|
|
|
+ } else { //create new record by newUsageDetailId
|
|
|
+ logType := 5
|
|
|
+ _, err = svcCtx.DB.UsageTotal.Create().
|
|
|
+ SetNotNilBotID(&authToken).
|
|
|
+ SetNotNilEndIndex(&newUsageDetailId).
|
|
|
+ SetNotNilTotalTokens(&sumTotalTokens).
|
|
|
+ SetNillableType(&logType).
|
|
|
+ SetNotNilOrganizationID(&orgId).
|
|
|
+ Save(ctx)
|
|
|
+ }
|
|
|
+
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string, newUsageDetailId uint64, orgId uint64) error {
|
|
|
+
|
|
|
+ sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
|
|
|
+ if err == nil {
|
|
|
+ err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+// sum total_tokens from usagedetail by AuthToken
|
|
|
+func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
|
|
|
+ authToken string) (uint64, error) {
|
|
|
+
|
|
|
+ var predicates []predicate.UsageDetail
|
|
|
+ predicates = append(predicates, usagedetail.BotIDEQ(authToken))
|
|
|
+
|
|
|
+ var res []struct {
|
|
|
+ Sum, Min, Max, Count uint64
|
|
|
+ }
|
|
|
+ totalTokens := uint64(0)
|
|
|
+ var err error = nil
|
|
|
+ err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
|
|
|
+ ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
|
|
|
+ if err == nil {
|
|
|
+ if len(res) > 0 {
|
|
|
+ totalTokens = res[0].Sum
|
|
|
+ } else {
|
|
|
+ totalTokens = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return totalTokens, err
|
|
|
+}
|
|
|
+*/
|
|
|
+
|
|
|
func IsOpenaiModel(model string) bool {
|
|
|
|
|
|
prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
|
|
@@ -30,6 +286,68 @@ func IsOpenaiModel(model string) bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
|
|
|
+ t := reflect.TypeOf(structPtr)
|
|
|
+ v := reflect.ValueOf(structPtr)
|
|
|
+
|
|
|
+ if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
|
|
|
+ return "", nil, errors.New("input must be a pointer to a struct")
|
|
|
+ }
|
|
|
+ t = t.Elem()
|
|
|
+ v = v.Elem()
|
|
|
+
|
|
|
+ var fields []string
|
|
|
+ var scanArgs []any
|
|
|
+
|
|
|
+ ignoredMap := make(map[reflect.Type]struct{})
|
|
|
+ // 检查调用者是否传入了任何要忽略的类型
|
|
|
+ if len(ignoredTypes) > 0 {
|
|
|
+ for _, ignoredType := range ignoredTypes {
|
|
|
+ if ignoredType != nil { // 防止 nil 类型加入 map
|
|
|
+ ignoredMap[ignoredType] = struct{}{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for i := 0; i < t.NumField(); i++ {
|
|
|
+ field := t.Field(i)
|
|
|
+ value := v.Field(i)
|
|
|
+
|
|
|
+ // Skip unexported fields
|
|
|
+ if !field.IsExported() {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get json tag
|
|
|
+ jsonTag := field.Tag.Get("json")
|
|
|
+ if jsonTag == "-" || jsonTag == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ jsonParts := strings.Split(jsonTag, ",")
|
|
|
+ jsonName := jsonParts[0]
|
|
|
+ if jsonName == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ //传入了要忽略的类型时进行处理
|
|
|
+ if len(ignoredMap) > 0 {
|
|
|
+ fieldType := field.Type //获取字段的实际 Go 类型
|
|
|
+ //如果字段是指针,我们通常关心的是指针指向的元素的类型
|
|
|
+ if fieldType.Kind() == reflect.Ptr {
|
|
|
+ fieldType = fieldType.Elem() // 获取元素类型
|
|
|
+ }
|
|
|
+ if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
|
|
|
+ continue // 成员类型存在于忽略列表中则忽略
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fields = append(fields, jsonName)
|
|
|
+ scanArgs = append(scanArgs, value.Addr().Interface())
|
|
|
+ }
|
|
|
+ return strings.Join(fields, ", "), scanArgs, nil
|
|
|
+}
|
|
|
+
|
|
|
type StdChatClient struct {
|
|
|
*openai.Client
|
|
|
}
|