package crontask import ( "bytes" "encoding/json" "errors" "fmt" "hash/fnv" "runtime" "strconv" "sync" "sync/atomic" "time" "wechat-api/ent" "wechat-api/ent/compapiasynctask" "wechat-api/ent/predicate" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "github.com/openai/openai-go/option" "github.com/zeromicro/go-zero/core/logx" ) const ( Task_Ready = 10 //任务就绪 ReqApi_Done = 20 //请求API完成 Callback_Done = 30 //请求回调完成 All_Done = 30 //全部完成 Task_Suspend = 60 //任务暂停 Task_Fail = 70 //任务失败 //MaxWorker = 5 //最大并发worker数量 MaxLoadTask = 1000 //一次最大获取任务数量 LoopTryCount = 3 //循环体内重试次数 ErrTaskTryCount = 3 //最大允许错误任务重试次数 DefaultChatId = "FOO001" ) type Task struct { Data *ent.CompapiAsynctask Idx int Code int } // 带会话管理的任务通道组 type TaskDispatcher struct { mu sync.Mutex workerChs []chan Task // 每个worker独立通道 } func NewTaskDispatcher(workerCount int) *TaskDispatcher { td := &TaskDispatcher{ workerChs: make([]chan Task, workerCount), } // 初始化worker通道 for i := range td.workerChs { td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲 //fmt.Printf("make worker chan:%d\n", i+1) } return td } // 哈希分片算法(确保相同chatid路由到同一个worker) func (td *TaskDispatcher) getWorkerChannel(chatId string) (int, chan Task) { if len(chatId) == 0 { chatId = DefaultChatId } h := fnv.New32a() h.Write([]byte(chatId)) idx := int(h.Sum32()) % len(td.workerChs) return idx, td.workerChs[idx] } // 分配任务到对应的消费channel func (td *TaskDispatcher) Dispatch(task Task) { td.mu.Lock() defer td.mu.Unlock() // 根据chatId哈希获得其对应的workerChan workerChId, workerCh := td.getWorkerChannel(task.Data.ChatID) // 将任务送入该chatid的workerChan workerCh <- task logx.Debugf("producer:ChatId:[%s] Task Push to WorkerChan:%d", task.Data.ChatID, workerChId+1) } func getGoroutineId() (int64, error) { // 堆栈结果中需要消除的前缀符 var goroutineSpace = []byte("goroutine ") bs := make([]byte, 128) bs = bs[:runtime.Stack(bs, false)] bs = bytes.TrimPrefix(bs, goroutineSpace) i := bytes.IndexByte(bs, ' ') if i < 0 { return -1, errors.New("get current goroutine id failed") } return strconv.ParseInt(string(bs[:i]), 10, 64) } func (l *CronTask) compApiCallback(MaxWorker int) { var ( wg sync.WaitGroup produced int64 //生产数量(原子计数器) consumed int64 //消费数量(原子计数器) ) //创建任务分发器 if MaxWorker <= 0 { MaxWorker = 2 } dispatcher := NewTaskDispatcher(MaxWorker) //启动消费者 for i, ch := range dispatcher.workerChs { wg.Add(1) go func(workerID int, taskCh chan Task) { defer wg.Done() workerId, _ := getGoroutineId() logx.Infof("Consumer Goroutine:%d start......\n", workerId) for task := range taskCh { l.processTask(workerID, task) atomic.AddInt64(&consumed, 1) } }(i+1, ch) } // 生产者 wg.Add(1) go func() { defer wg.Done() workerId, _ := getGoroutineId() logx.Infof("Producer Goroutine:%d start......\n", workerId) //获得待处理异步任务列表 tasks, err := l.getAsyncReqTaskList() if err != nil { logx.Errorf("getAsyncReqTaskList err:%s", err) return } // 分发任务 for _, task := range tasks { dispatcher.Dispatch(task) atomic.AddInt64(&produced, 1) } logx.Infof("📦Producer Goroutine:%d 此批次共创建任务%d件", workerId, len(tasks)) // 关闭所有会话通道 dispatcher.mu.Lock() for _, ch := range dispatcher.workerChs { _ = ch close(ch) } dispatcher.mu.Unlock() }() wg.Wait() if atomic.LoadInt64(&produced) > 0 { logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced), (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100) } } func (l *CronTask) getAsyncReqTaskList() ([]Task, error) { var predicates []predicate.CompapiAsynctask predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done)) var tasks []Task res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...). Order(ent.Asc(compapiasynctask.FieldID)). Limit(MaxLoadTask). All(l.ctx) if err == nil { for idx, val := range res { tasks = append(tasks, Task{Data: val, Idx: idx}) } } return tasks, err } func (l *CronTask) processTask(workerID int, task Task) { //fmt.Printf("In processTask,Consumer(%d) dealing\n", workerID) //fmt.Printf("Task Detail: User(%s/%s) Async Call %s\n", task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase) _ = workerID var err error rt := 0 for { if task.Data.TaskStatus >= All_Done { break } switch task.Data.TaskStatus { case Task_Ready: //请求API平台 // succ: taskStatus Task_Ready => ReqApi_Done // fail: taskStatus保持当前不变或Task_Fail rt, err = l.requestAPI(task.Data) case ReqApi_Done: //结果回调 // succ: taskStatus ReqApi_Done => Callback_Done(All_Done) // fail: taskStatus保持当前不变或Task_Fail rt, err = l.requestCallback(task.Data) } if err != nil { //收集错误 if rt == 0 { //不可恢复错误处理.... } //fmt.Println("===>ERROR:", err, ",Task Ignore...") return //先暂时忽略处理,也许应按错误类型分别对待 } } } func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) { workerId, _ := getGoroutineId() logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus) if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != ReqApi_Done { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } req := types.CompOpenApiResp{} if len(taskData.ResponseRaw) == 0 { return 0, errors.New("call api response empty") } if len(taskData.CallbackURL) == 0 { return 0, errors.New("callback url empty") } if err := json.Unmarshal([]byte(taskData.ResponseRaw), &req); err != nil { return 0, err } //先开启事务更新任务状态 => Callback_Done(回调完成) tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done) if err != nil { return 0, err } //请求预先给定的callback_url client := compapi.NewAiClient("", taskData.CallbackURL) //emptyParams := openai.ChatCompletionNewParams{} customResp := types.BaseDataInfo{} opts := []option.RequestOption{option.WithResponseBodyInto(&customResp)} opts = append(opts, option.WithRequestBody("application/json", []byte(taskData.ResponseRaw))) for i := range LoopTryCount { //重试机制 err = client.Post(l.ctx, taskData.CallbackURL, nil, nil, opts...) //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...) if err == nil { //call succ break } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*5) * time.Second) } if err != nil { _ = tx.Rollback() //回滚之前更新状态 //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId) err1 := l.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } err = tx.Commit() //事务提交 //fmt.Printf("Worker:%d requestCallback事务提交\n", workerId) if err != nil { return 0, err } taskData.TaskStatus = Callback_Done //状态迁移 return 1, nil } func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) { workerId, _ := getGoroutineId() logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus) if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != Task_Ready { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } if taskData.EventType != "fastgpt" { return 0, fmt.Errorf("event type :'%s' not support", taskData.EventType) } var ( err error apiResp *types.CompOpenApiResp tx *ent.Tx ) req := types.CompApiReq{} if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil { return 0, err } //先开启事务更新任务状态 => ReqApi_Done(请求API完成) tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done) if err != nil { return 0, err } for i := range LoopTryCount { //重试机制 apiResp, err = compapi.NewFastgptChatCompletions(l.ctx, taskData.OpenaiKey, taskData.OpenaiBase, &req) if err == nil && apiResp != nil && len(apiResp.Choices) > 0 { //call succ break } else if apiResp != nil && len(apiResp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问权限") } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*5) * time.Second) } if err != nil || apiResp == nil { if apiResp == nil && err == nil { err = errors.New("resp is null") } _ = tx.Rollback() //回滚之前更新状态 //fmt.Printf("Worker:%d NewFastgptChatCompletions Fail,Rollback......\n", workerId) err1 := l.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } respBs, err := json.Marshal(*apiResp) if err != nil { _ = tx.Rollback() //回滚之前更新状态 return 0, err } taskData.ResponseRaw = string(respBs) _, err = tx.CompapiAsynctask.UpdateOneID(taskData.ID). SetResponseRaw(taskData.ResponseRaw). Save(l.ctx) if err != nil { _ = tx.Rollback() //回滚之前更新状态 return 0, err } err = tx.Commit() //事务提交 //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId) if err != nil { return 0, err } taskData.TaskStatus = ReqApi_Done //状态迁移 return 1, nil } // 更新任务状态事务版 func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) { //开启Mysql事务 tx, _ := l.svcCtx.DB.Tx(l.ctx) _, err := tx.CompapiAsynctask.UpdateOneID(Id). SetTaskStatus(int8(status)). SetUpdatedAt(time.Now()). Save(l.ctx) if err != nil { return nil, err } return tx, nil } func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) { var err error var needStop = false if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()). SetTaskStatus(int8(Task_Fail)). Save(l.ctx) if err == nil { needStop = true taskData.TaskStatus = Task_Fail } } return needStop, err } // 错误任务处理 func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error { logx.Debug("多次循环之后依然失败,进入错误任务处理环节") cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()) if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 taskData.TaskStatus = Task_Fail cauo = cauo.SetTaskStatus(int8(Task_Fail)) } else { cauo = cauo.SetRetryCount(taskData.RetryCount + 1). SetLastError(lasterr.Error()) } _, err := cauo.Save(l.ctx) return err }