func.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. package compapi
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "wechat-api/ent"
  12. "wechat-api/ent/custom_types"
  13. "wechat-api/ent/predicate"
  14. "wechat-api/ent/usagedetail"
  15. "wechat-api/ent/usagetotal"
  16. "wechat-api/internal/config"
  17. "wechat-api/internal/types"
  18. "wechat-api/internal/utils/contextkey"
  19. openai "github.com/openai/openai-go"
  20. "github.com/openai/openai-go/option"
  21. "github.com/openai/openai-go/packages/ssestream"
  22. "github.com/redis/go-redis/v9"
  23. "github.com/zeromicro/go-zero/rest/httpx"
  24. )
  25. type ServiceContext struct {
  26. Config config.Config
  27. DB *ent.Client
  28. Rds redis.UniversalClient
  29. }
  30. func AppendUsageDetailLog(ctx context.Context, svcCtx *ServiceContext,
  31. authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  32. logType := 5
  33. rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
  34. tmpId := 0
  35. tmpId, _ = strconv.Atoi(resp.ID)
  36. sessionId := uint64(tmpId)
  37. orgId := uint64(0)
  38. apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(ctx)
  39. if ok {
  40. orgId = apiKeyObj.OrganizationID
  41. }
  42. promptTokens := uint64(resp.Usage.PromptTokens)
  43. completionToken := uint64(resp.Usage.CompletionTokens)
  44. totalTokens := promptTokens + completionToken
  45. msgContent := getMessageContentStr(req.Messages[0].Content)
  46. _, _, _ = logType, sessionId, totalTokens
  47. res, err := svcCtx.DB.UsageDetail.Create().
  48. SetNotNilType(&logType).
  49. SetNotNilBotID(&authToken).
  50. SetNotNilReceiverID(&req.EventType).
  51. SetNotNilSessionID(&sessionId).
  52. SetNillableRequest(&msgContent).
  53. SetNillableResponse(&resp.Choices[0].Message.Content).
  54. SetNillableOrganizationID(&orgId).
  55. SetOriginalData(rawReqResp).
  56. SetNillablePromptTokens(&promptTokens).
  57. SetNillableCompletionTokens(&completionToken).
  58. SetNillableTotalTokens(&totalTokens).
  59. Save(ctx)
  60. if err == nil { //插入UsageDetai之后再统计UsageTotal
  61. updateUsageTotal(ctx, svcCtx, authToken, res.ID, orgId)
  62. }
  63. return err
  64. }
  65. func getUsagetotalIdByToken(ctx context.Context, svcCtx *ServiceContext,
  66. authToken string) (uint64, error) {
  67. var predicates []predicate.UsageTotal
  68. predicates = append(predicates, usagetotal.BotIDEQ(authToken))
  69. return svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(ctx)
  70. }
  71. func replaceUsagetotalTokens(ctx context.Context, svcCtx *ServiceContext,
  72. authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
  73. Id, err := getUsagetotalIdByToken(ctx, svcCtx, authToken)
  74. if err != nil && !ent.IsNotFound(err) {
  75. return err
  76. }
  77. if Id > 0 { //UsageTotal have record by newUsageDetailId
  78. _, err = svcCtx.DB.UsageTotal.UpdateOneID(Id).
  79. SetTotalTokens(sumTotalTokens).
  80. SetEndIndex(newUsageDetailId).
  81. Save(ctx)
  82. } else { //create new record by newUsageDetailId
  83. logType := 5
  84. _, err = svcCtx.DB.UsageTotal.Create().
  85. SetNotNilBotID(&authToken).
  86. SetNotNilEndIndex(&newUsageDetailId).
  87. SetNotNilTotalTokens(&sumTotalTokens).
  88. SetNillableType(&logType).
  89. SetNotNilOrganizationID(&orgId).
  90. Save(ctx)
  91. }
  92. return err
  93. }
  94. func updateUsageTotal(ctx context.Context, svcCtx *ServiceContext,
  95. authToken string, newUsageDetailId uint64, orgId uint64) error {
  96. sumTotalTokens, err := sumTotalTokensByAuthToken(ctx, svcCtx, authToken) //首先sum UsageDetail的TotalTokens
  97. if err == nil {
  98. err = replaceUsagetotalTokens(ctx, svcCtx, authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
  99. }
  100. return err
  101. }
  102. // sum total_tokens from usagedetail by AuthToken
  103. func sumTotalTokensByAuthToken(ctx context.Context, svcCtx *ServiceContext,
  104. authToken string) (uint64, error) {
  105. var predicates []predicate.UsageDetail
  106. predicates = append(predicates, usagedetail.BotIDEQ(authToken))
  107. var res []struct {
  108. Sum, Min, Max, Count uint64
  109. }
  110. totalTokens := uint64(0)
  111. var err error = nil
  112. err = svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
  113. ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(ctx, &res)
  114. if err == nil {
  115. if len(res) > 0 {
  116. totalTokens = res[0].Sum
  117. } else {
  118. totalTokens = 0
  119. }
  120. }
  121. return totalTokens, err
  122. }
  123. func getMessageContentStr(input any) string {
  124. str := ""
  125. switch val := input.(type) {
  126. case string:
  127. str = val
  128. case []interface{}:
  129. if len(val) > 0 {
  130. if valc, ok := val[0].(map[string]interface{}); ok {
  131. if valcc, ok := valc["text"]; ok {
  132. str, _ = valcc.(string)
  133. }
  134. }
  135. }
  136. }
  137. return str
  138. }
  139. func IsOpenaiModel(model string) bool {
  140. prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
  141. // 遍历所有前缀进行检查
  142. for _, prefix := range prefixes {
  143. if strings.HasPrefix(model, prefix) {
  144. return true
  145. }
  146. }
  147. return false
  148. }
  149. func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
  150. t := reflect.TypeOf(structPtr)
  151. v := reflect.ValueOf(structPtr)
  152. if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
  153. return "", nil, errors.New("input must be a pointer to a struct")
  154. }
  155. t = t.Elem()
  156. v = v.Elem()
  157. var fields []string
  158. var scanArgs []any
  159. ignoredMap := make(map[reflect.Type]struct{})
  160. // 检查调用者是否传入了任何要忽略的类型
  161. if len(ignoredTypes) > 0 {
  162. for _, ignoredType := range ignoredTypes {
  163. if ignoredType != nil { // 防止 nil 类型加入 map
  164. ignoredMap[ignoredType] = struct{}{}
  165. }
  166. }
  167. }
  168. for i := 0; i < t.NumField(); i++ {
  169. field := t.Field(i)
  170. value := v.Field(i)
  171. // Skip unexported fields
  172. if !field.IsExported() {
  173. continue
  174. }
  175. // Get json tag
  176. jsonTag := field.Tag.Get("json")
  177. if jsonTag == "-" || jsonTag == "" {
  178. continue
  179. }
  180. jsonParts := strings.Split(jsonTag, ",")
  181. jsonName := jsonParts[0]
  182. if jsonName == "" {
  183. continue
  184. }
  185. //传入了要忽略的类型时进行处理
  186. if len(ignoredMap) > 0 {
  187. fieldType := field.Type //获取字段的实际 Go 类型
  188. //如果字段是指针,我们通常关心的是指针指向的元素的类型
  189. if fieldType.Kind() == reflect.Ptr {
  190. fieldType = fieldType.Elem() // 获取元素类型
  191. }
  192. if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
  193. continue // 成员类型存在于忽略列表中则忽略
  194. }
  195. }
  196. fields = append(fields, jsonName)
  197. scanArgs = append(scanArgs, value.Addr().Interface())
  198. }
  199. return strings.Join(fields, ", "), scanArgs, nil
  200. }
  201. type StdChatClient struct {
  202. *openai.Client
  203. }
  204. func NewStdChatClient(apiKey string, apiBase string) *StdChatClient {
  205. opts := []option.RequestOption{}
  206. if len(apiKey) > 0 {
  207. opts = append(opts, option.WithAPIKey(apiKey))
  208. }
  209. opts = append(opts, option.WithBaseURL(apiBase))
  210. client := openai.NewClient(opts...)
  211. return &StdChatClient{&client}
  212. }
  213. func NewAiClient(apiKey string, apiBase string) *openai.Client {
  214. opts := []option.RequestOption{}
  215. if len(apiKey) > 0 {
  216. opts = append(opts, option.WithAPIKey(apiKey))
  217. }
  218. opts = append(opts, option.WithBaseURL(apiBase))
  219. client := openai.NewClient(opts...)
  220. return &client
  221. }
  222. func NewFastgptClient(apiKey string) *openai.Client {
  223. //http://fastgpt.ascrm.cn/api/v1/
  224. client := openai.NewClient(option.WithAPIKey(apiKey),
  225. option.WithBaseURL("http://fastgpt.ascrm.cn/api/v1/"))
  226. return &client
  227. }
  228. func NewDeepSeekClient(apiKey string) *openai.Client {
  229. client := openai.NewClient(option.WithAPIKey(apiKey),
  230. option.WithBaseURL("https://api.deepseek.com"))
  231. return &client
  232. }
  233. func DoChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  234. var (
  235. jsonBytes []byte
  236. err error
  237. )
  238. emptyParams := openai.ChatCompletionNewParams{}
  239. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  240. return nil, err
  241. }
  242. //fmt.Printf("In DoChatCompletions, req: '%s'\n", string(jsonBytes))
  243. //也许应该对请求体不规范成员名进行检查
  244. customResp := types.CompOpenApiResp{}
  245. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  246. respBodyOps := option.WithResponseBodyInto(&customResp)
  247. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps); err != nil {
  248. return nil, err
  249. }
  250. if customResp.FgtErrCode != nil && customResp.FgtErrStatusTxt != nil { //针对fastgpt出错但New()不返回错误的情况
  251. return nil, fmt.Errorf("%s(%d)", *customResp.FgtErrStatusTxt, *customResp.FgtErrCode)
  252. }
  253. return &customResp, nil
  254. }
  255. func DoChatCompletionsStream(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  256. var (
  257. jsonBytes []byte
  258. raw *http.Response
  259. //raw []byte
  260. ok bool
  261. hw http.ResponseWriter
  262. )
  263. hw, ok = contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter
  264. if !ok {
  265. return nil, errors.New("content get http writer err")
  266. }
  267. flusher, ok := (hw).(http.Flusher)
  268. if !ok {
  269. http.Error(hw, "Streaming unsupported!", http.StatusInternalServerError)
  270. }
  271. emptyParams := openai.ChatCompletionNewParams{}
  272. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  273. return nil, err
  274. }
  275. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  276. respBodyOps := option.WithResponseBodyInto(&raw)
  277. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps, option.WithJSONSet("stream", true)); err != nil {
  278. return nil, err
  279. }
  280. //设置流式输出头 http1.1
  281. hw.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  282. hw.Header().Set("Connection", "keep-alive")
  283. hw.Header().Set("Cache-Control", "no-cache")
  284. chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(raw), err)
  285. defer chatStream.Close()
  286. for chatStream.Next() {
  287. chunk := chatStream.Current()
  288. fmt.Fprintf(hw, "data:%s\n\n", chunk.Data.RAW)
  289. flusher.Flush()
  290. //time.Sleep(1 * time.Millisecond)
  291. }
  292. fmt.Fprintf(hw, "data:%s\n\n", "[DONE]")
  293. flusher.Flush()
  294. httpx.Ok(hw)
  295. return nil, nil
  296. }
  297. func NewChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  298. if chatInfo.Stream {
  299. return DoChatCompletionsStream(ctx, client, chatInfo)
  300. } else {
  301. return DoChatCompletions(ctx, client, chatInfo)
  302. }
  303. }
  304. func NewMismatchChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  305. client := NewAiClient(apiKey, apiBase)
  306. return NewChatCompletions(ctx, client, chatInfo)
  307. }
  308. func NewFastgptChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  309. client := NewAiClient(apiKey, apiBase)
  310. return NewChatCompletions(ctx, client, chatInfo)
  311. }
  312. func NewDeepSeekChatCompletions(ctx context.Context, apiKey string, chatInfo *types.CompApiReq, chatModel openai.ChatModel) (res *types.CompOpenApiResp, err error) {
  313. client := NewDeepSeekClient(apiKey)
  314. if chatModel != ChatModelDeepSeekV3 {
  315. chatModel = ChatModelDeepSeekR1
  316. }
  317. chatInfo.Model = chatModel
  318. return NewChatCompletions(ctx, client, chatInfo)
  319. }
  320. func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  321. var (
  322. jsonBytes []byte
  323. )
  324. emptyParams := openai.ChatCompletionNewParams{}
  325. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  326. return nil, err
  327. }
  328. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  329. //customResp := types.CompOpenApiResp{}
  330. //respBodyOps := option.WithResponseBodyInto(&customResp)
  331. //chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps, respBodyOps)
  332. chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps)
  333. // optionally, an accumulator helper can be used
  334. acc := openai.ChatCompletionAccumulator{}
  335. httpWriter, ok := ctx.Value("HttpResp-Writer").(http.ResponseWriter)
  336. if !ok {
  337. return nil, errors.New("content get writer err")
  338. }
  339. //httpWriter.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  340. //httpWriter.Header().Set("Connection", "keep-alive")
  341. //httpWriter.Header().Set("Cache-Control", "no-cache")
  342. idx := 0
  343. for chatStream.Next() {
  344. chunk := chatStream.Current()
  345. acc.AddChunk(chunk)
  346. fmt.Printf("=====>get %d chunk:%v\n", idx, chunk)
  347. if _, err := fmt.Fprintf(httpWriter, "%v", chunk); err != nil {
  348. fmt.Printf("Error writing to client:%v \n", err)
  349. break
  350. }
  351. if content, ok := acc.JustFinishedContent(); ok {
  352. println("Content stream finished:", content)
  353. }
  354. // if using tool calls
  355. if tool, ok := acc.JustFinishedToolCall(); ok {
  356. println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments)
  357. }
  358. if refusal, ok := acc.JustFinishedRefusal(); ok {
  359. println("Refusal stream finished:", refusal)
  360. }
  361. // it's best to use chunks after handling JustFinished events
  362. if len(chunk.Choices) > 0 {
  363. idx++
  364. fmt.Printf("idx:%d get =>'%s'\n", idx, chunk.Choices[0].Delta.Content)
  365. }
  366. }
  367. if err := chatStream.Err(); err != nil {
  368. return nil, err
  369. }
  370. return nil, nil
  371. }
  372. func GetWorkInfoByID(eventType string, workId string) (string, uint) {
  373. val, exist := fastgptWorkIdMap[workId]
  374. if !exist {
  375. val = fastgptWorkIdMap["default"]
  376. }
  377. return val.Id, val.Idx
  378. }
  379. // 获取workToken
  380. func GetWorkTokenByID(eventType string, workId string) string {
  381. id, _ := GetWorkInfoByID(eventType, workId)
  382. return id
  383. }
  384. // 获取workIdx
  385. func GetWorkIdxByID(eventType string, workId string) uint {
  386. _, idx := GetWorkInfoByID(eventType, workId)
  387. return idx
  388. }