package compapi import ( "context" "encoding/json" "errors" "fmt" "net" "net/http" "net/url" "reflect" "regexp" "strconv" "strings" "wechat-api/ent" "wechat-api/ent/custom_types" "wechat-api/hook/credit" "wechat-api/internal/config" "wechat-api/internal/types" "wechat-api/internal/utils/contextkey" openai "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/ssestream" "github.com/redis/go-redis/v9" "github.com/zeromicro/go-zero/rest/httpx" ) type ServiceContext struct { Config config.Config DB *ent.Client Rds redis.UniversalClient } func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext, authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { rawReqResp := custom_types.OriginalData{Request: req, Response: resp} orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx) if ok { orgId = apiKeyObj.OrganizationID } msgContent := getMessageContentStr(req.Messages[0].Content) tx, err := svcCtx.DB.Tx(context.Background()) if err != nil { err = fmt.Errorf("start transaction error:%v", err) } else { usage := credit.Usage{ CompletionTokens: uint64(resp.Usage.CompletionTokens), PromptTokens: uint64(resp.Usage.PromptTokens), TotalTokens: uint64(resp.Usage.TotalTokens), } err = credit.AddCreditUsage(tx, ctx, authToken, req.EventType, orgId, &msgContent, &resp.Choices[0].Message.Content, &rawReqResp, &usage, req.Model, ) if err != nil { _ = tx.Rollback() err = fmt.Errorf("save credits info failed:%v", err) } else { _ = tx.Commit() } } return err } func IsValidURL(input *string, adjust bool) bool { // 合法域名正则(支持通配符、中文域名等场景按需调整) var domainRegex = regexp.MustCompile( // 多级域名(如 example.com) `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` + `|` + // 单级域名(如 localhost 或 mytest-svc) `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`, ) // 空值直接返回 if *input == "" { return false } inputStr := *input // --- 预处理输入:自动补全协议 --- // 若输入不包含协议头,默认添加 http:// if !strings.Contains(*input, "://") { inputStr = "http://" + *input } // --- 解析 URL --- u, err := url.Parse(inputStr) if err != nil { return false } // --- 校验协议 --- // 只允许常见协议(按需扩展) switch u.Scheme { case "http", "https", "ftp", "ftps": default: return false } // --- 拆分 Host 和 Port --- host, port, err := net.SplitHostPort(u.Host) if err != nil { // 无端口时,整个 Host 作为主机名 host = u.Host port = "" } // --- 校验主机名 --- // 场景1:IPv4 或 IPv6 if ip := net.ParseIP(host); ip != nil { // 允许私有或保留 IP(按需调整) // 示例中允许所有合法 IP } else { // 场景2:域名(包括 localhost) if !domainRegex.MatchString(host) { return false } } // --- 校验端口 --- if port != "" { p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80) if err != nil { // 直接尝试解析为数字端口 numPort, err := strconv.Atoi(port) if err != nil || numPort < 1 || numPort > 65535 { return false } } else if p == 0 { // 动态端口为 0 时无效 return false } } if adjust { *input = inputStr } return true } func getMessageContentStr(input any) string { str := "" switch val := input.(type) { case string: str = val case []any: if len(val) > 0 { if valc, ok := val[0].(map[string]any); ok { if valcc, ok := valc["text"]; ok { str, _ = valcc.(string) } } } } return str } /* func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext, authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error { logType := 5 rawReqResp := custom_types.OriginalData{Request: req, Response: resp} tmpId := 0 tmpId, _ = strconv.Atoi(resp.ID) sessionId := uint64(tmpId) orgId := uint64(0) apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx) if ok { orgId = apiKeyObj.OrganizationID } promptTokens := uint64(resp.Usage.PromptTokens) completionToken := uint64(resp.Usage.CompletionTokens) totalTokens := promptTokens + completionToken msgContent := getMessageContentStr(req.Messages[0].Content) _, _, _ = logType, sessionId, totalTokens res, err := svcCtx.DB.UsageDetail.Create(). SetNotNilType(&logType). SetNotNilBotID(&authToken). SetNotNilReceiverID(&req.EventType). SetNotNilSessionID(&sessionId). SetNillableRequest(&msgContent). SetNillableResponse(&resp.Choices[0].Message.Content). SetNillableOrganizationID(&orgId). SetOriginalData(rawReqResp). SetNillablePromptTokens(&promptTokens). SetNillableCompletionTokens(&completionToken). SetNillableTotalTokens(&totalTokens). Save(ctx) if err == nil { //插入UsageDetai之后再统计UsageTotal updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId) } return err } func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext, authToken string) (uint64, error) { var predicates []predicate.UsageTotal predicates = append(predicates, usagetotal.BotIDEQ(authToken)) return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx) } func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext, authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error { Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken) if err != nil && !ent.IsNotFound(err) { return err } if Id > 0 { //UsageTotal have record by newUsageDetailId _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id). SetTotalTokens(sumTotalTokens). SetEndIndex(newUsageDetailId). Save(ctx) } else { //create new record by newUsageDetailId logType := 5 _, err = svcCtx.DB.UsageTotal.Create(). SetNotNilBotID(&authToken). SetNotNilEndIndex(&newUsageDetailId). SetNotNilTotalTokens(&sumTotalTokens). SetNillableType(&logType). SetNotNilOrganizationID(&orgId). Save(ctx) } return err } func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext, authToken string, newUsageDetailId uint64, orgId uint64) error { sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens if err == nil { err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens } return err } // sum total_tokens from usagedetail by AuthToken func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext, authToken string) (uint64, error) { var predicates []predicate.UsageDetail predicates = append(predicates, usagedetail.BotIDEQ(authToken)) var res []struct { Sum, Min, Max, Count uint64 } totalTokens := uint64(0) var err error = nil err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"), ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res) if err == nil { if len(res) > 0 { totalTokens = res[0].Sum } else { totalTokens = 0 } } return totalTokens, err } */ func IsOpenaiModel(model string) bool { prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"} // 遍历所有前缀进行检查 for _, prefix := range prefixes { if strings.HasPrefix(model, prefix) { return true } } return false } func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) { t := reflect.TypeOf(structPtr) v := reflect.ValueOf(structPtr) if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { return "", nil, errors.New("input must be a pointer to a struct") } t = t.Elem() v = v.Elem() var fields []string var scanArgs []any ignoredMap := make(map[reflect.Type]struct{}) // 检查调用者是否传入了任何要忽略的类型 if len(ignoredTypes) > 0 { for _, ignoredType := range ignoredTypes { if ignoredType != nil { // 防止 nil 类型加入 map ignoredMap[ignoredType] = struct{}{} } } } for i := 0; i < t.NumField(); i++ { field := t.Field(i) value := v.Field(i) // Skip unexported fields if !field.IsExported() { continue } // Get json tag jsonTag := field.Tag.Get("json") if jsonTag == "-" || jsonTag == "" { continue } jsonParts := strings.Split(jsonTag, ",") jsonName := jsonParts[0] if jsonName == "" { continue } //传入了要忽略的类型时进行处理 if len(ignoredMap) > 0 { fieldType := field.Type //获取字段的实际 Go 类型 //如果字段是指针,我们通常关心的是指针指向的元素的类型 if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() // 获取元素类型 } if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore { continue // 成员类型存在于忽略列表中则忽略 } } fields = append(fields, jsonName) scanArgs = append(scanArgs, value.Addr().Interface()) } return strings.Join(fields, ", "), scanArgs, nil } type StdChatClient struct { *openai.Client } func NewStdChatClient(apiKey string, apiBase string) *StdChatClient { opts := []option.RequestOption{} if len(apiKey) > 0 { opts = append(opts, option.WithAPIKey(apiKey)) } opts = append(opts, option.WithBaseURL(apiBase)) client := openai.NewClient(opts...) return &StdChatClient{&client} } func NewAiClient(apiKey string, apiBase string) *openai.Client { opts := []option.RequestOption{} if len(apiKey) > 0 { opts = append(opts, option.WithAPIKey(apiKey)) } opts = append(opts, option.WithBaseURL(apiBase)) client := openai.NewClient(opts...) return &client } func NewFastgptClient(apiKey string) *openai.Client { //http://fastgpt.ascrm.cn/api/v1/ client := openai.NewClient(option.WithAPIKey(apiKey), option.WithBaseURL("http://fastgpt.ascrm.cn/api/v1/")) return &client } func NewDeepSeekClient(apiKey string) *openai.Client { client := openai.NewClient(option.WithAPIKey(apiKey), option.WithBaseURL("https://api.deepseek.com")) return &client } func DoChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) { var ( jsonBytes []byte err error ) emptyParams := openai.ChatCompletionNewParams{} if jsonBytes, err = json.Marshal(chatInfo); err != nil { return nil, err } //fmt.Printf("In DoChatCompletions, req: '%s'\n", string(jsonBytes)) //也许应该对请求体不规范成员名进行检查 customResp := types.CompOpenApiResp{} reqBodyOps := option.WithRequestBody("application/json", jsonBytes) respBodyOps := option.WithResponseBodyInto(&customResp) if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps); err != nil { return nil, err } if customResp.FgtErrCode != nil && customResp.FgtErrStatusTxt != nil { //针对fastgpt出错但New()不返回错误的情况 return nil, fmt.Errorf("%s(%d)", *customResp.FgtErrStatusTxt, *customResp.FgtErrCode) } return &customResp, nil } func DoChatCompletionsStream(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) { var ( jsonBytes []byte raw *http.Response //raw []byte ok bool hw http.ResponseWriter ) hw, ok = contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter if !ok { return nil, errors.New("content get http writer err") } flusher, ok := (hw).(http.Flusher) if !ok { http.Error(hw, "Streaming unsupported!", http.StatusInternalServerError) } emptyParams := openai.ChatCompletionNewParams{} if jsonBytes, err = json.Marshal(chatInfo); err != nil { return nil, err } reqBodyOps := option.WithRequestBody("application/json", jsonBytes) respBodyOps := option.WithResponseBodyInto(&raw) if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps, option.WithJSONSet("stream", true)); err != nil { return nil, err } //设置流式输出头 http1.1 hw.Header().Set("Content-Type", "text/event-stream;charset=utf-8") hw.Header().Set("Connection", "keep-alive") hw.Header().Set("Cache-Control", "no-cache") chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(raw), err) defer chatStream.Close() for chatStream.Next() { chunk := chatStream.Current() fmt.Fprintf(hw, "data:%s\n\n", chunk.Data.RAW) flusher.Flush() //time.Sleep(1 * time.Millisecond) } fmt.Fprintf(hw, "data:%s\n\n", "[DONE]") flusher.Flush() httpx.Ok(hw) return nil, nil } func NewChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) { if chatInfo.Stream { return DoChatCompletionsStream(ctx, client, chatInfo) } else { return DoChatCompletions(ctx, client, chatInfo) } } func NewMismatchChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) { client := NewAiClient(apiKey, apiBase) return NewChatCompletions(ctx, client, chatInfo) } func NewFastgptChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) { client := NewAiClient(apiKey, apiBase) return NewChatCompletions(ctx, client, chatInfo) } func NewDeepSeekChatCompletions(ctx context.Context, apiKey string, chatInfo *types.CompApiReq, chatModel openai.ChatModel) (res *types.CompOpenApiResp, err error) { client := NewDeepSeekClient(apiKey) if chatModel != ChatModelDeepSeekV3 { chatModel = ChatModelDeepSeekR1 } chatInfo.Model = chatModel return NewChatCompletions(ctx, client, chatInfo) } func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) { var ( jsonBytes []byte ) emptyParams := openai.ChatCompletionNewParams{} if jsonBytes, err = json.Marshal(chatInfo); err != nil { return nil, err } reqBodyOps := option.WithRequestBody("application/json", jsonBytes) //customResp := types.CompOpenApiResp{} //respBodyOps := option.WithResponseBodyInto(&customResp) //chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps, respBodyOps) chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps) // optionally, an accumulator helper can be used acc := openai.ChatCompletionAccumulator{} httpWriter, ok := ctx.Value("HttpResp-Writer").(http.ResponseWriter) if !ok { return nil, errors.New("content get writer err") } //httpWriter.Header().Set("Content-Type", "text/event-stream;charset=utf-8") //httpWriter.Header().Set("Connection", "keep-alive") //httpWriter.Header().Set("Cache-Control", "no-cache") idx := 0 for chatStream.Next() { chunk := chatStream.Current() acc.AddChunk(chunk) fmt.Printf("=====>get %d chunk:%v\n", idx, chunk) if _, err := fmt.Fprintf(httpWriter, "%v", chunk); err != nil { fmt.Printf("Error writing to client:%v \n", err) break } if content, ok := acc.JustFinishedContent(); ok { println("Content stream finished:", content) } // if using tool calls if tool, ok := acc.JustFinishedToolCall(); ok { println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments) } if refusal, ok := acc.JustFinishedRefusal(); ok { println("Refusal stream finished:", refusal) } // it's best to use chunks after handling JustFinished events if len(chunk.Choices) > 0 { idx++ fmt.Printf("idx:%d get =>'%s'\n", idx, chunk.Choices[0].Delta.Content) } } if err := chatStream.Err(); err != nil { return nil, err } return nil, nil } func GetWorkInfoByID(eventType string, workId string) (string, uint) { val, exist := fastgptWorkIdMap[workId] if !exist { val = fastgptWorkIdMap["default"] } return val.Id, val.Idx } // 获取workToken func GetWorkTokenByID(eventType string, workId string) string { id, _ := GetWorkInfoByID(eventType, workId) return id } // 获取workIdx func GetWorkIdxByID(eventType string, workId string) uint { _, idx := GetWorkInfoByID(eventType, workId) return idx }