|
- package crontask
- import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "hash/fnv"
- "runtime"
- "strconv"
- "sync"
- "sync/atomic"
- "time"
- "wechat-api/ent"
- "wechat-api/ent/compapiasynctask"
- "wechat-api/ent/predicate"
- "wechat-api/internal/types"
- "wechat-api/internal/utils/compapi"
- "github.com/openai/openai-go/option"
- "github.com/zeromicro/go-zero/core/logx"
- )
- const (
- Task_Ready = 10 //任务就绪
- ReqApi_Done = 20 //请求API完成
- Callback_Done = 30 //请求回调完成
- All_Done = 30 //全部完成
- Task_Suspend = 60 //任务暂停
- Task_Fail = 70 //任务失败
- //MaxWorker = 5 //最大并发worker数量
- MaxLoadTask = 1000 //一次最大获取任务数量
- LoopTryCount = 3 //循环体内重试次数
- ErrTaskTryCount = 3 //最大允许错误任务重试次数
- DefaultChatId = "FOO001"
- )
- type Task struct {
- Data *ent.CompapiAsynctask
- Idx int
- Code int
- }
- // 带会话管理的任务通道组
- type TaskDispatcher struct {
- mu sync.Mutex
- workerChs []chan Task // 每个worker独立通道
- }
- func NewTaskDispatcher(workerCount int) *TaskDispatcher {
- td := &TaskDispatcher{
- workerChs: make([]chan Task, workerCount),
- }
- // 初始化worker通道
- for i := range td.workerChs {
- td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲
- //fmt.Printf("make worker chan:%d\n", i+1)
- }
- return td
- }
- // 哈希分片算法(确保相同chatid路由到同一个worker)
- func (td *TaskDispatcher) getWorkerChannel(chatId string) (int, chan Task) {
- if len(chatId) == 0 {
- chatId = DefaultChatId
- }
- h := fnv.New32a()
- h.Write([]byte(chatId))
- idx := int(h.Sum32()) % len(td.workerChs)
- return idx, td.workerChs[idx]
- }
- // 分配任务到对应的消费channel
- func (td *TaskDispatcher) Dispatch(task Task) {
- td.mu.Lock()
- defer td.mu.Unlock()
- // 根据chatId哈希获得其对应的workerChan
- workerChId, workerCh := td.getWorkerChannel(task.Data.ChatID)
- // 将任务送入该chatid的workerChan
- workerCh <- task
- logx.Debugf("producer:ChatId:[%s] Task Push to WorkerChan:%d", task.Data.ChatID, workerChId+1)
- }
- func getGoroutineId() (int64, error) {
- // 堆栈结果中需要消除的前缀符
- var goroutineSpace = []byte("goroutine ")
- bs := make([]byte, 128)
- bs = bs[:runtime.Stack(bs, false)]
- bs = bytes.TrimPrefix(bs, goroutineSpace)
- i := bytes.IndexByte(bs, ' ')
- if i < 0 {
- return -1, errors.New("get current goroutine id failed")
- }
- return strconv.ParseInt(string(bs[:i]), 10, 64)
- }
- func (l *CronTask) compApiCallback(MaxWorker int) {
- var (
- wg sync.WaitGroup
- produced int64 //生产数量(原子计数器)
- consumed int64 //消费数量(原子计数器)
- )
- //创建任务分发器
- if MaxWorker <= 0 {
- MaxWorker = 2
- }
- dispatcher := NewTaskDispatcher(MaxWorker)
- //启动消费者
- for i, ch := range dispatcher.workerChs {
- wg.Add(1)
- go func(workerID int, taskCh chan Task) {
- defer wg.Done()
- workerId, _ := getGoroutineId()
- logx.Infof("Consumer Goroutine:%d start......\n", workerId)
- for task := range taskCh {
- l.processTask(workerID, task)
- atomic.AddInt64(&consumed, 1)
- }
- }(i+1, ch)
- }
- // 生产者
- wg.Add(1)
- go func() {
- defer wg.Done()
- workerId, _ := getGoroutineId()
- logx.Infof("Producer Goroutine:%d start......\n", workerId)
- //获得待处理异步任务列表
- tasks, err := l.getAsyncReqTaskList()
- if err != nil {
- logx.Errorf("getAsyncReqTaskList err:%s", err)
- return
- }
- // 分发任务
- for _, task := range tasks {
- dispatcher.Dispatch(task)
- atomic.AddInt64(&produced, 1)
- }
- logx.Infof("📦Producer Goroutine:%d 此批次共创建任务%d件", workerId, len(tasks))
- // 关闭所有会话通道
- dispatcher.mu.Lock()
- for _, ch := range dispatcher.workerChs {
- _ = ch
- close(ch)
- }
- dispatcher.mu.Unlock()
- }()
- wg.Wait()
- if atomic.LoadInt64(&produced) > 0 {
- logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced),
- (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
- }
- }
- func (l *CronTask) getAsyncReqTaskList() ([]Task, error) {
- var predicates []predicate.CompapiAsynctask
- predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
- var tasks []Task
- res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
- Order(ent.Asc(compapiasynctask.FieldID)).
- Limit(MaxLoadTask).
- All(l.ctx)
- if err == nil {
- for idx, val := range res {
- tasks = append(tasks, Task{Data: val, Idx: idx})
- }
- }
- return tasks, err
- }
- func (l *CronTask) processTask(workerID int, task Task) {
- //fmt.Printf("In processTask,Consumer(%d) dealing\n", workerID)
- //fmt.Printf("Task Detail: User(%s/%s) Async Call %s\n", task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase)
- _ = workerID
- var err error
- rt := 0
- for {
- if task.Data.TaskStatus >= All_Done {
- break
- }
- switch task.Data.TaskStatus {
- case Task_Ready:
- //请求API平台
- // succ: taskStatus Task_Ready => ReqApi_Done
- // fail: taskStatus保持当前不变或Task_Fail
- rt, err = l.requestAPI(task.Data)
- case ReqApi_Done:
- //结果回调
- // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
- // fail: taskStatus保持当前不变或Task_Fail
- rt, err = l.requestCallback(task.Data)
- }
- if err != nil {
- //收集错误
- if rt == 0 {
- //不可恢复错误处理....
- }
- //fmt.Println("===>ERROR:", err, ",Task Ignore...")
- return //先暂时忽略处理,也许应按错误类型分别对待
- }
- }
- }
- func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
- workerId, _ := getGoroutineId()
- logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
- if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
- return 1, errors.New("too many err retry")
- }
- if taskData.TaskStatus != ReqApi_Done {
- return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
- }
- req := types.CompOpenApiResp{}
- if len(taskData.ResponseRaw) == 0 {
- return 0, errors.New("call api response empty")
- }
- if len(taskData.CallbackURL) == 0 {
- return 0, errors.New("callback url empty")
- }
- if err := json.Unmarshal([]byte(taskData.ResponseRaw), &req); err != nil {
- return 0, err
- }
- //先开启事务更新任务状态 => Callback_Done(回调完成)
- tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done)
- if err != nil {
- return 0, err
- }
- //请求预先给定的callback_url
- client := compapi.NewAiClient("", taskData.CallbackURL)
- //emptyParams := openai.ChatCompletionNewParams{}
- customResp := types.BaseDataInfo{}
- opts := []option.RequestOption{option.WithResponseBodyInto(&customResp)}
- opts = append(opts, option.WithRequestBody("application/json", []byte(taskData.ResponseRaw)))
- for i := range LoopTryCount { //重试机制
- err = client.Post(l.ctx, taskData.CallbackURL, nil, nil, opts...)
- //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...)
- if err == nil {
- //call succ
- break
- }
- logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
- taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
- time.Sleep(time.Duration(1+i*5) * time.Second)
- }
- if err != nil {
- _ = tx.Rollback() //回滚之前更新状态
- //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId)
- err1 := l.dealErrorTask(taskData, err) //错误任务处理
- et := 1
- if err1 != nil {
- et = 0
- }
- return et, err
- }
- err = tx.Commit() //事务提交
- //fmt.Printf("Worker:%d requestCallback事务提交\n", workerId)
- if err != nil {
- return 0, err
- }
- taskData.TaskStatus = Callback_Done //状态迁移
- return 1, nil
- }
- func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
- workerId, _ := getGoroutineId()
- logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
- if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
- return 1, errors.New("too many err retry")
- }
- if taskData.TaskStatus != Task_Ready {
- return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
- }
- if taskData.EventType != "fastgpt" {
- return 0, fmt.Errorf("event type :'%s' not support", taskData.EventType)
- }
- var (
- err error
- apiResp *types.CompOpenApiResp
- tx *ent.Tx
- )
- req := types.CompApiReq{}
- if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
- return 0, err
- }
- //先开启事务更新任务状态 => ReqApi_Done(请求API完成)
- tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done)
- if err != nil {
- return 0, err
- }
- for i := range LoopTryCount { //重试机制
- apiResp, err = compapi.NewFastgptChatCompletions(l.ctx,
- taskData.OpenaiKey, taskData.OpenaiBase, &req)
- if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
- //call succ
- break
- } else if apiResp != nil && len(apiResp.Choices) == 0 {
- err = errors.New("返回结果缺失,请检查访问权限")
- }
- logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
- taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount)
- time.Sleep(time.Duration(1+i*5) * time.Second)
- }
- if err != nil || apiResp == nil {
- if apiResp == nil && err == nil {
- err = errors.New("resp is null")
- }
- _ = tx.Rollback() //回滚之前更新状态
- //fmt.Printf("Worker:%d NewFastgptChatCompletions Fail,Rollback......\n", workerId)
- err1 := l.dealErrorTask(taskData, err) //错误任务处理
- et := 1
- if err1 != nil {
- et = 0
- }
- return et, err
- }
- respBs, err := json.Marshal(*apiResp)
- if err != nil {
- _ = tx.Rollback() //回滚之前更新状态
- return 0, err
- }
- taskData.ResponseRaw = string(respBs)
- _, err = tx.CompapiAsynctask.UpdateOneID(taskData.ID).
- SetResponseRaw(taskData.ResponseRaw).
- Save(l.ctx)
- if err != nil {
- _ = tx.Rollback() //回滚之前更新状态
- return 0, err
- }
- err = tx.Commit() //事务提交
- //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId)
- if err != nil {
- return 0, err
- }
- taskData.TaskStatus = ReqApi_Done //状态迁移
- return 1, nil
- }
- // 更新任务状态事务版
- func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) {
- //开启Mysql事务
- tx, _ := l.svcCtx.DB.Tx(l.ctx)
- _, err := tx.CompapiAsynctask.UpdateOneID(Id).
- SetTaskStatus(int8(status)).
- SetUpdatedAt(time.Now()).
- Save(l.ctx)
- if err != nil {
- return nil, err
- }
- return tx, nil
- }
- func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
- var err error
- var needStop = false
- if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
- _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
- SetUpdatedAt(time.Now()).
- SetTaskStatus(int8(Task_Fail)).
- Save(l.ctx)
- if err == nil {
- needStop = true
- taskData.TaskStatus = Task_Fail
- }
- }
- return needStop, err
- }
- // 错误任务处理
- func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
- logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
- cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
- SetUpdatedAt(time.Now())
- if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
- taskData.TaskStatus = Task_Fail
- cauo = cauo.SetTaskStatus(int8(Task_Fail))
- } else {
- cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
- SetLastError(lasterr.Error())
- }
- _, err := cauo.Save(l.ctx)
- return err
- }
|