compapi_callback.go 13 KB


  1. package crontask
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "hash/fnv"
  8. "runtime"
  9. "strconv"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "wechat-api/ent"
  14. "wechat-api/ent/compapiasynctask"
  15. "wechat-api/ent/predicate"
  16. "wechat-api/internal/types"
  17. "wechat-api/internal/utils/compapi"
  18. "github.com/zeromicro/go-zero/core/logx"
  19. )
  20. const (
  21. Task_Ready = 10 //任务就绪
  22. ReqApi_Done = 20 //请求API完成
  23. Callback_Done = 30 //请求回调完成
  24. All_Done = 30 //全部完成
  25. Task_Suspend = 60 //任务暂停
  26. Task_Fail = 70 //任务失败
  27. //MaxWorker = 5 //最大并发worker数量
  28. MaxLoadTask = 1000 //一次最大获取任务数量
  29. LoopTryCount = 3 //循环体内重试次数
  30. ErrTaskTryCount = 3 //最大允许错误任务重试次数
  31. DefaultDisId = "DIS0001"
  32. )
  33. type Task struct {
  34. Data *ent.CompapiAsynctask
  35. Idx int
  36. Code int
  37. }
  38. // 带会话管理的任务通道组
  39. type TaskDispatcher struct {
  40. mu sync.Mutex
  41. workerChs []chan Task // 每个worker独立通道
  42. }
  43. func NewTaskDispatcher(channelCount int) *TaskDispatcher {
  44. td := &TaskDispatcher{
  45. workerChs: make([]chan Task, channelCount),
  46. }
  47. // 初始化worker通道
  48. for i := range td.workerChs {
  49. td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲
  50. //fmt.Printf("make worker chan:%d\n", i+1)
  51. }
  52. return td
  53. }
  54. // 按woker下标索引选择workerChs
  55. func (td *TaskDispatcher) getWorkerChanByIdx(idx int) (int, chan Task) {
  56. nidx := idx % len(td.workerChs)
  57. return nidx, td.workerChs[nidx]
  58. }
  59. // 按哈希分片选择workerChs
  60. func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
  61. if len(disID) == 0 {
  62. disID = DefaultDisId
  63. }
  64. idx := 0
  65. if len(td.workerChs) > 1 {
  66. h := fnv.New32a()
  67. h.Write([]byte(disID))
  68. idx = int(h.Sum32()) % len(td.workerChs)
  69. }
  70. outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
  71. for i := range len(td.workerChs) {
  72. outStr += fmt.Sprintf("#%d", i+1)
  73. if i < len(td.workerChs)-1 {
  74. outStr += ","
  75. }
  76. }
  77. outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
  78. logx.Debug(outStr)
  79. return idx, td.workerChs[idx]
  80. }
  81. // 分配任务到对应的消费channel
  82. func (td *TaskDispatcher) Dispatch(task Task) {
  83. td.mu.Lock()
  84. defer td.mu.Unlock()
  85. // 根据chatId哈希获得其对应的workerChan
  86. workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
  87. // 将任务送入该chatid的workerChan
  88. workerCh <- task
  89. logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
  90. }
  91. func getGoroutineId() (int64, error) {
  92. // 堆栈结果中需要消除的前缀符
  93. var goroutineSpace = []byte("goroutine ")
  94. bs := make([]byte, 128)
  95. bs = bs[:runtime.Stack(bs, false)]
  96. bs = bytes.TrimPrefix(bs, goroutineSpace)
  97. i := bytes.IndexByte(bs, ' ')
  98. if i < 0 {
  99. return -1, errors.New("get current goroutine id failed")
  100. }
  101. return strconv.ParseInt(string(bs[:i]), 10, 64)
  102. }
  103. func (l *CronTask) CliTest(MaxWorker int, MaxChannel int) {
  104. l.compApiCallback(MaxWorker, MaxChannel)
  105. }
  106. func (l *CronTask) compApiCallback(MaxWorker int, MaxChannel int) {
  107. var (
  108. wg sync.WaitGroup
  109. produced int64 //生产数量(原子计数器)
  110. consumed int64 //消费数量(原子计数器)
  111. )
  112. //创建任务分发器
  113. if MaxWorker <= 0 {
  114. MaxWorker = 2
  115. }
  116. if MaxChannel <= 0 || MaxChannel > MaxWorker {
  117. MaxChannel = MaxWorker
  118. }
  119. dispatcher := NewTaskDispatcher(MaxChannel)
  120. //启动消费者
  121. for widx := range MaxWorker {
  122. cidx, ch := dispatcher.getWorkerChanByIdx(widx)
  123. wg.Add(1)
  124. go func(workerID int, channelID int, taskCh chan Task) {
  125. defer wg.Done()
  126. gid, _ := getGoroutineId()
  127. logx.Infof("Consumer %d(Goroutine:%d) bind WorkerChan:#%d start......\n",
  128. workerID, gid, channelID)
  129. for task := range taskCh {
  130. l.processTask(workerID, task)
  131. atomic.AddInt64(&consumed, 1)
  132. }
  133. }(widx+1, cidx+1, ch)
  134. }
  135. // 生产者
  136. wg.Add(1)
  137. go func() {
  138. defer wg.Done()
  139. gid, _ := getGoroutineId()
  140. logx.Infof("Producer 1(Goroutine:%d) start......\n", gid)
  141. //获得待处理异步任务列表
  142. tasks, err := l.getAsyncReqTaskList()
  143. if err != nil {
  144. logx.Errorf("getAsyncReqTaskList err:%s", err)
  145. return
  146. }
  147. // 分发任务
  148. for _, task := range tasks {
  149. dispatcher.Dispatch(task)
  150. atomic.AddInt64(&produced, 1)
  151. }
  152. logx.Infof("📦Producer 1 此批次共创建任务%d件", len(tasks))
  153. // 关闭所有会话通道
  154. dispatcher.mu.Lock()
  155. for _, ch := range dispatcher.workerChs {
  156. _ = ch
  157. close(ch)
  158. }
  159. dispatcher.mu.Unlock()
  160. }()
  161. wg.Wait()
  162. if atomic.LoadInt64(&produced) > 0 {
  163. logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced),
  164. (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
  165. }
  166. }
  167. func (l *CronTask) getAsyncReqTaskList() ([]Task, error) {
  168. var predicates []predicate.CompapiAsynctask
  169. predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
  170. var tasks []Task
  171. res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
  172. Order(ent.Asc(compapiasynctask.FieldID)).
  173. Limit(MaxLoadTask).
  174. All(l.ctx)
  175. if err == nil {
  176. for idx, val := range res {
  177. tasks = append(tasks, Task{Data: val, Idx: idx})
  178. }
  179. }
  180. return tasks, err
  181. }
  182. func (l *CronTask) processTask(workerID int, task Task) {
  183. /*
  184. fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
  185. workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
  186. task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)*/
  187. _ = workerID
  188. var err error
  189. rt := 0
  190. for {
  191. if task.Data.TaskStatus >= All_Done {
  192. break
  193. }
  194. switch task.Data.TaskStatus {
  195. case Task_Ready:
  196. //请求API平台
  197. // succ: taskStatus Task_Ready => ReqApi_Done
  198. // fail: taskStatus保持当前不变或Task_Fail
  199. rt, err = l.requestAPI(task.Data)
  200. case ReqApi_Done:
  201. //结果回调
  202. // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
  203. // fail: taskStatus保持当前不变或Task_Fail
  204. rt, err = l.requestCallback(task.Data)
  205. }
  206. if err != nil {
  207. //收集错误
  208. if rt == 0 {
  209. //不可恢复错误处理....
  210. }
  211. logx.Debugf("Task ignore by '%s'", err)
  212. return //先暂时忽略处理,也许应按错误类型分别对待
  213. }
  214. }
  215. }
  216. func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
  217. workerId, _ := getGoroutineId()
  218. logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
  219. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  220. return 1, errors.New("too many err retry")
  221. }
  222. if taskData.TaskStatus != ReqApi_Done {
  223. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  224. }
  225. if len(taskData.CallbackURL) == 0 {
  226. return 0, errors.New("callback url empty")
  227. }
  228. if len(taskData.ResponseRaw) == 0 {
  229. return 0, errors.New("call api response empty")
  230. }
  231. /*
  232. fstr := "mytest-svc:"
  233. if taskData.RetryCount > 0 && strings.Contains(taskData.CallbackURL, fstr) {
  234. taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
  235. }
  236. */
  237. //先开启事务更新任务状态 => Callback_Done(回调完成)
  238. tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done)
  239. if err != nil {
  240. return 0, err
  241. }
  242. //请求预先给定的callback_url
  243. var res map[string]any
  244. //初始化client
  245. client := compapi.NewClient(l.ctx)
  246. for i := range LoopTryCount { //重试机制
  247. res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData.ResponseRaw)
  248. //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...)
  249. if err == nil {
  250. //call succ
  251. //fmt.Println("callback succ..........")
  252. //fmt.Println(typekit.PrettyPrint(res))
  253. logx.Infof("callback:'%s' succ", taskData.CallbackURL)
  254. break
  255. }
  256. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  257. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  258. time.Sleep(time.Duration(1+i*5) * time.Second)
  259. }
  260. //多次循环之后依然失败,进入错误任务处理环节
  261. if err != nil {
  262. _ = tx.Rollback() //回滚之前更新状态
  263. //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId)
  264. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  265. et := 1
  266. if err1 != nil {
  267. et = 0
  268. }
  269. return et, err
  270. }
  271. //成功后处理环节
  272. err = tx.Commit() //事务提交
  273. if err != nil {
  274. return 0, err
  275. }
  276. //更新taskData.CallbackResponseRaw
  277. l.updateCallbackResponse(taskData.ID, res)
  278. taskData.TaskStatus = Callback_Done //状态迁移
  279. return 1, nil
  280. }
  281. func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
  282. workerId, _ := getGoroutineId()
  283. logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
  284. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  285. return 1, errors.New("too many err retry")
  286. }
  287. if taskData.TaskStatus != Task_Ready {
  288. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  289. }
  290. var (
  291. err error
  292. apiResp *types.CompOpenApiResp
  293. tx *ent.Tx
  294. )
  295. req := types.CompApiReq{}
  296. if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
  297. return 0, err
  298. }
  299. //先开启事务更新任务状态 => ReqApi_Done(请求API完成)
  300. tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done)
  301. if err != nil {
  302. return 0, err
  303. }
  304. //初始化client
  305. client := compapi.NewClient(l.ctx, compapi.WithApiBase(taskData.OpenaiBase),
  306. compapi.WithApiKey(taskData.OpenaiKey))
  307. for i := range LoopTryCount { //重试机制
  308. apiResp, err = client.Chat(&req)
  309. if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
  310. //call succ
  311. break
  312. } else if apiResp != nil && len(apiResp.Choices) == 0 {
  313. err = errors.New("返回结果缺失,请检查访问权限")
  314. }
  315. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  316. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  317. time.Sleep(time.Duration(1+i*5) * time.Second)
  318. }
  319. //多次循环之后依然失败,进入错误任务处理环节
  320. if err != nil || apiResp == nil {
  321. if apiResp == nil && err == nil {
  322. err = errors.New("resp is null")
  323. }
  324. _ = tx.Rollback() //回滚之前更新状态
  325. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  326. et := 1
  327. if err1 != nil {
  328. et = 0
  329. }
  330. return et, err
  331. }
  332. //成功后处理环节
  333. //更新taskData.ResponseRaw
  334. taskData.ResponseRaw, err = (*apiResp).ToString()
  335. if err != nil {
  336. _ = tx.Rollback() //回滚之前更新状态
  337. return 0, err
  338. }
  339. err = l.updateApiResponseByTx(tx, taskData.ID, taskData.ResponseRaw)
  340. if err != nil {
  341. return 0, err
  342. }
  343. err = tx.Commit() //事务提交
  344. //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId)
  345. if err != nil {
  346. return 0, err
  347. }
  348. taskData.TaskStatus = ReqApi_Done //状态迁移
  349. return 1, nil
  350. }
  351. // 更新任务状态事务版
  352. func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) {
  353. //开启Mysql事务
  354. tx, _ := l.svcCtx.DB.Tx(l.ctx)
  355. _, err := tx.CompapiAsynctask.UpdateOneID(Id).
  356. SetTaskStatus(int8(status)).
  357. SetUpdatedAt(time.Now()).
  358. Save(l.ctx)
  359. if err != nil {
  360. return nil, err
  361. }
  362. return tx, nil
  363. }
  364. // 更新请求大模型后结果事务版
  365. func (l *CronTask) updateApiResponseByTx(tx *ent.Tx, taskId uint64, apiResponse string) error {
  366. _, err := tx.CompapiAsynctask.UpdateOneID(taskId).
  367. SetUpdatedAt(time.Now()).
  368. SetResponseRaw(apiResponse).
  369. SetLastError("").
  370. SetRetryCount(0).
  371. Save(l.ctx)
  372. if err != nil {
  373. _ = tx.Rollback() //回滚之前更新状态
  374. }
  375. return err
  376. }
  377. func (l *CronTask) updateCallbackResponse(taskId uint64, callRes any) error {
  378. callResStr := ""
  379. switch v := callRes.(type) {
  380. case []byte:
  381. callResStr = string(v)
  382. default:
  383. if bs, err := json.Marshal(v); err == nil {
  384. callResStr = string(bs)
  385. } else {
  386. return err
  387. }
  388. }
  389. _, err := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  390. SetUpdatedAt(time.Now()).
  391. SetCallbackResponseRaw(callResStr).
  392. SetLastError("").
  393. SetRetryCount(0).
  394. Save(l.ctx)
  395. return err
  396. }
  397. func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
  398. var err error
  399. var needStop = false
  400. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  401. _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  402. SetUpdatedAt(time.Now()).
  403. SetTaskStatus(int8(Task_Fail)).
  404. Save(l.ctx)
  405. if err == nil {
  406. needStop = true
  407. taskData.TaskStatus = Task_Fail
  408. }
  409. }
  410. return needStop, err
  411. }
  412. // 错误任务处理
  413. func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
  414. logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
  415. cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  416. SetUpdatedAt(time.Now())
  417. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  418. taskData.TaskStatus = Task_Fail
  419. cauo = cauo.SetTaskStatus(int8(Task_Fail))
  420. } else {
  421. cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
  422. SetLastError(lasterr.Error())
  423. }
  424. _, err := cauo.Save(l.ctx)
  425. return err
  426. }