chat_completions_logic.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package chat
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strconv"
  7. "wechat-api/ent"
  8. "wechat-api/internal/svc"
  9. "wechat-api/internal/types"
  10. "wechat-api/internal/utils/compapi"
  11. "wechat-api/internal/utils/contextkey"
  12. "wechat-api/ent/custom_types"
  13. "wechat-api/ent/predicate"
  14. "wechat-api/ent/usagedetail"
  15. "wechat-api/ent/usagetotal"
  16. "github.com/zeromicro/go-zero/core/logx"
  17. )
  18. type ChatCompletionsLogic struct {
  19. logx.Logger
  20. ctx context.Context
  21. svcCtx *svc.ServiceContext
  22. }
  23. func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic {
  24. return &ChatCompletionsLogic{
  25. Logger: logx.WithContext(ctx),
  26. ctx: ctx,
  27. svcCtx: svcCtx}
  28. }
  29. func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) {
  30. // todo: add your logic here and delete this line
  31. /*
  32. 1.鉴权获得token
  33. 2.必要参数检测及转换
  34. 3. 根据event_type选择不同处理路由
  35. */
  36. var (
  37. apiKeyObj *ent.ApiKey
  38. ok bool
  39. )
  40. workToken := compapi.GetWorkTokenByID(req.EventType, req.WorkId)
  41. apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx)
  42. if !ok {
  43. return nil, errors.New("content get token err")
  44. }
  45. /*
  46. fmt.Println("=========================================")
  47. fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
  48. fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
  49. fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  50. fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  51. fmt.Printf("workToken:'%s' because %s/%s\n", workToken, req.EventType, req.WorkId)
  52. fmt.Println("=========================================")
  53. */
  54. if len(apiKeyObj.OpenaiBase) == 0 || len(workToken) == 0 {
  55. return nil, errors.New("not auth info")
  56. }
  57. apiResp, err := l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
  58. if err == nil && apiResp != nil {
  59. l.doRequestLog(req, apiResp) //请求记录
  60. }
  61. return apiResp, err
  62. }
  63. func (l *ChatCompletionsLogic) doRequestLog(req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  64. authToken, ok := contextkey.OpenapiTokenKey.GetValue(l.ctx)
  65. if !ok {
  66. return errors.New("content get auth token err")
  67. }
  68. return l.appendUsageDetailLog(authToken, req, resp)
  69. }
  70. func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
  71. var predicates []predicate.UsageTotal
  72. predicates = append(predicates, usagetotal.BotIDEQ(authToken))
  73. return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
  74. }
  75. func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
  76. Id, err := l.getUsagetotalIdByToken(authToken)
  77. if err != nil && !ent.IsNotFound(err) {
  78. return err
  79. }
  80. if Id > 0 { //UsageTotal have record by newUsageDetailId
  81. _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
  82. SetTotalTokens(sumTotalTokens).
  83. SetEndIndex(newUsageDetailId).
  84. Save(l.ctx)
  85. } else { //create new record by newUsageDetailId
  86. logType := 5
  87. _, err = l.svcCtx.DB.UsageTotal.Create().
  88. SetNotNilBotID(&authToken).
  89. SetNotNilEndIndex(&newUsageDetailId).
  90. SetNotNilTotalTokens(&sumTotalTokens).
  91. SetNillableType(&logType).
  92. SetNotNilOrganizationID(&orgId).
  93. Save(l.ctx)
  94. }
  95. return err
  96. }
  97. func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
  98. sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
  99. if err == nil {
  100. err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
  101. }
  102. return err
  103. }
  104. // sum total_tokens from usagedetail by AuthToken
  105. func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
  106. var predicates []predicate.UsageDetail
  107. predicates = append(predicates, usagedetail.BotIDEQ(authToken))
  108. var res []struct {
  109. Sum, Min, Max, Count uint64
  110. }
  111. totalTokens := uint64(0)
  112. var err error = nil
  113. err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
  114. ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
  115. if err == nil {
  116. if len(res) > 0 {
  117. totalTokens = res[0].Sum
  118. } else {
  119. totalTokens = 0
  120. }
  121. }
  122. return totalTokens, err
  123. }
  124. func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  125. logType := 5
  126. workIdx := compapi.GetWorkIdxByID(req.EventType, req.WorkId)
  127. rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
  128. tmpId := 0
  129. tmpId, _ = strconv.Atoi(resp.ID)
  130. sessionId := uint64(tmpId)
  131. orgId := uint64(0)
  132. apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
  133. if ok {
  134. orgId = apiKeyObj.OrganizationID
  135. }
  136. promptTokens := uint64(resp.Usage.PromptTokens)
  137. completionToken := uint64(resp.Usage.CompletionTokens)
  138. totalTokens := promptTokens + completionToken
  139. fmt.Printf("====>n appendUsageDetailLog:%v|||%T\n", req.Messages[0].Content, req.Messages[0].Content)
  140. res, err := l.svcCtx.DB.UsageDetail.Create().
  141. SetNotNilType(&logType).
  142. SetNotNilBotID(&authToken).
  143. SetNotNilReceiverID(&req.EventType).
  144. SetNotNilSessionID(&sessionId).
  145. SetNillableApp(&workIdx).
  146. //SetNillableRequest(&req.Messages[0].Content).
  147. SetNillableResponse(&resp.Choices[0].Message.Content).
  148. SetNillableOrganizationID(&orgId).
  149. SetOriginalData(rawReqResp).
  150. SetNillablePromptTokens(&promptTokens).
  151. SetNillableCompletionTokens(&completionToken).
  152. SetNillableTotalTokens(&totalTokens).
  153. Save(l.ctx)
  154. if err == nil { //插入UsageDetai之后再统计UsageTotal
  155. l.updateUsageTotal(authToken, res.ID, orgId)
  156. }
  157. return err
  158. }
  159. func (l *ChatCompletionsLogic) workForFastgpt(req *types.CompApiReq, apiKey string, apiBase string) (resp *types.CompOpenApiResp, err error) {
  160. //apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP"
  161. if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 {
  162. req.FastgptChatId = req.ChatId
  163. }
  164. if len(req.Model) > 0 {
  165. if req.Variables == nil {
  166. req.Variables = make(map[string]string)
  167. }
  168. req.Variables["model"] = req.Model
  169. }
  170. return compapi.NewFastgptChatCompletions(l.ctx, apiKey, apiBase, req)
  171. }