compapi_callback.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. package crontask
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "hash/fnv"
  8. "runtime"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "wechat-api/ent"
  15. "wechat-api/ent/compapiasynctask"
  16. "wechat-api/ent/predicate"
  17. "wechat-api/internal/types"
  18. "wechat-api/internal/utils/compapi"
  19. "github.com/zeromicro/go-zero/core/logx"
  20. )
  21. const (
  22. Task_Ready = 10 //任务就绪
  23. ReqApi_Done = 20 //请求API完成
  24. Callback_Done = 30 //请求回调完成
  25. All_Done = 30 //全部完成
  26. Task_Suspend = 60 //任务暂停
  27. Task_Fail = 70 //任务失败
  28. //MaxWorker = 5 //最大并发worker数量
  29. MaxLoadTask = 1000 //一次最大获取任务数量
  30. LoopTryCount = 3 //循环体内重试次数
  31. ErrTaskTryCount = 3 //最大允许错误任务重试次数
  32. DefaultDisId = "DIS0001"
  33. )
  34. type Task struct {
  35. Data *ent.CompapiAsynctask
  36. Idx int
  37. Code int
  38. }
  39. // 带会话管理的任务通道组
  40. type TaskDispatcher struct {
  41. mu sync.Mutex
  42. workerChs []chan Task // 每个worker独立通道
  43. }
  44. func NewTaskDispatcher(channelCount int) *TaskDispatcher {
  45. td := &TaskDispatcher{
  46. workerChs: make([]chan Task, channelCount),
  47. }
  48. // 初始化worker通道
  49. for i := range td.workerChs {
  50. td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲
  51. //fmt.Printf("make worker chan:%d\n", i+1)
  52. }
  53. return td
  54. }
  55. // 按woker下标索引选择workerChs
  56. func (td *TaskDispatcher) getWorkerChanByIdx(idx int) (int, chan Task) {
  57. nidx := idx % len(td.workerChs)
  58. return nidx, td.workerChs[nidx]
  59. }
  60. // 按哈希分片选择workerChs
  61. func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
  62. if len(disID) == 0 {
  63. disID = DefaultDisId
  64. }
  65. idx := 0
  66. if len(td.workerChs) > 1 {
  67. h := fnv.New32a()
  68. h.Write([]byte(disID))
  69. idx = int(h.Sum32()) % len(td.workerChs)
  70. }
  71. outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
  72. for i := range len(td.workerChs) {
  73. outStr += fmt.Sprintf("#%d", i+1)
  74. if i < len(td.workerChs)-1 {
  75. outStr += ","
  76. }
  77. }
  78. outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
  79. logx.Debug(outStr)
  80. return idx, td.workerChs[idx]
  81. }
  82. // 分配任务到对应的消费channel
  83. func (td *TaskDispatcher) Dispatch(task Task) {
  84. td.mu.Lock()
  85. defer td.mu.Unlock()
  86. // 根据chatId哈希获得其对应的workerChan
  87. workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
  88. // 将任务送入该chatid的workerChan
  89. workerCh <- task
  90. logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
  91. }
  92. func getGoroutineId() (int64, error) {
  93. // 堆栈结果中需要消除的前缀符
  94. var goroutineSpace = []byte("goroutine ")
  95. bs := make([]byte, 128)
  96. bs = bs[:runtime.Stack(bs, false)]
  97. bs = bytes.TrimPrefix(bs, goroutineSpace)
  98. i := bytes.IndexByte(bs, ' ')
  99. if i < 0 {
  100. return -1, errors.New("get current goroutine id failed")
  101. }
  102. return strconv.ParseInt(string(bs[:i]), 10, 64)
  103. }
  104. func (l *CronTask) CliTest(MaxWorker int, MaxChannel int) {
  105. l.compApiCallback(MaxWorker, MaxChannel)
  106. }
  107. func (l *CronTask) compApiCallback(MaxWorker int, MaxChannel int) {
  108. var (
  109. wg sync.WaitGroup
  110. produced int64 //生产数量(原子计数器)
  111. consumed int64 //消费数量(原子计数器)
  112. )
  113. //创建任务分发器
  114. if MaxWorker <= 0 {
  115. MaxWorker = 2
  116. }
  117. if MaxChannel <= 0 || MaxChannel > MaxWorker {
  118. MaxChannel = MaxWorker
  119. }
  120. dispatcher := NewTaskDispatcher(MaxChannel)
  121. //启动消费者
  122. for widx := range MaxWorker {
  123. cidx, ch := dispatcher.getWorkerChanByIdx(widx)
  124. wg.Add(1)
  125. go func(workerID int, channelID int, taskCh chan Task) {
  126. defer wg.Done()
  127. gid, _ := getGoroutineId()
  128. logx.Infof("Consumer %d(Goroutine:%d) bind WorkerChan:#%d start......\n",
  129. workerID, gid, channelID)
  130. for task := range taskCh {
  131. l.processTask(workerID, task)
  132. atomic.AddInt64(&consumed, 1)
  133. }
  134. }(widx+1, cidx+1, ch)
  135. }
  136. // 生产者
  137. wg.Add(1)
  138. go func() {
  139. defer wg.Done()
  140. gid, _ := getGoroutineId()
  141. logx.Infof("Producer 1(Goroutine:%d) start......\n", gid)
  142. //获得待处理异步任务列表
  143. tasks, err := l.getAsyncReqTaskList()
  144. if err != nil {
  145. logx.Errorf("getAsyncReqTaskList err:%s", err)
  146. return
  147. }
  148. // 分发任务
  149. for _, task := range tasks {
  150. dispatcher.Dispatch(task)
  151. atomic.AddInt64(&produced, 1)
  152. }
  153. logx.Infof("📦Producer 1 此批次共创建任务%d件", len(tasks))
  154. // 关闭所有会话通道
  155. dispatcher.mu.Lock()
  156. for _, ch := range dispatcher.workerChs {
  157. _ = ch
  158. close(ch)
  159. }
  160. dispatcher.mu.Unlock()
  161. }()
  162. wg.Wait()
  163. if atomic.LoadInt64(&produced) > 0 {
  164. logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced),
  165. (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
  166. }
  167. }
  168. func (l *CronTask) getAsyncReqTaskList() ([]Task, error) {
  169. var predicates []predicate.CompapiAsynctask
  170. predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
  171. var tasks []Task
  172. res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
  173. Order(ent.Asc(compapiasynctask.FieldID)).
  174. Limit(MaxLoadTask).
  175. All(l.ctx)
  176. if err == nil {
  177. for idx, val := range res {
  178. tasks = append(tasks, Task{Data: val, Idx: idx})
  179. }
  180. }
  181. return tasks, err
  182. }
  183. func (l *CronTask) processTask(workerID int, task Task) {
  184. /*
  185. fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
  186. workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
  187. task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)*/
  188. _ = workerID
  189. var err error
  190. rt := 0
  191. for {
  192. if task.Data.TaskStatus >= All_Done {
  193. break
  194. }
  195. switch task.Data.TaskStatus {
  196. case Task_Ready:
  197. //请求API平台
  198. // succ: taskStatus Task_Ready => ReqApi_Done
  199. // fail: taskStatus保持当前不变或Task_Fail
  200. rt, err = l.requestAPI(task.Data)
  201. case ReqApi_Done:
  202. //结果回调
  203. // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
  204. // fail: taskStatus保持当前不变或Task_Fail
  205. rt, err = l.requestCallback(task.Data)
  206. }
  207. if err != nil {
  208. //收集错误
  209. if rt == 0 {
  210. //不可恢复错误处理....
  211. }
  212. logx.Debugf("Task ignore by '%s'", err)
  213. return //先暂时忽略处理,也许应按错误类型分别对待
  214. }
  215. }
  216. }
  217. func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
  218. workerId, _ := getGoroutineId()
  219. logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
  220. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  221. return 1, errors.New("too many err retry")
  222. }
  223. if taskData.TaskStatus != ReqApi_Done {
  224. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  225. }
  226. if len(taskData.CallbackURL) == 0 {
  227. return 0, errors.New("callback url empty")
  228. }
  229. if len(taskData.ResponseRaw) == 0 {
  230. return 0, errors.New("call api response empty")
  231. }
  232. fstr := "mytest-svc:"
  233. if taskData.RetryCount > 0 && strings.Contains(taskData.CallbackURL, fstr) {
  234. taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
  235. }
  236. //先开启事务更新任务状态 => Callback_Done(回调完成)
  237. tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done)
  238. if err != nil {
  239. return 0, err
  240. }
  241. //请求预先给定的callback_url
  242. var res map[string]any
  243. //初始化client
  244. client := compapi.NewClient(l.ctx)
  245. for i := range LoopTryCount { //重试机制
  246. res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData.ResponseRaw)
  247. //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...)
  248. if err == nil {
  249. //call succ
  250. //fmt.Println("callback succ..........")
  251. //fmt.Println(typekit.PrettyPrint(res))
  252. logx.Infof("callback:'%s' succ", taskData.CallbackURL)
  253. break
  254. }
  255. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  256. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  257. time.Sleep(time.Duration(1+i*5) * time.Second)
  258. }
  259. //多次循环之后依然失败,进入错误任务处理环节
  260. if err != nil {
  261. _ = tx.Rollback() //回滚之前更新状态
  262. //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId)
  263. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  264. et := 1
  265. if err1 != nil {
  266. et = 0
  267. }
  268. return et, err
  269. }
  270. //成功后处理环节
  271. err = tx.Commit() //事务提交
  272. if err != nil {
  273. return 0, err
  274. }
  275. //更新taskData.CallbackResponseRaw
  276. l.updateCallbackResponse(taskData.ID, res)
  277. taskData.TaskStatus = Callback_Done //状态迁移
  278. return 1, nil
  279. }
  280. func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
  281. workerId, _ := getGoroutineId()
  282. logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
  283. if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  284. return 1, errors.New("too many err retry")
  285. }
  286. if taskData.TaskStatus != Task_Ready {
  287. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  288. }
  289. var (
  290. err error
  291. apiResp *types.CompOpenApiResp
  292. tx *ent.Tx
  293. )
  294. req := types.CompApiReq{}
  295. if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
  296. return 0, err
  297. }
  298. //先开启事务更新任务状态 => ReqApi_Done(请求API完成)
  299. tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done)
  300. if err != nil {
  301. return 0, err
  302. }
  303. //初始化client
  304. client := compapi.NewClient(l.ctx, compapi.WithApiBase(taskData.OpenaiBase),
  305. compapi.WithApiKey(taskData.OpenaiKey))
  306. for i := range LoopTryCount { //重试机制
  307. apiResp, err = client.Chat(&req)
  308. if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
  309. //call succ
  310. break
  311. } else if apiResp != nil && len(apiResp.Choices) == 0 {
  312. err = errors.New("返回结果缺失,请检查访问权限")
  313. }
  314. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  315. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  316. time.Sleep(time.Duration(1+i*5) * time.Second)
  317. }
  318. //多次循环之后依然失败,进入错误任务处理环节
  319. if err != nil || apiResp == nil {
  320. if apiResp == nil && err == nil {
  321. err = errors.New("resp is null")
  322. }
  323. _ = tx.Rollback() //回滚之前更新状态
  324. err1 := l.dealErrorTask(taskData, err) //错误任务处理
  325. et := 1
  326. if err1 != nil {
  327. et = 0
  328. }
  329. return et, err
  330. }
  331. //成功后处理环节
  332. //更新taskData.ResponseRaw
  333. taskData.ResponseRaw, err = (*apiResp).ToString()
  334. if err != nil {
  335. _ = tx.Rollback() //回滚之前更新状态
  336. return 0, err
  337. }
  338. err = l.updateApiResponseByTx(tx, taskData.ID, taskData.ResponseRaw)
  339. if err != nil {
  340. return 0, err
  341. }
  342. err = tx.Commit() //事务提交
  343. //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId)
  344. if err != nil {
  345. return 0, err
  346. }
  347. taskData.TaskStatus = ReqApi_Done //状态迁移
  348. return 1, nil
  349. }
  350. // 更新任务状态事务版
  351. func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) {
  352. //开启Mysql事务
  353. tx, _ := l.svcCtx.DB.Tx(l.ctx)
  354. _, err := tx.CompapiAsynctask.UpdateOneID(Id).
  355. SetTaskStatus(int8(status)).
  356. SetUpdatedAt(time.Now()).
  357. Save(l.ctx)
  358. if err != nil {
  359. return nil, err
  360. }
  361. return tx, nil
  362. }
  363. // 更新请求大模型后结果事务版
  364. func (l *CronTask) updateApiResponseByTx(tx *ent.Tx, taskId uint64, apiResponse string) error {
  365. _, err := tx.CompapiAsynctask.UpdateOneID(taskId).
  366. SetUpdatedAt(time.Now()).
  367. SetResponseRaw(apiResponse).
  368. SetLastError("").
  369. SetRetryCount(0).
  370. Save(l.ctx)
  371. if err != nil {
  372. _ = tx.Rollback() //回滚之前更新状态
  373. }
  374. return err
  375. }
  376. func (l *CronTask) updateCallbackResponse(taskId uint64, callRes any) error {
  377. callResStr := ""
  378. switch v := callRes.(type) {
  379. case []byte:
  380. callResStr = string(v)
  381. default:
  382. if bs, err := json.Marshal(v); err == nil {
  383. callResStr = string(bs)
  384. } else {
  385. return err
  386. }
  387. }
  388. _, err := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  389. SetUpdatedAt(time.Now()).
  390. SetCallbackResponseRaw(callResStr).
  391. SetLastError("").
  392. SetRetryCount(0).
  393. Save(l.ctx)
  394. return err
  395. }
  396. func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
  397. var err error
  398. var needStop = false
  399. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  400. _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  401. SetUpdatedAt(time.Now()).
  402. SetTaskStatus(int8(Task_Fail)).
  403. Save(l.ctx)
  404. if err == nil {
  405. needStop = true
  406. taskData.TaskStatus = Task_Fail
  407. }
  408. }
  409. return needStop, err
  410. }
  411. // 错误任务处理
  412. func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
  413. logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
  414. cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  415. SetUpdatedAt(time.Now())
  416. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  417. taskData.TaskStatus = Task_Fail
  418. cauo = cauo.SetTaskStatus(int8(Task_Fail))
  419. } else {
  420. cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
  421. SetLastError(lasterr.Error())
  422. }
  423. _, err := cauo.Save(l.ctx)
  424. return err
  425. }