compapi_callback.go 12 KB


  1. package crontask
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "hash/fnv"
  8. "runtime"
  9. "strconv"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "wechat-api/ent"
  14. "wechat-api/ent/compapiasynctask"
  15. "wechat-api/ent/predicate"
  16. "wechat-api/internal/types"
  17. "wechat-api/internal/utils/compapi"
  18. "github.com/openai/openai-go/option"
  19. "github.com/zeromicro/go-zero/core/logx"
  20. )
  21. const (
  22. Task_Ready = 10 //任务就绪
  23. ReqApi_Done = 20 //请求API完成
  24. Callback_Done = 30 //请求回调完成
  25. All_Done = 30 //全部完成
  26. Task_Suspend = 60 //任务暂停
  27. Task_Fail = 70 //任务失败
  28. //MaxWorker = 5 //最大并发worker数量
  29. MaxLoadTask = 1000 //一次最大获取任务数量
  30. LoopTryCount = 3 //循环体内重试次数
  31. ErrTaskTryCount = 3 //最大允许错误任务重试次数
  32. DefaultChatId = "FOO001"
  33. )
  34. type Task struct {
  35. Data *ent.CompapiAsynctask
  36. Idx int
  37. Code int
  38. }
  39. // 带会话管理的任务通道组
  40. type TaskDispatcher struct {
  41. mu sync.Mutex
  42. workerChs []chan Task // 每个worker独立通道
  43. }
  44. func NewTaskDispatcher(workerCount int) *TaskDispatcher {
  45. td := &TaskDispatcher{
  46. workerChs: make([]chan Task, workerCount),
  47. }
  48. // 初始化worker通道
  49. for i := range td.workerChs {
  50. td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲
  51. //fmt.Printf("make worker chan:%d\n", i+1)
  52. }
  53. return td
  54. }
  55. // 哈希分片算法(确保相同chatid路由到同一个worker)
  56. func (td *TaskDispatcher) getWorkerChannel(chatId string) (int, chan Task) {
  57. if len(chatId) == 0 {
  58. chatId = DefaultChatId
  59. }
  60. h := fnv.New32a()
  61. h.Write([]byte(chatId))
  62. idx := int(h.Sum32()) % len(td.workerChs)
  63. return idx, td.workerChs[idx]
  64. }
  65. // 分配任务到对应的消费channel
  66. func (td *TaskDispatcher) Dispatch(task Task) {
  67. td.mu.Lock()
  68. defer td.mu.Unlock()
  69. // 根据chatId哈希获得其对应的workerChan
  70. workerChId, workerCh := td.getWorkerChannel(task.Data.ChatID)
  71. // 将任务送入该chatid的workerChan
  72. workerCh <- task
  73. logx.Debugf("producer:ChatId:[%s] Task Push to WorkerChan:%d", task.Data.ChatID, workerChId+1)
  74. }
  75. func getGoroutineId() (int64, error) {
  76. // 堆栈结果中需要消除的前缀符
  77. var goroutineSpace = []byte("goroutine ")
  78. bs := make([]byte, 128)
  79. bs = bs[:runtime.Stack(bs, false)]
  80. bs = bytes.TrimPrefix(bs, goroutineSpace)
  81. i := bytes.IndexByte(bs, ' ')
  82. if i < 0 {
  83. return -1, errors.New("get current goroutine id failed")
  84. }
  85. return strconv.ParseInt(string(bs[:i]), 10, 64)
  86. }
  87. func (l *CronTask) compApiCallback(MaxWorker int) {
  88. var (
  89. wg sync.WaitGroup
  90. produced int64 //生产数量(原子计数器)
  91. consumed int64 //消费数量(原子计数器)
  92. )
  93. //创建任务分发器
  94. if MaxWorker <= 0 {
  95. MaxWorker = 2
  96. }
  97. dispatcher := NewTaskDispatcher(MaxWorker)
  98. //启动消费者
  99. for i, ch := range dispatcher.workerChs {
  100. wg.Add(1)
  101. go func(workerID int, taskCh chan Task) {
  102. defer wg.Done()
  103. workerId, _ := getGoroutineId()
  104. logx.Infof("Consumer Goroutine:%d start......\n", workerId)
  105. for task := range taskCh {
  106. l.processTask(workerID, task)
  107. atomic.AddInt64(&consumed, 1)
  108. }
  109. }(i+1, ch)
  110. }
  111. // 生产者
  112. wg.Add(1)
  113. go func() {
  114. defer wg.Done()
  115. workerId, _ := getGoroutineId()
  116. logx.Infof("Producer Goroutine:%d start......\n", workerId)
  117. //获得待处理异步任务列表
  118. tasks, err := l.getAsyncReqTaskList()
  119. if err != nil {
  120. logx.Errorf("getAsyncReqTaskList err:%s", err)
  121. return
  122. }
  123. // 分发任务
  124. for _, task := range tasks {
  125. dispatcher.Dispatch(task)
  126. atomic.AddInt64(&produced, 1)
  127. }
  128. logx.Infof("📦Producer Goroutine:%d 此批次共创建任务%d件", workerId, len(tasks))
  129. // 关闭所有会话通道
  130. dispatcher.mu.Lock()
  131. for _, ch := range dispatcher.workerChs {
  132. _ = ch
  133. close(ch)
  134. }
  135. dispatcher.mu.Unlock()
  136. }()
  137. wg.Wait()
  138. if atomic.LoadInt64(&produced) > 0 {
  139. logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced),
  140. (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
  141. }
  142. }
  143. func (l *CronTask) getAsyncReqTaskList() ([]Task, error) {
  144. var predicates []predicate.CompapiAsynctask
  145. predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
  146. var tasks []Task
  147. res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
  148. Order(ent.Asc(compapiasynctask.FieldID)).
  149. Limit(MaxLoadTask).
  150. All(l.ctx)
  151. if err == nil {
  152. for idx, val := range res {
  153. tasks = append(tasks, Task{Data: val, Idx: idx})
  154. }
  155. }
  156. return tasks, err
  157. }
  158. func (l *CronTask) processTask(workerID int, task Task) {
  159. //fmt.Printf("In processTask,Consumer(%d) dealing\n", workerID)
  160. //fmt.Printf("Task Detail: User(%s/%s) Async Call %s\n", task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase)
  161. _ = workerID
  162. var err error
  163. rt := 0
  164. for {
  165. if task.Data.TaskStatus >= All_Done {
  166. break
  167. }
  168. switch task.Data.TaskStatus {
  169. case Task_Ready:
  170. //请求API平台
  171. // succ: taskStatus Task_Ready => ReqApi_Done
  172. // fail: taskStatus保持当前不变或Task_Fail
  173. rt, err = l.requestAPI(task.Data)
  174. case ReqApi_Done:
  175. //结果回调
  176. // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
  177. // fail: taskStatus保持当前不变或Task_Fail
  178. rt, err = l.requestCallback(task.Data)
  179. }
  180. if err != nil {
  181. //收集错误
  182. if rt == 0 {
  183. //不可恢复错误处理....
  184. }
  185. //fmt.Println("===>ERROR:", err, ",Task Ignore...")
  186. return //先暂时忽略处理,也许应按错误类型分别对待
  187. }
  188. }
  189. }
  190. func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
  191. workerId, _ := getGoroutineId()
  192. logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
  193. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  194. return 1, errors.New("too many err retry")
  195. }
  196. if taskData.TaskStatus != ReqApi_Done {
  197. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  198. }
  199. req := types.CompOpenApiResp{}
  200. if len(taskData.ResponseRaw) == 0 {
  201. return 0, errors.New("call api response empty")
  202. }
  203. if len(taskData.CallbackURL) == 0 {
  204. return 0, errors.New("callback url empty")
  205. }
  206. if err := json.Unmarshal([]byte(taskData.ResponseRaw), &req); err != nil {
  207. return 0, err
  208. }
  209. //先开启事务更新任务状态 => Callback_Done(回调完成)
  210. tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done)
  211. if err != nil {
  212. return 0, err
  213. }
  214. //请求预先给定的callback_url
  215. client := compapi.NewAiClient("", taskData.CallbackURL)
  216. //emptyParams := openai.ChatCompletionNewParams{}
  217. customResp := types.BaseDataInfo{}
  218. opts := []option.RequestOption{option.WithResponseBodyInto(&customResp)}
  219. opts = append(opts, option.WithRequestBody("application/json", []byte(taskData.ResponseRaw)))
  220. for i := range LoopTryCount { //重试机制
  221. err = client.Post(l.ctx, taskData.CallbackURL, nil, nil, opts...)
  222. //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...)
  223. if err == nil {
  224. //call succ
  225. break
  226. }
  227. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  228. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  229. time.Sleep(time.Duration(1+i*5) * time.Second)
  230. }
  231. if err != nil {
  232. _ = tx.Rollback() //回滚之前更新状态
  233. //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId)
  234. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  235. et := 1
  236. if err1 != nil {
  237. et = 0
  238. }
  239. return et, err
  240. }
  241. err = tx.Commit() //事务提交
  242. //fmt.Printf("Worker:%d requestCallback事务提交\n", workerId)
  243. if err != nil {
  244. return 0, err
  245. }
  246. taskData.TaskStatus = Callback_Done //状态迁移
  247. return 1, nil
  248. }
  249. func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
  250. workerId, _ := getGoroutineId()
  251. logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
  252. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  253. return 1, errors.New("too many err retry")
  254. }
  255. if taskData.TaskStatus != Task_Ready {
  256. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  257. }
  258. if taskData.EventType != "fastgpt" {
  259. return 0, fmt.Errorf("event type :'%s' not support", taskData.EventType)
  260. }
  261. var (
  262. err error
  263. apiResp *types.CompOpenApiResp
  264. tx *ent.Tx
  265. )
  266. req := types.CompApiReq{}
  267. if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
  268. return 0, err
  269. }
  270. //先开启事务更新任务状态 => ReqApi_Done(请求API完成)
  271. tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done)
  272. if err != nil {
  273. return 0, err
  274. }
  275. for i := range LoopTryCount { //重试机制
  276. apiResp, err = compapi.NewFastgptChatCompletions(l.ctx,
  277. taskData.OpenaiKey, taskData.OpenaiBase, &req)
  278. if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
  279. //call succ
  280. break
  281. } else if apiResp != nil && len(apiResp.Choices) == 0 {
  282. err = errors.New("返回结果缺失,请检查访问权限")
  283. }
  284. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  285. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  286. time.Sleep(time.Duration(1+i*5) * time.Second)
  287. }
  288. if err != nil || apiResp == nil {
  289. if apiResp == nil && err == nil {
  290. err = errors.New("resp is null")
  291. }
  292. _ = tx.Rollback() //回滚之前更新状态
  293. //fmt.Printf("Worker:%d NewFastgptChatCompletions Fail,Rollback......\n", workerId)
  294. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  295. et := 1
  296. if err1 != nil {
  297. et = 0
  298. }
  299. return et, err
  300. }
  301. respBs, err := json.Marshal(*apiResp)
  302. if err != nil {
  303. _ = tx.Rollback() //回滚之前更新状态
  304. return 0, err
  305. }
  306. taskData.ResponseRaw = string(respBs)
  307. _, err = tx.CompapiAsynctask.UpdateOneID(taskData.ID).
  308. SetResponseRaw(taskData.ResponseRaw).
  309. Save(l.ctx)
  310. if err != nil {
  311. _ = tx.Rollback() //回滚之前更新状态
  312. return 0, err
  313. }
  314. err = tx.Commit() //事务提交
  315. //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId)
  316. if err != nil {
  317. return 0, err
  318. }
  319. taskData.TaskStatus = ReqApi_Done //状态迁移
  320. return 1, nil
  321. }
  322. // 更新任务状态事务版
  323. func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) {
  324. //开启Mysql事务
  325. tx, _ := l.svcCtx.DB.Tx(l.ctx)
  326. _, err := tx.CompapiAsynctask.UpdateOneID(Id).
  327. SetTaskStatus(int8(status)).
  328. SetUpdatedAt(time.Now()).
  329. Save(l.ctx)
  330. if err != nil {
  331. return nil, err
  332. }
  333. return tx, nil
  334. }
  335. func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
  336. var err error
  337. var needStop = false
  338. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  339. _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  340. SetUpdatedAt(time.Now()).
  341. SetTaskStatus(int8(Task_Fail)).
  342. Save(l.ctx)
  343. if err == nil {
  344. needStop = true
  345. taskData.TaskStatus = Task_Fail
  346. }
  347. }
  348. return needStop, err
  349. }
  350. // 错误任务处理
  351. func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
  352. logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
  353. cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  354. SetUpdatedAt(time.Now())
  355. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  356. taskData.TaskStatus = Task_Fail
  357. cauo = cauo.SetTaskStatus(int8(Task_Fail))
  358. } else {
  359. cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
  360. SetLastError(lasterr.Error())
  361. }
  362. _, err := cauo.Save(l.ctx)
  363. return err
  364. }