chat_completions_logic.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. package chat
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "wechat-api/ent"
  8. "wechat-api/internal/svc"
  9. "wechat-api/internal/types"
  10. "wechat-api/internal/utils/compapi"
  11. "wechat-api/internal/utils/contextkey"
  12. "wechat-api/internal/utils/typekit"
  13. "github.com/zeromicro/go-zero/core/logx"
  14. )
  15. type baseLogicWorkflow interface {
  16. AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error
  17. DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error)
  18. AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error
  19. AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey)
  20. }
  21. type ChatCompletionsLogic struct {
  22. logx.Logger
  23. ctx context.Context
  24. svcCtx *svc.ServiceContext
  25. }
  26. type FastgptChatLogic struct {
  27. ChatCompletionsLogic
  28. }
  29. type MismatchChatLogic struct {
  30. ChatCompletionsLogic
  31. }
  32. type IntentChatLogic struct {
  33. ChatCompletionsLogic
  34. }
  35. /*
  36. 扩展LogicChat工厂方法
  37. 返回根据不同EventType相关的扩展LogicChat的baseLogicWorkflow接口形式
  38. 每增加一个新的扩展LogicChat结构,需要在此函数中增加相应的创建语句
  39. */
  40. func (l *ChatCompletionsLogic) getLogicWorkflow(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (baseLogicWorkflow, error) {
  41. var (
  42. err error
  43. wf baseLogicWorkflow
  44. )
  45. if apiKeyObj.Edges.Agent.Type != 2 {
  46. err = fmt.Errorf("api agent type not support(%d)", apiKeyObj.Edges.Agent.Type)
  47. } else if req.EventType == "mismatch" {
  48. wf = &MismatchChatLogic{ChatCompletionsLogic: *l}
  49. } else if req.EventType == "intent" {
  50. wf = &IntentChatLogic{ChatCompletionsLogic: *l}
  51. } else {
  52. wf = &FastgptChatLogic{ChatCompletionsLogic: *l}
  53. }
  54. return wf, err
  55. }
  56. func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic {
  57. return &ChatCompletionsLogic{
  58. Logger: logx.WithContext(ctx),
  59. ctx: ctx,
  60. svcCtx: svcCtx}
  61. }
  62. func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) {
  63. // todo: add your logic here and delete this line
  64. var (
  65. apiKeyObj *ent.ApiKey
  66. ok bool
  67. )
  68. asyncMode = false
  69. //从上下文中获取鉴权中间件埋下的apiAuthInfo
  70. apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx)
  71. if !ok {
  72. return asyncMode, nil, errors.New("content get auth info err")
  73. }
  74. //根据请求产生相关的工作流接口集
  75. wf, err := l.getLogicWorkflow(apiKeyObj, req)
  76. if err != nil {
  77. return false, nil, err
  78. }
  79. //请求前临时观察相关参数
  80. //PreChatVars(req, apiKeyObj, wf)
  81. //微调部分请求参数
  82. wf.AdjustRequest(req, apiKeyObj)
  83. if isAsyncReqest(req) { //异步请求处理模式
  84. asyncMode = true
  85. err = wf.AppendAsyncRequest(apiKeyObj, req)
  86. } else { //同步请求处理模式
  87. resp, err = wf.DoSyncRequest(apiKeyObj, req)
  88. if err == nil && resp != nil && len(resp.Choices) > 0 {
  89. wf.AppendUsageDetailLog(apiKeyObj.Key, req, resp) //请求记录
  90. } else if resp != nil && len(resp.Choices) == 0 {
  91. err = errors.New("返回结果缺失,请检查访问地址及权限")
  92. }
  93. }
  94. return asyncMode, resp, err
  95. }
  96. func (l *ChatCompletionsLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
  97. if len(req.EventType) == 0 {
  98. req.EventType = "fastgpt"
  99. }
  100. if len(req.Model) == 0 && len(apiKeyObj.Edges.Agent.Model) > 0 {
  101. req.Model = apiKeyObj.Edges.Agent.Model
  102. }
  103. //异步任务相关参数调整
  104. if req.IsBatch {
  105. //流模式暂时不支持异步模式
  106. //Callback格式非法则取消批量模式
  107. if req.Stream || !compapi.IsValidURL(&req.Callback, true) {
  108. req.IsBatch = false
  109. }
  110. }
  111. }
  112. func (l *ChatCompletionsLogic) DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error) {
  113. resp, err := compapi.NewClient(l.ctx, compapi.WithApiBase(apiKeyObj.Edges.Agent.APIBase),
  114. compapi.WithApiKey(apiKeyObj.Edges.Agent.APIKey)).
  115. Chat(req)
  116. if err != nil {
  117. return nil, err
  118. }
  119. //以下临时测试case
  120. //humanSeeValidResult(l.ctx, req, resp)
  121. return resp, err
  122. }
  123. func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error {
  124. rawReqBs, err := json.Marshal(*req)
  125. if err != nil {
  126. return err
  127. }
  128. rawReqStr := string(rawReqBs)
  129. res, err := l.svcCtx.DB.CompapiAsynctask.Create().
  130. SetNotNilAuthToken(&apiKeyObj.Key).
  131. SetNotNilOpenaiBase(&apiKeyObj.Edges.Agent.APIBase).
  132. SetNotNilOpenaiKey(&apiKeyObj.Edges.Agent.APIKey).
  133. SetNotNilOrganizationID(&apiKeyObj.OrganizationID).
  134. SetNotNilEventType(&req.EventType).
  135. SetNillableModel(&req.Model).
  136. SetNillableChatID(&req.ChatId).
  137. SetNillableResponseChatItemID(&req.ResponseChatItemId).
  138. SetNotNilRequestRaw(&rawReqStr).
  139. SetNotNilCallbackURL(&req.Callback).
  140. Save(l.ctx)
  141. if err == nil {
  142. logx.Infof("appendAsyncRequest succ,get id:%d", res.ID)
  143. }
  144. return err
  145. }
  146. func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  147. svcCtx := &compapi.ServiceContext{Config: l.svcCtx.Config, DB: l.svcCtx.DB,
  148. Rds: l.svcCtx.Rds}
  149. return compapi.AppendUsageDetailLog(l.ctx, svcCtx, authToken, req, resp)
  150. }
  151. func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
  152. l.ChatCompletionsLogic.AdjustRequest(req, apiKeyObj) //先父类的参数调整
  153. if req.EventType != "fastgpt" {
  154. return
  155. }
  156. if len(req.Model) > 0 {
  157. if req.Variables == nil {
  158. req.Variables = make(map[string]string)
  159. }
  160. req.Variables["model"] = req.Model
  161. }
  162. if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 {
  163. req.FastgptChatId = req.ChatId
  164. } else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 {
  165. req.ChatId = req.FastgptChatId
  166. }
  167. }
  168. func humanSeePreChatVars(req *types.CompApiReq, apiKeyObj *ent.ApiKey, wf baseLogicWorkflow) {
  169. fmt.Println("=========================================")
  170. fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
  171. fmt.Printf("Auth Token:'%s'\n", apiKeyObj.Key)
  172. fmt.Printf("ApiKey AgentID:%d\n", apiKeyObj.AgentID)
  173. fmt.Printf("ApiKey APIBase:'%s'\n", apiKeyObj.Edges.Agent.APIBase)
  174. fmt.Printf("ApiKey APIKey:'%s'\n", apiKeyObj.Edges.Agent.APIKey)
  175. fmt.Printf("ApiKey Type:%d\n", apiKeyObj.Edges.Agent.Type)
  176. fmt.Printf("ApiKey Model:'%s'\n", apiKeyObj.Edges.Agent.Model)
  177. fmt.Printf("EventType:'%s'\n", req.EventType)
  178. fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId)
  179. fmt.Println("=========================================")
  180. switch wf.(type) {
  181. case *MismatchChatLogic:
  182. fmt.Println("MismatchChatLogic Flow.....")
  183. case *IntentChatLogic:
  184. fmt.Println("IntentChatLogic Flow.....")
  185. case *FastgptChatLogic:
  186. fmt.Println("FastgptChatLogic Flow.....")
  187. default:
  188. fmt.Println("Other Flow.....")
  189. }
  190. }
  191. func humanSeeValidResult(ctx context.Context, req *types.CompApiReq, resp *types.CompOpenApiResp) {
  192. clientFace, err := compapi.NewClient(ctx).GetClientActFace(req.EventType)
  193. if err != nil {
  194. fmt.Println(err)
  195. return
  196. }
  197. taskData := ent.CompapiAsynctask{}
  198. taskData.ID = 1234
  199. taskData.ResponseChatItemID = req.ResponseChatItemId
  200. taskData.EventType = req.EventType
  201. taskData.ChatID = req.ChatId
  202. taskData.ResponseRaw, err = resp.ToString()
  203. if err != nil {
  204. fmt.Println(err)
  205. return
  206. }
  207. var bs []byte
  208. bs, err = clientFace.CallbackPrepare(&taskData)
  209. if err != nil {
  210. fmt.Println(err)
  211. }
  212. fmt.Printf("当前请求EventType:%s\n", req.EventType)
  213. fmt.Printf("当前请求MODEL:%s\n", req.Model)
  214. fmt.Println("client.CallbackPrepare结果[]byte版.........")
  215. fmt.Println(string(bs))
  216. nres := map[string]any{}
  217. err = json.Unmarshal(bs, &nres)
  218. if err != nil {
  219. fmt.Println(err)
  220. }
  221. fmt.Println("client.CallbackPrepare结果map[string]any版.........")
  222. fmt.Println(typekit.PrettyPrint(nres))
  223. config := compapi.ResponseFormatConfig{}
  224. if req.EventType == "mismatch" {
  225. clientInst := clientFace.(*compapi.MismatchClient)
  226. config = clientInst.ResponseFormatSetting(req)
  227. } else if req.EventType == "intent" {
  228. clientInst := clientFace.(*compapi.IntentClient)
  229. config = clientInst.ResponseFormatSetting(req)
  230. } else {
  231. return
  232. }
  233. err = compapi.NewChatResult(resp).ParseContentAs(&config.ResformatStruct)
  234. if err != nil {
  235. fmt.Println(err)
  236. }
  237. nres["content"] = config.ResformatStruct
  238. fmt.Println("client.CallbackPrepare结果ParseContentAs定制版.........")
  239. fmt.Println(typekit.PrettyPrint(nres))
  240. }
  241. func isAsyncReqest(req *types.CompApiReq) bool {
  242. return req.IsBatch
  243. }