func.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. package compapi
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "net/http"
  9. "net/url"
  10. "reflect"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "wechat-api/ent"
  15. "wechat-api/ent/custom_types"
  16. "wechat-api/hook/credit"
  17. "wechat-api/internal/config"
  18. "wechat-api/internal/types"
  19. "wechat-api/internal/utils/contextkey"
  20. openai "github.com/openai/openai-go"
  21. "github.com/openai/openai-go/option"
  22. "github.com/openai/openai-go/packages/ssestream"
  23. "github.com/redis/go-redis/v9"
  24. "github.com/zeromicro/go-zero/rest/httpx"
  25. )
  26. type ServiceContext struct {
  27. Config config.Config
  28. DB *ent.Client
  29. Rds redis.UniversalClient
  30. }
  31. func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
  32. authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  33. tmp0 := make([]custom_types.VResponseData, 0)
  34. for _, tmp1 := range resp.ResponseData {
  35. tmp21, _ := tmp1["id"].(string)
  36. tmp22, _ := tmp1["model"].(string)
  37. tmp2 := custom_types.VResponseData{ID: tmp21, Model: tmp22}
  38. tmp0 = append(tmp0, tmp2)
  39. }
  40. tmp3 := custom_types.VResponse{ID: resp.ID, Model: resp.Model, ResponseData: tmp0}
  41. rawReqResp := custom_types.OriginalData{Request: req, Response: resp, VResponse: &tmp3}
  42. orgId := uint64(0)
  43. apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
  44. if ok {
  45. orgId = apiKeyObj.OrganizationID
  46. }
  47. msgContent := getMessageContentStr(req.Messages[0].Content)
  48. tx, err := svcCtx.DB.Tx(context.Background())
  49. if err != nil {
  50. err = fmt.Errorf("start transaction error:%v", err)
  51. } else {
  52. usage := credit.Usage{
  53. CompletionTokens: uint64(resp.Usage.CompletionTokens),
  54. PromptTokens: uint64(resp.Usage.PromptTokens),
  55. TotalTokens: uint64(resp.Usage.TotalTokens),
  56. }
  57. err = credit.AddCreditUsage(tx, ctx,
  58. authToken, req.EventType, orgId,
  59. &msgContent, &resp.Choices[0].Message.Content,
  60. &rawReqResp, &usage, req.Model,
  61. )
  62. if err != nil {
  63. _ = tx.Rollback()
  64. err = fmt.Errorf("save credits info failed:%v", err)
  65. } else {
  66. _ = tx.Commit()
  67. }
  68. }
  69. return err
  70. }
  71. func IsValidURL(input *string, adjust bool) bool {
  72. // 合法域名正则(支持通配符、中文域名等场景按需调整)
  73. var domainRegex = regexp.MustCompile(
  74. // 多级域名(如 example.com)
  75. `^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
  76. `|` +
  77. // 单级域名(如 localhost 或 mytest-svc)
  78. `^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
  79. )
  80. // 空值直接返回
  81. if *input == "" {
  82. return false
  83. }
  84. inputStr := *input
  85. // --- 预处理输入:自动补全协议 ---
  86. // 若输入不包含协议头,默认添加 http://
  87. if !strings.Contains(*input, "://") {
  88. inputStr = "http://" + *input
  89. }
  90. // --- 解析 URL ---
  91. u, err := url.Parse(inputStr)
  92. if err != nil {
  93. return false
  94. }
  95. // --- 校验协议 ---
  96. // 只允许常见协议(按需扩展)
  97. switch u.Scheme {
  98. case "http", "https", "ftp", "ftps":
  99. default:
  100. return false
  101. }
  102. // --- 拆分 Host 和 Port ---
  103. host, port, err := net.SplitHostPort(u.Host)
  104. if err != nil {
  105. // 无端口时,整个 Host 作为主机名
  106. host = u.Host
  107. port = ""
  108. }
  109. // --- 校验主机名 ---
  110. // 场景1:IPv4 或 IPv6
  111. if ip := net.ParseIP(host); ip != nil {
  112. // 允许私有或保留 IP(按需调整)
  113. // 示例中允许所有合法 IP
  114. } else {
  115. // 场景2:域名(包括 localhost)
  116. if !domainRegex.MatchString(host) {
  117. return false
  118. }
  119. }
  120. // --- 校验端口 ---
  121. if port != "" {
  122. p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
  123. if err != nil {
  124. // 直接尝试解析为数字端口
  125. numPort, err := strconv.Atoi(port)
  126. if err != nil || numPort < 1 || numPort > 65535 {
  127. return false
  128. }
  129. } else if p == 0 { // 动态端口为 0 时无效
  130. return false
  131. }
  132. }
  133. if adjust {
  134. *input = inputStr
  135. }
  136. return true
  137. }
  138. func getMessageContentStr(input any) string {
  139. str := ""
  140. switch val := input.(type) {
  141. case string:
  142. str = val
  143. case []any:
  144. if len(val) > 0 {
  145. if valc, ok := val[0].(map[string]any); ok {
  146. if valcc, ok := valc["text"]; ok {
  147. str, _ = valcc.(string)
  148. }
  149. }
  150. }
  151. }
  152. return str
  153. }
  154. /*
  155. func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
  156. authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  157. logType := 5
  158. rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
  159. tmpId := 0
  160. tmpId, _ = strconv.Atoi(resp.ID)
  161. sessionId := uint64(tmpId)
  162. orgId := uint64(0)
  163. apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
  164. if ok {
  165. orgId = apiKeyObj.OrganizationID
  166. }
  167. promptTokens := uint64(resp.Usage.PromptTokens)
  168. completionToken := uint64(resp.Usage.CompletionTokens)
  169. totalTokens := promptTokens + completionToken
  170. msgContent := getMessageContentStr(req.Messages[0].Content)
  171. _, _, _ = logType, sessionId, totalTokens
  172. res, err := svcCtx.DB.UsageDetail.Create().
  173. SetNotNilType(&logType).
  174. SetNotNilBotID(&authToken).
  175. SetNotNilReceiverID(&req.EventType).
  176. SetNotNilSessionID(&sessionId).
  177. SetNillableRequest(&msgContent).
  178. SetNillableResponse(&resp.Choices[0].Message.Content).
  179. SetNillableOrganizationID(&orgId).
  180. SetOriginalData(rawReqResp).
  181. SetNillablePromptTokens(&promptTokens).
  182. SetNillableCompletionTokens(&completionToken).
  183. SetNillableTotalTokens(&totalTokens).
  184. Save(ctx)
  185. if err == nil { //插入UsageDetai之后再统计UsageTotal
  186. updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
  187. }
  188. return err
  189. }
  190. func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
  191. authToken string) (uint64, error) {
  192. var predicates []predicate.UsageTotal
  193. predicates = append(predicates, usagetotal.BotIDEQ(authToken))
  194. return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
  195. }
  196. func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
  197. authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
  198. Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
  199. if err != nil && !ent.IsNotFound(err) {
  200. return err
  201. }
  202. if Id > 0 { //UsageTotal have record by newUsageDetailId
  203. _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
  204. SetTotalTokens(sumTotalTokens).
  205. SetEndIndex(newUsageDetailId).
  206. Save(ctx)
  207. } else { //create new record by newUsageDetailId
  208. logType := 5
  209. _, err = svcCtx.DB.UsageTotal.Create().
  210. SetNotNilBotID(&authToken).
  211. SetNotNilEndIndex(&newUsageDetailId).
  212. SetNotNilTotalTokens(&sumTotalTokens).
  213. SetNillableType(&logType).
  214. SetNotNilOrganizationID(&orgId).
  215. Save(ctx)
  216. }
  217. return err
  218. }
  219. func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
  220. authToken string, newUsageDetailId uint64, orgId uint64) error {
  221. sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
  222. if err == nil {
  223. err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
  224. }
  225. return err
  226. }
  227. // sum total_tokens from usagedetail by AuthToken
  228. func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
  229. authToken string) (uint64, error) {
  230. var predicates []predicate.UsageDetail
  231. predicates = append(predicates, usagedetail.BotIDEQ(authToken))
  232. var res []struct {
  233. Sum, Min, Max, Count uint64
  234. }
  235. totalTokens := uint64(0)
  236. var err error = nil
  237. err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
  238. ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
  239. if err == nil {
  240. if len(res) > 0 {
  241. totalTokens = res[0].Sum
  242. } else {
  243. totalTokens = 0
  244. }
  245. }
  246. return totalTokens, err
  247. }
  248. */
  249. func IsOpenaiModel(model string) bool {
  250. prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
  251. // 遍历所有前缀进行检查
  252. for _, prefix := range prefixes {
  253. if strings.HasPrefix(model, prefix) {
  254. return true
  255. }
  256. }
  257. return false
  258. }
  259. func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
  260. t := reflect.TypeOf(structPtr)
  261. v := reflect.ValueOf(structPtr)
  262. if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
  263. return "", nil, errors.New("input must be a pointer to a struct")
  264. }
  265. t = t.Elem()
  266. v = v.Elem()
  267. var fields []string
  268. var scanArgs []any
  269. ignoredMap := make(map[reflect.Type]struct{})
  270. // 检查调用者是否传入了任何要忽略的类型
  271. if len(ignoredTypes) > 0 {
  272. for _, ignoredType := range ignoredTypes {
  273. if ignoredType != nil { // 防止 nil 类型加入 map
  274. ignoredMap[ignoredType] = struct{}{}
  275. }
  276. }
  277. }
  278. for i := 0; i < t.NumField(); i++ {
  279. field := t.Field(i)
  280. value := v.Field(i)
  281. // Skip unexported fields
  282. if !field.IsExported() {
  283. continue
  284. }
  285. // Get json tag
  286. jsonTag := field.Tag.Get("json")
  287. if jsonTag == "-" || jsonTag == "" {
  288. continue
  289. }
  290. jsonParts := strings.Split(jsonTag, ",")
  291. jsonName := jsonParts[0]
  292. if jsonName == "" {
  293. continue
  294. }
  295. //传入了要忽略的类型时进行处理
  296. if len(ignoredMap) > 0 {
  297. fieldType := field.Type //获取字段的实际 Go 类型
  298. //如果字段是指针,我们通常关心的是指针指向的元素的类型
  299. if fieldType.Kind() == reflect.Ptr {
  300. fieldType = fieldType.Elem() // 获取元素类型
  301. }
  302. if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
  303. continue // 成员类型存在于忽略列表中则忽略
  304. }
  305. }
  306. fields = append(fields, jsonName)
  307. scanArgs = append(scanArgs, value.Addr().Interface())
  308. }
  309. return strings.Join(fields, ", "), scanArgs, nil
  310. }
  311. type StdChatClient struct {
  312. *openai.Client
  313. }
  314. func NewStdChatClient(apiKey string, apiBase string) *StdChatClient {
  315. opts := []option.RequestOption{}
  316. if len(apiKey) > 0 {
  317. opts = append(opts, option.WithAPIKey(apiKey))
  318. }
  319. opts = append(opts, option.WithBaseURL(apiBase))
  320. client := openai.NewClient(opts...)
  321. return &StdChatClient{&client}
  322. }
  323. func NewAiClient(apiKey string, apiBase string) *openai.Client {
  324. opts := []option.RequestOption{}
  325. if len(apiKey) > 0 {
  326. opts = append(opts, option.WithAPIKey(apiKey))
  327. }
  328. opts = append(opts, option.WithBaseURL(apiBase))
  329. client := openai.NewClient(opts...)
  330. return &client
  331. }
  332. func NewFastgptClient(apiKey string) *openai.Client {
  333. //http://fastgpt.ascrm.cn/api/v1/
  334. client := openai.NewClient(option.WithAPIKey(apiKey),
  335. option.WithBaseURL("http://fastgpt.ascrm.cn/api/v1/"))
  336. return &client
  337. }
  338. func NewDeepSeekClient(apiKey string) *openai.Client {
  339. client := openai.NewClient(option.WithAPIKey(apiKey),
  340. option.WithBaseURL("https://api.deepseek.com"))
  341. return &client
  342. }
  343. func DoChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  344. var (
  345. jsonBytes []byte
  346. err error
  347. )
  348. emptyParams := openai.ChatCompletionNewParams{}
  349. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  350. return nil, err
  351. }
  352. //fmt.Printf("In DoChatCompletions, req: '%s'\n", string(jsonBytes))
  353. //也许应该对请求体不规范成员名进行检查
  354. customResp := types.CompOpenApiResp{}
  355. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  356. respBodyOps := option.WithResponseBodyInto(&customResp)
  357. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps); err != nil {
  358. return nil, err
  359. }
  360. if customResp.FgtErrCode != nil && customResp.FgtErrStatusTxt != nil { //针对fastgpt出错但New()不返回错误的情况
  361. return nil, fmt.Errorf("%s(%d)", *customResp.FgtErrStatusTxt, *customResp.FgtErrCode)
  362. }
  363. return &customResp, nil
  364. }
  365. func DoChatCompletionsStream(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  366. var (
  367. jsonBytes []byte
  368. raw *http.Response
  369. //raw []byte
  370. ok bool
  371. hw http.ResponseWriter
  372. )
  373. hw, ok = contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter
  374. if !ok {
  375. return nil, errors.New("content get http writer err")
  376. }
  377. flusher, ok := (hw).(http.Flusher)
  378. if !ok {
  379. http.Error(hw, "Streaming unsupported!", http.StatusInternalServerError)
  380. }
  381. emptyParams := openai.ChatCompletionNewParams{}
  382. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  383. return nil, err
  384. }
  385. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  386. respBodyOps := option.WithResponseBodyInto(&raw)
  387. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps, option.WithJSONSet("stream", true)); err != nil {
  388. return nil, err
  389. }
  390. //设置流式输出头 http1.1
  391. hw.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  392. hw.Header().Set("Connection", "keep-alive")
  393. hw.Header().Set("Cache-Control", "no-cache")
  394. chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(raw), err)
  395. defer chatStream.Close()
  396. for chatStream.Next() {
  397. chunk := chatStream.Current()
  398. fmt.Fprintf(hw, "data:%s\n\n", chunk.Data.RAW)
  399. flusher.Flush()
  400. //time.Sleep(1 * time.Millisecond)
  401. }
  402. fmt.Fprintf(hw, "data:%s\n\n", "[DONE]")
  403. flusher.Flush()
  404. httpx.Ok(hw)
  405. return nil, nil
  406. }
  407. func NewChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  408. if chatInfo.Stream {
  409. return DoChatCompletionsStream(ctx, client, chatInfo)
  410. } else {
  411. return DoChatCompletions(ctx, client, chatInfo)
  412. }
  413. }
  414. func NewMismatchChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  415. client := NewAiClient(apiKey, apiBase)
  416. return NewChatCompletions(ctx, client, chatInfo)
  417. }
  418. func NewFastgptChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  419. client := NewAiClient(apiKey, apiBase)
  420. return NewChatCompletions(ctx, client, chatInfo)
  421. }
  422. func NewDeepSeekChatCompletions(ctx context.Context, apiKey string, chatInfo *types.CompApiReq, chatModel openai.ChatModel) (res *types.CompOpenApiResp, err error) {
  423. client := NewDeepSeekClient(apiKey)
  424. if chatModel != ChatModelDeepSeekV3 {
  425. chatModel = ChatModelDeepSeekR1
  426. }
  427. chatInfo.Model = chatModel
  428. return NewChatCompletions(ctx, client, chatInfo)
  429. }
  430. func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  431. var (
  432. jsonBytes []byte
  433. )
  434. emptyParams := openai.ChatCompletionNewParams{}
  435. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  436. return nil, err
  437. }
  438. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  439. //customResp := types.CompOpenApiResp{}
  440. //respBodyOps := option.WithResponseBodyInto(&customResp)
  441. //chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps, respBodyOps)
  442. chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps)
  443. // optionally, an accumulator helper can be used
  444. acc := openai.ChatCompletionAccumulator{}
  445. httpWriter, ok := ctx.Value("HttpResp-Writer").(http.ResponseWriter)
  446. if !ok {
  447. return nil, errors.New("content get writer err")
  448. }
  449. //httpWriter.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  450. //httpWriter.Header().Set("Connection", "keep-alive")
  451. //httpWriter.Header().Set("Cache-Control", "no-cache")
  452. idx := 0
  453. for chatStream.Next() {
  454. chunk := chatStream.Current()
  455. acc.AddChunk(chunk)
  456. fmt.Printf("=====>get %d chunk:%v\n", idx, chunk)
  457. if _, err := fmt.Fprintf(httpWriter, "%v", chunk); err != nil {
  458. fmt.Printf("Error writing to client:%v \n", err)
  459. break
  460. }
  461. if content, ok := acc.JustFinishedContent(); ok {
  462. println("Content stream finished:", content)
  463. }
  464. // if using tool calls
  465. if tool, ok := acc.JustFinishedToolCall(); ok {
  466. println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments)
  467. }
  468. if refusal, ok := acc.JustFinishedRefusal(); ok {
  469. println("Refusal stream finished:", refusal)
  470. }
  471. // it's best to use chunks after handling JustFinished events
  472. if len(chunk.Choices) > 0 {
  473. idx++
  474. fmt.Printf("idx:%d get =>'%s'\n", idx, chunk.Choices[0].Delta.Content)
  475. }
  476. }
  477. if err := chatStream.Err(); err != nil {
  478. return nil, err
  479. }
  480. return nil, nil
  481. }
  482. func GetWorkInfoByID(eventType string, workId string) (string, uint) {
  483. val, exist := fastgptWorkIdMap[workId]
  484. if !exist {
  485. val = fastgptWorkIdMap["default"]
  486. }
  487. return val.Id, val.Idx
  488. }
  489. // 获取workToken
  490. func GetWorkTokenByID(eventType string, workId string) string {
  491. id, _ := GetWorkInfoByID(eventType, workId)
  492. return id
  493. }
  494. // 获取workIdx
  495. func GetWorkIdxByID(eventType string, workId string) uint {
  496. _, idx := GetWorkInfoByID(eventType, workId)
  497. return idx
  498. }