asynctask.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. package main
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "flag"
  8. "fmt"
  9. "hash/fnv"
  10. "os"
  11. "os/signal"
  12. "runtime"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "syscall"
  18. "time"
  19. "wechat-api/ent"
  20. "wechat-api/ent/compapiasynctask"
  21. "wechat-api/ent/predicate"
  22. "wechat-api/internal/svc"
  23. "wechat-api/internal/types"
  24. "wechat-api/internal/utils/compapi"
  25. "github.com/suyuan32/simple-admin-common/config"
  26. "github.com/zeromicro/go-zero/core/conf"
  27. "github.com/zeromicro/go-zero/core/logx"
  28. )
  29. const (
  30. Task_Ready = 10 //任务就绪
  31. ReqApi_Done = 20 //请求API完成
  32. Callback_Done = 30 //请求回调完成
  33. All_Done = Callback_Done //全部完成 (暂时将成功状态的终点标注于此)
  34. Task_Suspend = 60 //任务暂停
  35. Task_Fail = 70 //任务失败
  36. LoopTryCount = 3 //循环体内重试次数
  37. LoopDelayFactor = 3
  38. ErrTaskTryCount = 3 //最大允许错误任务重试次数
  39. DefaultDisId = "DIS0001"
  40. )
  41. type Config struct {
  42. BatchLoadTask uint `json:",default=100"`
  43. MaxWorker uint `json:",default=2"`
  44. MaxChannel uint `json:",default=1"`
  45. Debug bool `json:",default=false"`
  46. DatabaseConf config.DatabaseConf
  47. RedisConf config.RedisConf
  48. }
  49. type TaskStat struct {
  50. }
  51. type AsyncTask struct {
  52. logx.Logger
  53. ctx context.Context
  54. svcCtx *svc.ServiceContext
  55. Conf Config
  56. Stats TaskStat
  57. }
  58. type Task struct {
  59. Data *ent.CompapiAsynctask
  60. Idx int
  61. Code int
  62. }
  63. // 带会话管理的任务通道组
  64. type TaskDispatcher struct {
  65. mu sync.Mutex
  66. workerChs []chan Task // 每个worker独立通道
  67. Debug bool
  68. }
  69. var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file")
  70. func getGoroutineId() (int64, error) {
  71. // 堆栈结果中需要消除的前缀符
  72. var goroutineSpace = []byte("goroutine ")
  73. bs := make([]byte, 128)
  74. bs = bs[:runtime.Stack(bs, false)]
  75. bs = bytes.TrimPrefix(bs, goroutineSpace)
  76. i := bytes.IndexByte(bs, ' ')
  77. if i < 0 {
  78. return -1, errors.New("get current goroutine id failed")
  79. }
  80. return strconv.ParseInt(string(bs[:i]), 10, 64)
  81. }
  82. func NewTaskDispatcher(channelCount uint, chanSize uint, debug bool) *TaskDispatcher {
  83. td := &TaskDispatcher{
  84. workerChs: make([]chan Task, channelCount),
  85. Debug: debug,
  86. }
  87. // 初始化worker通道
  88. for i := range td.workerChs {
  89. td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲
  90. }
  91. return td
  92. }
  93. // 按woker下标索引选择workerChs
  94. func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) {
  95. nidx := idx % uint(len(td.workerChs))
  96. return nidx, td.workerChs[nidx]
  97. }
  98. // 按哈希分片选择workerChs
  99. func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
  100. if len(disID) == 0 {
  101. disID = DefaultDisId
  102. }
  103. idx := 0
  104. if len(td.workerChs) > 1 {
  105. h := fnv.New32a()
  106. h.Write([]byte(disID))
  107. idx = int(h.Sum32()) % len(td.workerChs)
  108. }
  109. if td.Debug {
  110. outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
  111. for i := range len(td.workerChs) {
  112. outStr += fmt.Sprintf("#%d", i+1)
  113. if i < len(td.workerChs)-1 {
  114. outStr += ","
  115. }
  116. }
  117. outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
  118. logx.Debug(outStr)
  119. }
  120. return idx, td.workerChs[idx]
  121. }
  122. // 分配任务到对应的消费channel
  123. func (td *TaskDispatcher) Dispatch(task Task) {
  124. td.mu.Lock()
  125. defer td.mu.Unlock()
  126. // 根据chatId哈希获得其对应的workerChan
  127. workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
  128. // 将任务送入该chatid的workerChan
  129. workerCh <- task
  130. if td.Debug {
  131. logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
  132. }
  133. }
  134. // 更新任务状态
  135. func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error {
  136. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  137. SetTaskStatus(int8(status)).
  138. SetUpdatedAt(time.Now()).
  139. Save(me.ctx)
  140. return err
  141. }
  142. // 更新请求大模型后结果
  143. func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error {
  144. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  145. SetUpdatedAt(time.Now()).
  146. SetResponseRaw(apiResponse).
  147. SetLastError("").
  148. SetRetryCount(0).
  149. Save(me.ctx)
  150. return err
  151. }
  152. func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error {
  153. callResStr := ""
  154. switch v := callRes.(type) {
  155. case []byte:
  156. callResStr = string(v)
  157. default:
  158. if bs, err := json.Marshal(v); err == nil {
  159. callResStr = string(bs)
  160. } else {
  161. return err
  162. }
  163. }
  164. _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
  165. SetUpdatedAt(time.Now()).
  166. SetCallbackResponseRaw(callResStr).
  167. SetLastError("").
  168. SetRetryCount(0).
  169. Save(me.ctx)
  170. return err
  171. }
  172. func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
  173. var err error
  174. var needStop = false
  175. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  176. _, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  177. SetUpdatedAt(time.Now()).
  178. SetTaskStatus(int8(Task_Fail)).
  179. Save(me.ctx)
  180. if err == nil {
  181. needStop = true
  182. taskData.TaskStatus = Task_Fail
  183. }
  184. }
  185. return needStop, err
  186. }
  187. // 错误任务处理
  188. func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
  189. logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
  190. cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
  191. SetUpdatedAt(time.Now())
  192. if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
  193. taskData.TaskStatus = Task_Fail
  194. cauo = cauo.SetTaskStatus(int8(Task_Fail))
  195. } else {
  196. cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
  197. SetLastError(lasterr.Error())
  198. }
  199. _, err := cauo.Save(me.ctx)
  200. return err
  201. }
  202. func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
  203. var workerId int64
  204. if me.Conf.Debug {
  205. workerId, _ = getGoroutineId()
  206. logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
  207. }
  208. if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  209. return 1, errors.New("too many err retry")
  210. }
  211. if taskData.TaskStatus != ReqApi_Done {
  212. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  213. }
  214. if len(taskData.CallbackURL) == 0 {
  215. return 0, errors.New("callback url empty")
  216. }
  217. if len(taskData.ResponseRaw) == 0 {
  218. return 0, errors.New("previous status call api response empty")
  219. }
  220. /*
  221. fstr := "mytest-svc:"
  222. if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) {
  223. taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
  224. }
  225. */
  226. //请求预先给定的callback_url
  227. res := map[string]any{}
  228. var err error
  229. //初始化client
  230. client := compapi.NewClient(me.ctx)
  231. for i := range LoopTryCount { //重试机制
  232. select {
  233. case <-me.ctx.Done(): //接到信号退出
  234. goto endloopTry
  235. default:
  236. res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData)
  237. if err == nil {
  238. //call succ
  239. //fmt.Println("callback succ..........")
  240. //fmt.Println(typekit.PrettyPrint(res))
  241. if me.Conf.Debug {
  242. logx.Infof("callback:'%s' succ", taskData.CallbackURL)
  243. }
  244. goto endloopTry
  245. }
  246. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  247. taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
  248. time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
  249. }
  250. }
  251. //多次循环之后依然失败,进入错误任务处理环节
  252. endloopTry:
  253. if err != nil {
  254. err1 := me.dealErrorTask(taskData, err) //错误任务处理
  255. et := 1
  256. if err1 != nil {
  257. et = 0
  258. }
  259. return et, err
  260. }
  261. //成功后处理环节
  262. //更新任务状态 => Callback_Done(回调完成)
  263. err = me.updateTaskStatus(taskData.ID, Callback_Done)
  264. if err != nil {
  265. return 0, err
  266. }
  267. //更新taskData.CallbackResponseRaw
  268. if len(res) > 0 {
  269. me.updateCallbackResponse(taskData.ID, res)
  270. }
  271. taskData.TaskStatus = Callback_Done //状态迁移
  272. return 1, nil
  273. }
  274. func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
  275. var workerId int64
  276. if me.Conf.Debug {
  277. workerId, _ = getGoroutineId()
  278. logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
  279. }
  280. if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
  281. return 1, errors.New("too many err retry")
  282. }
  283. if taskData.TaskStatus != Task_Ready {
  284. return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
  285. }
  286. var (
  287. err error
  288. apiResp *types.CompOpenApiResp
  289. )
  290. req := &types.CompApiReq{}
  291. if err = json.Unmarshal([]byte(taskData.RequestRaw), req); err != nil {
  292. return 0, err
  293. }
  294. //初始化client
  295. if !strings.HasSuffix(taskData.OpenaiBase, "/") {
  296. taskData.OpenaiBase = taskData.OpenaiBase + "/"
  297. }
  298. client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase),
  299. compapi.WithApiKey(taskData.OpenaiKey))
  300. for i := range LoopTryCount { //重试机制
  301. select {
  302. case <-me.ctx.Done(): //接到信号退出
  303. goto endloopTry
  304. default:
  305. apiResp, err = client.Chat(req)
  306. if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
  307. //call succ
  308. goto endloopTry
  309. } else if apiResp != nil && len(apiResp.Choices) == 0 {
  310. err = errors.New("返回结果缺失,请检查访问权限")
  311. }
  312. if me.Conf.Debug {
  313. logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
  314. taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
  315. }
  316. time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
  317. }
  318. }
  319. endloopTry:
  320. //多次循环之后依然失败,进入错误任务处理环节
  321. if err != nil || apiResp == nil {
  322. if apiResp == nil && err == nil {
  323. err = errors.New("resp is null")
  324. }
  325. err1 := me.dealErrorTask(taskData, err) //错误任务处理
  326. et := 1
  327. if err1 != nil {
  328. et = 0
  329. }
  330. return et, err
  331. }
  332. //成功后处理环节
  333. //追加访问大模型token统计数据
  334. svcCtx := &compapi.ServiceContext{DB: me.svcCtx.DB, Rds: me.svcCtx.Rds, Config: me.svcCtx.Config}
  335. err = compapi.AppendUsageDetailLog(me.ctx, svcCtx, taskData.AuthToken, req, apiResp)
  336. if err != nil {
  337. return 0, err
  338. }
  339. //更新任务状态 => ReqApi_Done(请求API完成)
  340. err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
  341. if err != nil {
  342. return 0, err
  343. }
  344. //更新taskData.ResponseRaw
  345. taskData.ResponseRaw, err = (*apiResp).ToString()
  346. if err != nil {
  347. return 0, err
  348. }
  349. err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw)
  350. if err != nil {
  351. return 0, err
  352. }
  353. taskData.TaskStatus = ReqApi_Done //状态迁移
  354. return 1, nil
  355. }
  356. /*
  357. CREATE INDEX idx_compapi_task_status_chat_id_id_desc
  358. ON compapi_asynctask (task_status, chat_id, id DESC);
  359. */
  360. func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) {
  361. fieldListStr, _, err := compapi.EntStructGenScanField(&ent.CompapiAsynctask{})
  362. if err != nil {
  363. return nil, err
  364. }
  365. rawQuery := fmt.Sprintf(`
  366. WITH ranked AS (
  367. SELECT %s,
  368. ROW_NUMBER() OVER (PARTITION BY chat_id ORDER BY id DESC) AS rn
  369. FROM compapi_asynctask
  370. WHERE task_status < ?
  371. )
  372. SELECT %s
  373. FROM ranked
  374. WHERE rn <= ?
  375. ORDER BY rn,id DESC
  376. LIMIT ?;
  377. `, fieldListStr, fieldListStr)
  378. // 执行原始查询
  379. rows, err := me.svcCtx.DB.QueryContext(me.ctx, rawQuery, All_Done,
  380. me.Conf.BatchLoadTask, me.Conf.BatchLoadTask)
  381. if err != nil {
  382. return nil, fmt.Errorf("fetch fair tasks query error: %w", err)
  383. }
  384. defer rows.Close()
  385. Idx := 0
  386. tasks := []Task{}
  387. for rows.Next() {
  388. taskrow := ent.CompapiAsynctask{}
  389. var scanParams []any
  390. _, scanParams, err = compapi.EntStructGenScanField(&taskrow)
  391. if err != nil {
  392. break
  393. }
  394. err = rows.Scan(scanParams...)
  395. if err != nil {
  396. break
  397. }
  398. task := Task{Data: &taskrow, Idx: Idx}
  399. tasks = append(tasks, task)
  400. Idx++
  401. }
  402. fmt.Printf("getAsyncReqTaskFairList get task:%d\n", len(tasks))
  403. return tasks, nil
  404. }
  405. func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) {
  406. var predicates []predicate.CompapiAsynctask
  407. predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
  408. var tasks []Task
  409. res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
  410. Order(ent.Desc(compapiasynctask.FieldID)).
  411. Limit(int(me.Conf.BatchLoadTask)).
  412. All(me.ctx)
  413. if err == nil {
  414. for idx, val := range res {
  415. tasks = append(tasks, Task{Data: val, Idx: idx})
  416. }
  417. }
  418. return tasks, err
  419. }
  420. func (me *AsyncTask) processTask(workerID uint, task Task) {
  421. /*
  422. fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
  423. workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
  424. task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)
  425. */
  426. _ = workerID
  427. var err error
  428. rt := 0
  429. for {
  430. select {
  431. case <-me.ctx.Done(): //接到信号退出
  432. return
  433. default:
  434. if task.Data.TaskStatus >= All_Done {
  435. goto endfor //原来的break操作加了switch一层不好使了
  436. }
  437. switch task.Data.TaskStatus {
  438. case Task_Ready:
  439. //请求API平台
  440. // succ: taskStatus Task_Ready => ReqApi_Done
  441. // fail: taskStatus保持当前不变或Task_Fail
  442. rt, err = me.requestAPI(task.Data)
  443. case ReqApi_Done:
  444. //结果回调
  445. // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
  446. // fail: taskStatus保持当前不变或Task_Fail
  447. rt, err = me.requestCallback(task.Data)
  448. }
  449. if err != nil {
  450. //收集错误
  451. if rt == 0 {
  452. //不可恢复错误处理....
  453. logx.Errorf("Task error by '%s'", err)
  454. } else {
  455. logx.Debugf("Task ignore by '%s'", err)
  456. }
  457. return //先暂时忽略处理,也许应按错误类型分别对待
  458. }
  459. }
  460. }
  461. endfor:
  462. }
  463. func (me *AsyncTask) batchWork() (int64, int64) {
  464. var (
  465. wg sync.WaitGroup
  466. produced int64 //生产数量(原子计数器)
  467. consumed int64 //消费数量(原子计数器)
  468. )
  469. //创建任务分发器
  470. dispatcher := NewTaskDispatcher(me.Conf.MaxChannel,
  471. me.Conf.BatchLoadTask/me.Conf.MaxChannel, me.Conf.Debug)
  472. //启动消费者
  473. for widx := range me.Conf.MaxWorker {
  474. cidx, ch := dispatcher.getWorkerChanByIdx(widx)
  475. wg.Add(1)
  476. go func(workerID uint, channelID uint, taskCh chan Task) {
  477. defer wg.Done()
  478. gidStr := ""
  479. if me.Conf.Debug {
  480. gid, _ := getGoroutineId()
  481. gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
  482. logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......",
  483. workerID, gidStr, channelID)
  484. }
  485. for task := range taskCh {
  486. me.processTask(widx, task)
  487. atomic.AddInt64(&consumed, 1)
  488. }
  489. }(widx+1, cidx+1, ch)
  490. }
  491. // 生产者
  492. wg.Add(1)
  493. go func() {
  494. defer wg.Done()
  495. gidStr := ""
  496. if me.Conf.Debug {
  497. gid, _ := getGoroutineId()
  498. gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
  499. logx.Infof("Producer @1%s start......", gidStr)
  500. }
  501. //获得待处理异步任务列表
  502. //tasks, err := me.getAsyncReqTaskList()
  503. tasks, err := me.getAsyncReqTaskFairList()
  504. if err != nil {
  505. logx.Errorf("getAsyncReqTaskFairList err:%s", err)
  506. return
  507. }
  508. // 分发任务
  509. for _, task := range tasks {
  510. dispatcher.Dispatch(task)
  511. atomic.AddInt64(&produced, 1)
  512. }
  513. fmt.Printf("📦Producer @1 此批次共创建任务%d件\n", len(tasks))
  514. // 关闭所有会话通道
  515. dispatcher.mu.Lock()
  516. for _, ch := range dispatcher.workerChs {
  517. _ = ch
  518. close(ch)
  519. }
  520. dispatcher.mu.Unlock()
  521. }()
  522. wg.Wait()
  523. consumedRatStr := ""
  524. if atomic.LoadInt64(&produced) > 0 {
  525. consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed),
  526. (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
  527. }
  528. fmt.Printf("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s\n", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr)
  529. return produced, consumed
  530. }
  531. func (me *AsyncTask) InitServiceContext() *svc.ServiceContext {
  532. rds := me.Conf.RedisConf.MustNewUniversalRedis()
  533. dbOpts := []ent.Option{ent.Log(logx.Info),
  534. ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())}
  535. if me.Conf.Debug {
  536. dbOpts = append(dbOpts, ent.Debug())
  537. }
  538. db := ent.NewClient(dbOpts...)
  539. svcCtx := svc.ServiceContext{DB: db, Rds: rds}
  540. //svcCtx.Config
  541. return &svcCtx
  542. }
  543. func (me *AsyncTask) adjustConf() {
  544. if me.Conf.MaxWorker <= 0 {
  545. me.Conf.MaxWorker = 2
  546. }
  547. if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker {
  548. me.Conf.MaxChannel = me.Conf.MaxWorker
  549. }
  550. }
  551. func (me *AsyncTask) Run(ctx context.Context) {
  552. me.ctx = ctx
  553. me.Logger = logx.WithContext(ctx)
  554. me.adjustConf()
  555. /*
  556. tasks, err := me.getAsyncReqTaskFairList()
  557. if err != nil {
  558. fmt.Println(err)
  559. return
  560. }
  561. for idx, task := range tasks {
  562. if idx > 20 {
  563. break
  564. }
  565. fmt.Printf("#%d=>%d ||[%s]'%s' || '%s' || '%s'|| '%s'\n", idx, task.Data.ID, task.Data.CreatedAt, task.Data.EventType,
  566. task.Data.Model, task.Data.OpenaiBase, task.Data.ChatID)
  567. }
  568. */
  569. batchId := 0
  570. for {
  571. batchId++
  572. select {
  573. case <-ctx.Done():
  574. // 收到了取消信号
  575. fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err())
  576. return
  577. default:
  578. timeStart := time.Now()
  579. secStart := timeStart.Unix()
  580. fmt.Printf("[%s]batchWork#%d start......\n", timeStart.Format("2006-01-02 15:04:05"), batchId)
  581. produced, _ := me.batchWork()
  582. timeEnd := time.Now()
  583. fmt.Printf("[%s]batchWork#%d end,spent %d sec\n", timeEnd.Format("2006-01-02 15:04:05"),
  584. batchId, timeEnd.Unix()-secStart)
  585. timeDurnum := 1
  586. if produced == 0 {
  587. timeDurnum = 5
  588. }
  589. fmt.Printf("batchWork will sleep %d sec\n", timeDurnum)
  590. time.Sleep(time.Duration(timeDurnum) * time.Second)
  591. }
  592. }
  593. }
  594. func NewAsyncTask() *AsyncTask {
  595. ataskObj := AsyncTask{}
  596. flag.Parse() //将命令行参数也塞入flag.CommandLine结构
  597. //fmt.Println(typekit.PrettyPrint(flag.CommandLine))
  598. conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv())
  599. //fmt.Println(typekit.PrettyPrint(ataskObj.Conf))
  600. ataskObj.svcCtx = ataskObj.InitServiceContext()
  601. return &ataskObj
  602. }
  603. func main() {
  604. fmt.Println("Compapi Asynctask Start......")
  605. //ctx, cancel := context.WithCancel(context.Background())
  606. ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
  607. defer cancel()
  608. NewAsyncTask().Run(ctx)
  609. fmt.Println("Compapi Asynctask End......")
  610. }