123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573 |
- package compapi
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "net"
- "net/http"
- "net/url"
- "reflect"
- "regexp"
- "strconv"
- "strings"
- "wechat-api/ent"
- "wechat-api/ent/custom_types"
- "wechat-api/hook/credit"
- "wechat-api/internal/config"
- "wechat-api/internal/types"
- "wechat-api/internal/utils/contextkey"
- openai "github.com/openai/openai-go"
- "github.com/openai/openai-go/option"
- "github.com/openai/openai-go/packages/ssestream"
- "github.com/redis/go-redis/v9"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- type ServiceContext struct {
- Config config.Config
- DB *ent.Client
- Rds redis.UniversalClient
- }
- func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
- authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
- rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
- orgId := uint64(0)
- apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
- if ok {
- orgId = apiKeyObj.OrganizationID
- }
- msgContent := getMessageContentStr(req.Messages[0].Content)
- tx, err := svcCtx.DB.Tx(context.Background())
- if err != nil {
- err = fmt.Errorf("start transaction error:%v", err)
- } else {
- usage := credit.Usage{
- CompletionTokens: uint64(resp.Usage.CompletionTokens),
- PromptTokens: uint64(resp.Usage.PromptTokens),
- TotalTokens: uint64(resp.Usage.TotalTokens),
- }
- err = credit.AddCreditUsage(tx, ctx,
- authToken, req.EventType, orgId,
- &msgContent, &resp.Choices[0].Message.Content,
- &rawReqResp, &usage, req.Model,
- )
- if err != nil {
- _ = tx.Rollback()
- err = fmt.Errorf("save credits info failed:%v", err)
- } else {
- _ = tx.Commit()
- }
- }
- return err
- }
- func IsValidURL(input *string, adjust bool) bool {
- // 合法域名正则(支持通配符、中文域名等场景按需调整)
- var domainRegex = regexp.MustCompile(
- // 多级域名(如 example.com)
- `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
- `|` +
- // 单级域名(如 localhost 或 mytest-svc)
- `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
- )
- // 空值直接返回
- if *input == "" {
- return false
- }
- inputStr := *input
- // --- 预处理输入:自动补全协议 ---
- // 若输入不包含协议头,默认添加 http://
- if !strings.Contains(*input, "://") {
- inputStr = "http://" + *input
- }
- // --- 解析 URL ---
- u, err := url.Parse(inputStr)
- if err != nil {
- return false
- }
- // --- 校验协议 ---
- // 只允许常见协议(按需扩展)
- switch u.Scheme {
- case "http", "https", "ftp", "ftps":
- default:
- return false
- }
- // --- 拆分 Host 和 Port ---
- host, port, err := net.SplitHostPort(u.Host)
- if err != nil {
- // 无端口时,整个 Host 作为主机名
- host = u.Host
- port = ""
- }
- // --- 校验主机名 ---
- // 场景1:IPv4 或 IPv6
- if ip := net.ParseIP(host); ip != nil {
- // 允许私有或保留 IP(按需调整)
- // 示例中允许所有合法 IP
- } else {
- // 场景2:域名(包括 localhost)
- if !domainRegex.MatchString(host) {
- return false
- }
- }
- // --- 校验端口 ---
- if port != "" {
- p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
- if err != nil {
- // 直接尝试解析为数字端口
- numPort, err := strconv.Atoi(port)
- if err != nil || numPort < 1 || numPort > 65535 {
- return false
- }
- } else if p == 0 { // 动态端口为 0 时无效
- return false
- }
- }
- if adjust {
- *input = inputStr
- }
- return true
- }
- func getMessageContentStr(input any) string {
- str := ""
- switch val := input.(type) {
- case string:
- str = val
- case []any:
- if len(val) > 0 {
- if valc, ok := val[0].(map[string]any); ok {
- if valcc, ok := valc["text"]; ok {
- str, _ = valcc.(string)
- }
- }
- }
- }
- return str
- }
- /*
- func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
- authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
- logType := 5
- rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
- tmpId := 0
- tmpId, _ = strconv.Atoi(resp.ID)
- sessionId := uint64(tmpId)
- orgId := uint64(0)
- apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
- if ok {
- orgId = apiKeyObj.OrganizationID
- }
- promptTokens := uint64(resp.Usage.PromptTokens)
- completionToken := uint64(resp.Usage.CompletionTokens)
- totalTokens := promptTokens + completionToken
- msgContent := getMessageContentStr(req.Messages[0].Content)
- _, _, _ = logType, sessionId, totalTokens
- res, err := svcCtx.DB.UsageDetail.Create().
- SetNotNilType(&logType).
- SetNotNilBotID(&authToken).
- SetNotNilReceiverID(&req.EventType).
- SetNotNilSessionID(&sessionId).
- SetNillableRequest(&msgContent).
- SetNillableResponse(&resp.Choices[0].Message.Content).
- SetNillableOrganizationID(&orgId).
- SetOriginalData(rawReqResp).
- SetNillablePromptTokens(&promptTokens).
- SetNillableCompletionTokens(&completionToken).
- SetNillableTotalTokens(&totalTokens).
- Save(ctx)
- if err == nil { //插入UsageDetai之后再统计UsageTotal
- updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
- }
- return err
- }
- func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
- authToken string) (uint64, error) {
- var predicates []predicate.UsageTotal
- predicates = append(predicates, usagetotal.BotIDEQ(authToken))
- return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
- }
- func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
- authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
- Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
- if err != nil && !ent.IsNotFound(err) {
- return err
- }
- if Id > 0 { //UsageTotal have record by newUsageDetailId
- _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
- SetTotalTokens(sumTotalTokens).
- SetEndIndex(newUsageDetailId).
- Save(ctx)
- } else { //create new record by newUsageDetailId
- logType := 5
- _, err = svcCtx.DB.UsageTotal.Create().
- SetNotNilBotID(&authToken).
- SetNotNilEndIndex(&newUsageDetailId).
- SetNotNilTotalTokens(&sumTotalTokens).
- SetNillableType(&logType).
- SetNotNilOrganizationID(&orgId).
- Save(ctx)
- }
- return err
- }
- func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
- authToken string, newUsageDetailId uint64, orgId uint64) error {
- sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
- if err == nil {
- err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
- }
- return err
- }
- // sum total_tokens from usagedetail by AuthToken
- func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
- authToken string) (uint64, error) {
- var predicates []predicate.UsageDetail
- predicates = append(predicates, usagedetail.BotIDEQ(authToken))
- var res []struct {
- Sum, Min, Max, Count uint64
- }
- totalTokens := uint64(0)
- var err error = nil
- err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
- ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
- if err == nil {
- if len(res) > 0 {
- totalTokens = res[0].Sum
- } else {
- totalTokens = 0
- }
- }
- return totalTokens, err
- }
- */
- func IsOpenaiModel(model string) bool {
- prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
- // 遍历所有前缀进行检查
- for _, prefix := range prefixes {
- if strings.HasPrefix(model, prefix) {
- return true
- }
- }
- return false
- }
- func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
- t := reflect.TypeOf(structPtr)
- v := reflect.ValueOf(structPtr)
- if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
- return "", nil, errors.New("input must be a pointer to a struct")
- }
- t = t.Elem()
- v = v.Elem()
- var fields []string
- var scanArgs []any
- ignoredMap := make(map[reflect.Type]struct{})
- // 检查调用者是否传入了任何要忽略的类型
- if len(ignoredTypes) > 0 {
- for _, ignoredType := range ignoredTypes {
- if ignoredType != nil { // 防止 nil 类型加入 map
- ignoredMap[ignoredType] = struct{}{}
- }
- }
- }
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
- value := v.Field(i)
- // Skip unexported fields
- if !field.IsExported() {
- continue
- }
- // Get json tag
- jsonTag := field.Tag.Get("json")
- if jsonTag == "-" || jsonTag == "" {
- continue
- }
- jsonParts := strings.Split(jsonTag, ",")
- jsonName := jsonParts[0]
- if jsonName == "" {
- continue
- }
- //传入了要忽略的类型时进行处理
- if len(ignoredMap) > 0 {
- fieldType := field.Type //获取字段的实际 Go 类型
- //如果字段是指针,我们通常关心的是指针指向的元素的类型
- if fieldType.Kind() == reflect.Ptr {
- fieldType = fieldType.Elem() // 获取元素类型
- }
- if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
- continue // 成员类型存在于忽略列表中则忽略
- }
- }
- fields = append(fields, jsonName)
- scanArgs = append(scanArgs, value.Addr().Interface())
- }
- return strings.Join(fields, ", "), scanArgs, nil
- }
- type StdChatClient struct {
- *openai.Client
- }
- func NewStdChatClient(apiKey string, apiBase string) *StdChatClient {
- opts := []option.RequestOption{}
- if len(apiKey) > 0 {
- opts = append(opts, option.WithAPIKey(apiKey))
- }
- opts = append(opts, option.WithBaseURL(apiBase))
- client := openai.NewClient(opts...)
- return &StdChatClient{&client}
- }
- func NewAiClient(apiKey string, apiBase string) *openai.Client {
- opts := []option.RequestOption{}
- if len(apiKey) > 0 {
- opts = append(opts, option.WithAPIKey(apiKey))
- }
- opts = append(opts, option.WithBaseURL(apiBase))
- client := openai.NewClient(opts...)
- return &client
- }
- func NewFastgptClient(apiKey string) *openai.Client {
- //http://fastgpt.ascrm.cn/api/v1/
- client := openai.NewClient(option.WithAPIKey(apiKey),
- option.WithBaseURL("http://fastgpt.ascrm.cn/api/v1/"))
- return &client
- }
- func NewDeepSeekClient(apiKey string) *openai.Client {
- client := openai.NewClient(option.WithAPIKey(apiKey),
- option.WithBaseURL("https://api.deepseek.com"))
- return &client
- }
- func DoChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
- var (
- jsonBytes []byte
- err error
- )
- emptyParams := openai.ChatCompletionNewParams{}
- if jsonBytes, err = json.Marshal(chatInfo); err != nil {
- return nil, err
- }
- //fmt.Printf("In DoChatCompletions, req: '%s'\n", string(jsonBytes))
- //也许应该对请求体不规范成员名进行检查
- customResp := types.CompOpenApiResp{}
- reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
- respBodyOps := option.WithResponseBodyInto(&customResp)
- if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps); err != nil {
- return nil, err
- }
- if customResp.FgtErrCode != nil && customResp.FgtErrStatusTxt != nil { //针对fastgpt出错但New()不返回错误的情况
- return nil, fmt.Errorf("%s(%d)", *customResp.FgtErrStatusTxt, *customResp.FgtErrCode)
- }
- return &customResp, nil
- }
- func DoChatCompletionsStream(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
- var (
- jsonBytes []byte
- raw *http.Response
- //raw []byte
- ok bool
- hw http.ResponseWriter
- )
- hw, ok = contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter
- if !ok {
- return nil, errors.New("content get http writer err")
- }
- flusher, ok := (hw).(http.Flusher)
- if !ok {
- http.Error(hw, "Streaming unsupported!", http.StatusInternalServerError)
- }
- emptyParams := openai.ChatCompletionNewParams{}
- if jsonBytes, err = json.Marshal(chatInfo); err != nil {
- return nil, err
- }
- reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
- respBodyOps := option.WithResponseBodyInto(&raw)
- if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps, option.WithJSONSet("stream", true)); err != nil {
- return nil, err
- }
- //设置流式输出头 http1.1
- hw.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
- hw.Header().Set("Connection", "keep-alive")
- hw.Header().Set("Cache-Control", "no-cache")
- chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(raw), err)
- defer chatStream.Close()
- for chatStream.Next() {
- chunk := chatStream.Current()
- fmt.Fprintf(hw, "data:%s\n\n", chunk.Data.RAW)
- flusher.Flush()
- //time.Sleep(1 * time.Millisecond)
- }
- fmt.Fprintf(hw, "data:%s\n\n", "[DONE]")
- flusher.Flush()
- httpx.Ok(hw)
- return nil, nil
- }
- func NewChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
- if chatInfo.Stream {
- return DoChatCompletionsStream(ctx, client, chatInfo)
- } else {
- return DoChatCompletions(ctx, client, chatInfo)
- }
- }
- func NewMismatchChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
- client := NewAiClient(apiKey, apiBase)
- return NewChatCompletions(ctx, client, chatInfo)
- }
- func NewFastgptChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
- client := NewAiClient(apiKey, apiBase)
- return NewChatCompletions(ctx, client, chatInfo)
- }
- func NewDeepSeekChatCompletions(ctx context.Context, apiKey string, chatInfo *types.CompApiReq, chatModel openai.ChatModel) (res *types.CompOpenApiResp, err error) {
- client := NewDeepSeekClient(apiKey)
- if chatModel != ChatModelDeepSeekV3 {
- chatModel = ChatModelDeepSeekR1
- }
- chatInfo.Model = chatModel
- return NewChatCompletions(ctx, client, chatInfo)
- }
- func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
- var (
- jsonBytes []byte
- )
- emptyParams := openai.ChatCompletionNewParams{}
- if jsonBytes, err = json.Marshal(chatInfo); err != nil {
- return nil, err
- }
- reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
- //customResp := types.CompOpenApiResp{}
- //respBodyOps := option.WithResponseBodyInto(&customResp)
- //chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps, respBodyOps)
- chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps)
- // optionally, an accumulator helper can be used
- acc := openai.ChatCompletionAccumulator{}
- httpWriter, ok := ctx.Value("HttpResp-Writer").(http.ResponseWriter)
- if !ok {
- return nil, errors.New("content get writer err")
- }
- //httpWriter.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
- //httpWriter.Header().Set("Connection", "keep-alive")
- //httpWriter.Header().Set("Cache-Control", "no-cache")
- idx := 0
- for chatStream.Next() {
- chunk := chatStream.Current()
- acc.AddChunk(chunk)
- fmt.Printf("=====>get %d chunk:%v\n", idx, chunk)
- if _, err := fmt.Fprintf(httpWriter, "%v", chunk); err != nil {
- fmt.Printf("Error writing to client:%v \n", err)
- break
- }
- if content, ok := acc.JustFinishedContent(); ok {
- println("Content stream finished:", content)
- }
- // if using tool calls
- if tool, ok := acc.JustFinishedToolCall(); ok {
- println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments)
- }
- if refusal, ok := acc.JustFinishedRefusal(); ok {
- println("Refusal stream finished:", refusal)
- }
- // it's best to use chunks after handling JustFinished events
- if len(chunk.Choices) > 0 {
- idx++
- fmt.Printf("idx:%d get =>'%s'\n", idx, chunk.Choices[0].Delta.Content)
- }
- }
- if err := chatStream.Err(); err != nil {
- return nil, err
- }
- return nil, nil
- }
- func GetWorkInfoByID(eventType string, workId string) (string, uint) {
- val, exist := fastgptWorkIdMap[workId]
- if !exist {
- val = fastgptWorkIdMap["default"]
- }
- return val.Id, val.Idx
- }
- // 获取workToken
- func GetWorkTokenByID(eventType string, workId string) string {
- id, _ := GetWorkInfoByID(eventType, workId)
- return id
- }
- // 获取workIdx
- func GetWorkIdxByID(eventType string, workId string) uint {
- _, idx := GetWorkInfoByID(eventType, workId)
- return idx
- }
|