asynctask.go 16 KB


  1. package main
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "hash/fnv"
  10. "os"
  11. "os/signal"
  12. "runtime"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "syscall"
  18. "time"
  19. "wechat-api/ent"
  20. "wechat-api/ent/compapiasynctask"
  21. "wechat-api/ent/predicate"
  22. "wechat-api/internal/svc"
  23. "wechat-api/internal/types"
  24. "wechat-api/internal/utils/compapi"
  25. "github.com/suyuan32/simple-admin-common/config"
  26. "github.com/zeromicro/go-zero/core/conf"
  27. "github.com/zeromicro/go-zero/core/logx"
  28. )
  29. const (
  30. Task_Ready = 10 //任务就绪
  31. ReqApi_Done = 20 //请求API完成
  32. Callback_Done = 30 //请求回调完成
  33. All_Done = Callback_Done //全部完成 (暂时将成功状态的终点标注于此)
  34. Task_Suspend = 60 //任务暂停
  35. Task_Fail = 70 //任务失败
  36. LoopTryCount = 3 //循环体内重试次数
  37. ErrTaskTryCount = 3 //最大允许错误任务重试次数
  38. DefaultDisId = "DIS0001"
  39. )
  40. type Config struct {
  41. BatchLoadTask uint `json:",default=100"`
  42. MaxWorker uint `json:",default=2"`
  43. MaxChannel uint `json:",default=1"`
  44. Debug bool `json:",default=false"`
  45. DatabaseConf config.DatabaseConf
  46. RedisConf config.RedisConf
  47. }
  48. type TaskStat struct {
  49. }
  50. type AsyncTask struct {
  51. logx.Logger
  52. ctx context.Context
  53. svcCtx *svc.ServiceContext
  54. Conf Config
  55. Stats TaskStat
  56. }
  57. type Task struct {
  58. Data *ent.CompapiAsynctask
  59. Idx int
  60. Code int
  61. }
  62. // 带会话管理的任务通道组
  63. type TaskDispatcher struct {
  64. mu sync.Mutex
  65. workerChs []chan Task // 每个worker独立通道
  66. }
  67. var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file")
  68. func getGoroutineId() (int64, error) {
  69. // 堆栈结果中需要消除的前缀符
  70. var goroutineSpace = []byte("goroutine ")
  71. bs := make([]byte, 128)
  72. bs = bs[:runtime.Stack(bs, false)]
  73. bs = bytes.TrimPrefix(bs, goroutineSpace)
  74. i := bytes.IndexByte(bs, ' ')
  75. if i < 0 {
  76. return -1, errors.New("get current goroutine id failed")
  77. }
  78. return strconv.ParseInt(string(bs[:i]), 10, 64)
  79. }
  80. func NewTaskDispatcher(channelCount uint, chanSize uint) *TaskDispatcher {
  81. td := &TaskDispatcher{
  82. workerChs: make([]chan Task, channelCount),
  83. }
  84. // 初始化worker通道
  85. for i := range td.workerChs {
  86. td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲
  87. }
  88. return td
  89. }
  90. // 按woker下标索引选择workerChs
  91. func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) {
  92. nidx := idx % uint(len(td.workerChs))
  93. return nidx, td.workerChs[nidx]
  94. }
  95. // 按哈希分片选择workerChs
  96. func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
  97. if len(disID) == 0 {
  98. disID = DefaultDisId
  99. }
  100. idx := 0
  101. if len(td.workerChs) > 1 {
  102. h := fnv.New32a()
  103. h.Write([]byte(disID))
  104. idx = int(h.Sum32()) % len(td.workerChs)
  105. }
  106. outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
  107. for i := range len(td.workerChs) {
  108. outStr += fmt.Sprintf("#%d", i+1)
  109. if i < len(td.workerChs)-1 {
  110. outStr += ","
  111. }
  112. }
  113. outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
  114. logx.Debug(outStr)
  115. return idx, td.workerChs[idx]
  116. }
  117. // 分配任务到对应的消费channel
  118. func (td *TaskDispatcher) Dispatch(task Task) {
  119. td.mu.Lock()
  120. defer td.mu.Unlock()
  121. // 根据chatId哈希获得其对应的workerChan
  122. workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
  123. // 将任务送入该chatid的workerChan
  124. workerCh <- task
  125. logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
  126. }
  127. // 更新任务状态
  128. func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error {
  129. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  130. SetTaskStatus(int8(status)).
  131. SetUpdatedAt(time.Now()).
  132. Save(me.ctx)
  133. return err
  134. }
  135. // 更新请求大模型后结果
  136. func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error {
  137. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  138. SetUpdatedAt(time.Now()).
  139. SetResponseRaw(apiResponse).
  140. SetLastError("").
  141. SetRetryCount(0).
  142. Save(me.ctx)
  143. return err
  144. }
  145. func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error {
  146. callResStr := ""
  147. switch v := callRes.(type) {
  148. case []byte:
  149. callResStr = string(v)
  150. default:
  151. if bs, err := json.Marshal(v); err == nil {
  152. callResStr = string(bs)
  153. } else {
  154. return err
  155. }
  156. }
  157. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  158. SetUpdatedAt(time.Now()).
  159. SetCallbackResponseRaw(callResStr).
  160. SetLastError("").
  161. SetRetryCount(0).
  162. Save(me.ctx)
  163. return err
  164. }
  165. func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
  166. var err error
  167. var needStop = false
  168. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  169. _, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  170. SetUpdatedAt(time.Now()).
  171. SetTaskStatus(int8(Task_Fail)).
  172. Save(me.ctx)
  173. if err == nil {
  174. needStop = true
  175. taskData.TaskStatus = Task_Fail
  176. }
  177. }
  178. return needStop, err
  179. }
  180. // 错误任务处理
  181. func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
  182. logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
  183. cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  184. SetUpdatedAt(time.Now())
  185. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  186. taskData.TaskStatus = Task_Fail
  187. cauo = cauo.SetTaskStatus(int8(Task_Fail))
  188. } else {
  189. cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
  190. SetLastError(lasterr.Error())
  191. }
  192. _, err := cauo.Save(me.ctx)
  193. return err
  194. }
  195. func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
  196. var workerId int64
  197. if me.Conf.Debug {
  198. workerId, _ = getGoroutineId()
  199. logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
  200. }
  201. if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  202. return 1, errors.New("too many err retry")
  203. }
  204. if taskData.TaskStatus != ReqApi_Done {
  205. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  206. }
  207. if len(taskData.CallbackURL) == 0 {
  208. return 0, errors.New("callback url empty")
  209. }
  210. if len(taskData.ResponseRaw) == 0 {
  211. return 0, errors.New("call api response empty")
  212. }
  213. /*
  214. fstr := "mytest-svc:"
  215. if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) {
  216. taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
  217. }
  218. */
  219. //请求预先给定的callback_url
  220. res := map[string]any{}
  221. var err error
  222. //初始化client
  223. client := compapi.NewClient(me.ctx)
  224. for i := range LoopTryCount { //重试机制
  225. select {
  226. case <-me.ctx.Done(): //接到信号退出
  227. goto endloopTry
  228. default:
  229. res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData)
  230. if err == nil {
  231. //call succ
  232. //fmt.Println("callback succ..........")
  233. //fmt.Println(typekit.PrettyPrint(res))
  234. logx.Infof("callback:'%s' succ", taskData.CallbackURL)
  235. goto endloopTry
  236. }
  237. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  238. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  239. time.Sleep(time.Duration(1+i*5) * time.Second)
  240. }
  241. }
  242. //多次循环之后依然失败,进入错误任务处理环节
  243. endloopTry:
  244. if err != nil {
  245. err1 := me.dealErrorTask(taskData, err) //错误任务处理
  246. et := 1
  247. if err1 != nil {
  248. et = 0
  249. }
  250. return et, err
  251. }
  252. //成功后处理环节
  253. //更新任务状态 => Callback_Done(回调完成)
  254. err = me.updateTaskStatus(taskData.ID, Callback_Done)
  255. if err != nil {
  256. return 0, err
  257. }
  258. //更新taskData.CallbackResponseRaw
  259. if len(res) > 0 {
  260. me.updateCallbackResponse(taskData.ID, res)
  261. }
  262. taskData.TaskStatus = Callback_Done //状态迁移
  263. return 1, nil
  264. }
  265. func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
  266. var workerId int64
  267. if me.Conf.Debug {
  268. workerId, _ = getGoroutineId()
  269. logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
  270. }
  271. if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  272. return 1, errors.New("too many err retry")
  273. }
  274. if taskData.TaskStatus != Task_Ready {
  275. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  276. }
  277. var (
  278. err error
  279. apiResp *types.CompOpenApiResp
  280. )
  281. req := types.CompApiReq{}
  282. if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
  283. return 0, err
  284. }
  285. //初始化client
  286. if !strings.HasSuffix(taskData.OpenaiBase, "/") {
  287. taskData.OpenaiBase = taskData.OpenaiBase + "/"
  288. }
  289. client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase),
  290. compapi.WithApiKey(taskData.OpenaiKey))
  291. for i := range LoopTryCount { //重试机制
  292. select {
  293. case <-me.ctx.Done(): //接到信号退出
  294. goto endloopTry
  295. default:
  296. apiResp, err = client.Chat(&req)
  297. if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
  298. //call succ
  299. goto endloopTry
  300. } else if apiResp != nil && len(apiResp.Choices) == 0 {
  301. err = errors.New("返回结果缺失,请检查访问权限")
  302. }
  303. if me.Conf.Debug {
  304. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  305. taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
  306. }
  307. time.Sleep(time.Duration(1+i*5) * time.Second)
  308. }
  309. }
  310. endloopTry:
  311. //多次循环之后依然失败,进入错误任务处理环节
  312. if err != nil || apiResp == nil {
  313. if apiResp == nil && err == nil {
  314. err = errors.New("resp is null")
  315. }
  316. err1 := me.dealErrorTask(taskData, err) //错误任务处理
  317. et := 1
  318. if err1 != nil {
  319. et = 0
  320. }
  321. return et, err
  322. }
  323. //成功后处理环节
  324. //更新任务状态 => ReqApi_Done(请求API完成)
  325. err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
  326. if err != nil {
  327. return 0, err
  328. }
  329. //更新taskData.ResponseRaw
  330. taskData.ResponseRaw, err = (*apiResp).ToString()
  331. if err != nil {
  332. return 0, err
  333. }
  334. err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw)
  335. if err != nil {
  336. return 0, err
  337. }
  338. taskData.TaskStatus = ReqApi_Done //状态迁移
  339. return 1, nil
  340. }
  341. func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) {
  342. var predicates []predicate.CompapiAsynctask
  343. predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
  344. var tasks []Task
  345. res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
  346. Order(ent.Asc(compapiasynctask.FieldID)).
  347. Limit(int(me.Conf.BatchLoadTask)).
  348. All(me.ctx)
  349. if err == nil {
  350. for idx, val := range res {
  351. tasks = append(tasks, Task{Data: val, Idx: idx})
  352. }
  353. }
  354. return tasks, err
  355. }
  356. func (me *AsyncTask) processTask(workerID uint, task Task) {
  357. /*
  358. fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
  359. workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
  360. task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)
  361. */
  362. _ = workerID
  363. var err error
  364. rt := 0
  365. for {
  366. select {
  367. case <-me.ctx.Done(): //接到信号退出
  368. return
  369. default:
  370. if task.Data.TaskStatus >= All_Done {
  371. goto endfor //原来的break操作加了switch一层不好使了
  372. }
  373. switch task.Data.TaskStatus {
  374. case Task_Ready:
  375. //请求API平台
  376. // succ: taskStatus Task_Ready => ReqApi_Done
  377. // fail: taskStatus保持当前不变或Task_Fail
  378. rt, err = me.requestAPI(task.Data)
  379. case ReqApi_Done:
  380. //结果回调
  381. // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
  382. // fail: taskStatus保持当前不变或Task_Fail
  383. rt, err = me.requestCallback(task.Data)
  384. }
  385. if err != nil {
  386. //收集错误
  387. if rt == 0 {
  388. //不可恢复错误处理....
  389. logx.Errorf("Task error by '%s'", err)
  390. } else {
  391. logx.Debugf("Task ignore by '%s'", err)
  392. }
  393. return //先暂时忽略处理,也许应按错误类型分别对待
  394. }
  395. }
  396. }
  397. endfor:
  398. }
  399. func (me *AsyncTask) batchWork() (int64, int64) {
  400. var (
  401. wg sync.WaitGroup
  402. produced int64 //生产数量(原子计数器)
  403. consumed int64 //消费数量(原子计数器)
  404. )
  405. //创建任务分发器
  406. dispatcher := NewTaskDispatcher(me.Conf.MaxChannel, me.Conf.BatchLoadTask/me.Conf.MaxChannel)
  407. //启动消费者
  408. for widx := range me.Conf.MaxWorker {
  409. cidx, ch := dispatcher.getWorkerChanByIdx(widx)
  410. wg.Add(1)
  411. go func(workerID uint, channelID uint, taskCh chan Task) {
  412. defer wg.Done()
  413. gidStr := ""
  414. if me.Conf.Debug {
  415. gid, _ := getGoroutineId()
  416. gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
  417. }
  418. logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......",
  419. workerID, gidStr, channelID)
  420. for task := range taskCh {
  421. me.processTask(widx, task)
  422. atomic.AddInt64(&consumed, 1)
  423. }
  424. }(widx+1, cidx+1, ch)
  425. }
  426. // 生产者
  427. wg.Add(1)
  428. go func() {
  429. defer wg.Done()
  430. gidStr := ""
  431. if me.Conf.Debug {
  432. gid, _ := getGoroutineId()
  433. gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
  434. }
  435. logx.Infof("Producer @1%s start......", gidStr)
  436. //获得待处理异步任务列表
  437. tasks, err := me.getAsyncReqTaskList()
  438. if err != nil {
  439. logx.Errorf("getAsyncReqTaskList err:%s", err)
  440. return
  441. }
  442. // 分发任务
  443. for _, task := range tasks {
  444. dispatcher.Dispatch(task)
  445. atomic.AddInt64(&produced, 1)
  446. }
  447. logx.Infof("📦Producer @1 此批次共创建任务%d件", len(tasks))
  448. // 关闭所有会话通道
  449. dispatcher.mu.Lock()
  450. for _, ch := range dispatcher.workerChs {
  451. _ = ch
  452. close(ch)
  453. }
  454. dispatcher.mu.Unlock()
  455. }()
  456. wg.Wait()
  457. consumedRatStr := ""
  458. if atomic.LoadInt64(&produced) > 0 {
  459. consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed),
  460. (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
  461. }
  462. logx.Infof("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr)
  463. return produced, consumed
  464. }
  465. func (me *AsyncTask) InitServiceContext() *svc.ServiceContext {
  466. rds := me.Conf.RedisConf.MustNewUniversalRedis()
  467. dbOpts := []ent.Option{ent.Log(logx.Info),
  468. ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())}
  469. if me.Conf.Debug {
  470. dbOpts = append(dbOpts, ent.Debug())
  471. }
  472. db := ent.NewClient(dbOpts...)
  473. svcCtx := svc.ServiceContext{DB: db, Rds: rds}
  474. //svcCtx.Config
  475. return &svcCtx
  476. }
  477. func (me *AsyncTask) adjustConf() {
  478. if me.Conf.MaxWorker <= 0 {
  479. me.Conf.MaxWorker = 2
  480. }
  481. if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker {
  482. me.Conf.MaxChannel = me.Conf.MaxWorker
  483. }
  484. }
  485. func (me *AsyncTask) Run(ctx context.Context) {
  486. me.ctx = ctx
  487. me.Logger = logx.WithContext(ctx)
  488. me.adjustConf()
  489. for {
  490. select {
  491. case <-ctx.Done():
  492. // 收到了取消信号
  493. fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err())
  494. return
  495. default:
  496. produced, _ := me.batchWork()
  497. timeDurnum := 1
  498. if produced == 0 {
  499. timeDurnum = 5
  500. }
  501. fmt.Printf("batchWork will sleep %d sec\n", timeDurnum)
  502. time.Sleep(time.Duration(timeDurnum) * time.Second)
  503. }
  504. }
  505. }
  506. func NewAsyncTask() *AsyncTask {
  507. ataskObj := AsyncTask{}
  508. flag.Parse() //将命令行参数也塞入flag.CommandLine结构
  509. //fmt.Println(typekit.PrettyPrint(flag.CommandLine))
  510. conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv())
  511. //fmt.Println(typekit.PrettyPrint(ataskObj.Conf))
  512. ataskObj.svcCtx = ataskObj.InitServiceContext()
  513. return &ataskObj
  514. }
  515. func main() {
  516. fmt.Println("Compapi Asynctask Start......")
  517. //ctx, cancel := context.WithCancel(context.Background())
  518. ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
  519. defer cancel()
  520. NewAsyncTask().Run(ctx)
  521. fmt.Println("Compapi Asynctask End......")
  522. }