package crontask import ( "bytes" "encoding/json" "errors" "fmt" "hash/fnv" "runtime" "strconv" "strings" "sync" "sync/atomic" "time" "wechat-api/ent" "wechat-api/ent/compapiasynctask" "wechat-api/ent/predicate" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "github.com/zeromicro/go-zero/core/logx" ) const ( Task_Ready = 10 //任务就绪 ReqApi_Done = 20 //请求API完成 Callback_Done = 30 //请求回调完成 All_Done = 30 //全部完成 Task_Suspend = 60 //任务暂停 Task_Fail = 70 //任务失败 //MaxWorker = 5 //最大并发worker数量 MaxLoadTask = 1000 //一次最大获取任务数量 LoopTryCount = 3 //循环体内重试次数 ErrTaskTryCount = 3 //最大允许错误任务重试次数 DefaultDisId = "DIS0001" ) type Task struct { Data *ent.CompapiAsynctask Idx int Code int } // 带会话管理的任务通道组 type TaskDispatcher struct { mu sync.Mutex workerChs []chan Task // 每个worker独立通道 } func NewTaskDispatcher(channelCount int) *TaskDispatcher { td := &TaskDispatcher{ workerChs: make([]chan Task, channelCount), } // 初始化worker通道 for i := range td.workerChs { td.workerChs[i] = make(chan Task, 100) // 每个worker带缓冲 //fmt.Printf("make worker chan:%d\n", i+1) } return td } // 按woker下标索引选择workerChs func (td *TaskDispatcher) getWorkerChanByIdx(idx int) (int, chan Task) { nidx := idx % len(td.workerChs) return nidx, td.workerChs[nidx] } // 按哈希分片选择workerChs func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) { if len(disID) == 0 { disID = DefaultDisId } idx := 0 if len(td.workerChs) > 1 { h := fnv.New32a() h.Write([]byte(disID)) idx = int(h.Sum32()) % len(td.workerChs) } outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID) for i := range len(td.workerChs) { outStr += fmt.Sprintf("#%d", i+1) if i < len(td.workerChs)-1 { outStr += "," } } outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID) logx.Debug(outStr) return idx, td.workerChs[idx] } // 分配任务到对应的消费channel func (td *TaskDispatcher) Dispatch(task Task) { td.mu.Lock() defer td.mu.Unlock() // 根据chatId哈希获得其对应的workerChan workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType) // 将任务送入该chatid的workerChan workerCh <- task logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1) } func getGoroutineId() (int64, error) { // 堆栈结果中需要消除的前缀符 var goroutineSpace = []byte("goroutine ") bs := make([]byte, 128) bs = bs[:runtime.Stack(bs, false)] bs = bytes.TrimPrefix(bs, goroutineSpace) i := bytes.IndexByte(bs, ' ') if i < 0 { return -1, errors.New("get current goroutine id failed") } return strconv.ParseInt(string(bs[:i]), 10, 64) } func (l *CronTask) CliTest(MaxWorker int, MaxChannel int) { l.compApiCallback(MaxWorker, MaxChannel) } func (l *CronTask) compApiCallback(MaxWorker int, MaxChannel int) { var ( wg sync.WaitGroup produced int64 //生产数量(原子计数器) consumed int64 //消费数量(原子计数器) ) //创建任务分发器 if MaxWorker <= 0 { MaxWorker = 2 } if MaxChannel <= 0 || MaxChannel > MaxWorker { MaxChannel = MaxWorker } dispatcher := NewTaskDispatcher(MaxChannel) //启动消费者 for widx := range MaxWorker { cidx, ch := dispatcher.getWorkerChanByIdx(widx) wg.Add(1) go func(workerID int, channelID int, taskCh chan Task) { defer wg.Done() gid, _ := getGoroutineId() logx.Infof("Consumer %d(Goroutine:%d) bind WorkerChan:#%d start......\n", workerID, gid, channelID) for task := range taskCh { l.processTask(workerID, task) atomic.AddInt64(&consumed, 1) } }(widx+1, cidx+1, ch) } // 生产者 wg.Add(1) go func() { defer wg.Done() gid, _ := getGoroutineId() logx.Infof("Producer 1(Goroutine:%d) start......\n", gid) //获得待处理异步任务列表 tasks, err := l.getAsyncReqTaskList() if err != nil { logx.Errorf("getAsyncReqTaskList err:%s", err) return } // 分发任务 for _, task := range tasks { dispatcher.Dispatch(task) atomic.AddInt64(&produced, 1) } logx.Infof("📦Producer 1 此批次共创建任务%d件", len(tasks)) // 关闭所有会话通道 dispatcher.mu.Lock() for _, ch := range dispatcher.workerChs { _ = ch close(ch) } dispatcher.mu.Unlock() }() wg.Wait() if atomic.LoadInt64(&produced) > 0 { logx.Infof("🏁本次任务完成度统计: Producer:1,Consumer:%d (%d/%d)*100=%d%%", MaxWorker, atomic.LoadInt64(&consumed), atomic.LoadInt64(&produced), (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100) } } func (l *CronTask) getAsyncReqTaskList() ([]Task, error) { var predicates []predicate.CompapiAsynctask predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done)) var tasks []Task res, err := l.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...). Order(ent.Asc(compapiasynctask.FieldID)). Limit(MaxLoadTask). All(l.ctx) if err == nil { for idx, val := range res { tasks = append(tasks, Task{Data: val, Idx: idx}) } } return tasks, err } func (l *CronTask) processTask(workerID int, task Task) { /* fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n", workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)*/ _ = workerID var err error rt := 0 for { if task.Data.TaskStatus >= All_Done { break } switch task.Data.TaskStatus { case Task_Ready: //请求API平台 // succ: taskStatus Task_Ready => ReqApi_Done // fail: taskStatus保持当前不变或Task_Fail rt, err = l.requestAPI(task.Data) case ReqApi_Done: //结果回调 // succ: taskStatus ReqApi_Done => Callback_Done(All_Done) // fail: taskStatus保持当前不变或Task_Fail rt, err = l.requestCallback(task.Data) } if err != nil { //收集错误 if rt == 0 { //不可恢复错误处理.... } logx.Debugf("Task ignore by '%s'", err) return //先暂时忽略处理,也许应按错误类型分别对待 } } } func (l *CronTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) { workerId, _ := getGoroutineId() logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus) if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != ReqApi_Done { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } if len(taskData.CallbackURL) == 0 { return 0, errors.New("callback url empty") } if len(taskData.ResponseRaw) == 0 { return 0, errors.New("call api response empty") } fstr := "mytest-svc:" if taskData.RetryCount > 0 && strings.Contains(taskData.CallbackURL, fstr) { taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1) } //先开启事务更新任务状态 => Callback_Done(回调完成) tx, err := l.updateTaskStatusByTx(taskData.ID, Callback_Done) if err != nil { return 0, err } //请求预先给定的callback_url var res map[string]any //初始化client client := compapi.NewClient(l.ctx) for i := range LoopTryCount { //重试机制 res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData.ResponseRaw) //_, err = client.Chat.Completions.New(l.ctx, emptyParams, opts...) if err == nil { //call succ //fmt.Println("callback succ..........") //fmt.Println(typekit.PrettyPrint(res)) logx.Infof("callback:'%s' succ", taskData.CallbackURL) break } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*5) * time.Second) } //多次循环之后依然失败,进入错误任务处理环节 if err != nil { _ = tx.Rollback() //回滚之前更新状态 //fmt.Printf("Worker:%d client.Chat.Completions.New Fail,Rollback......\n", workerId) err1 := l.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 err = tx.Commit() //事务提交 if err != nil { return 0, err } //更新taskData.CallbackResponseRaw l.updateCallbackResponse(taskData.ID, res) taskData.TaskStatus = Callback_Done //状态迁移 return 1, nil } func (l *CronTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) { workerId, _ := getGoroutineId() logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus) if needStop, _ := l.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != Task_Ready { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } var ( err error apiResp *types.CompOpenApiResp tx *ent.Tx ) req := types.CompApiReq{} if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil { return 0, err } //先开启事务更新任务状态 => ReqApi_Done(请求API完成) tx, err = l.updateTaskStatusByTx(taskData.ID, ReqApi_Done) if err != nil { return 0, err } //初始化client client := compapi.NewClient(l.ctx, compapi.WithApiBase(taskData.OpenaiBase), compapi.WithApiKey(taskData.OpenaiKey)) for i := range LoopTryCount { //重试机制 apiResp, err = client.Chat(&req) if err == nil && apiResp != nil && len(apiResp.Choices) > 0 { //call succ break } else if apiResp != nil && len(apiResp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问权限") } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*5) * time.Second) } //多次循环之后依然失败,进入错误任务处理环节 if err != nil || apiResp == nil { if apiResp == nil && err == nil { err = errors.New("resp is null") } _ = tx.Rollback() //回滚之前更新状态 err1 := l.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 //更新taskData.ResponseRaw taskData.ResponseRaw, err = (*apiResp).ToString() if err != nil { _ = tx.Rollback() //回滚之前更新状态 return 0, err } err = l.updateApiResponseByTx(tx, taskData.ID, taskData.ResponseRaw) if err != nil { return 0, err } err = tx.Commit() //事务提交 //fmt.Printf("Worker:%d requestAPI事务提交\n", workerId) if err != nil { return 0, err } taskData.TaskStatus = ReqApi_Done //状态迁移 return 1, nil } // 更新任务状态事务版 func (l *CronTask) updateTaskStatusByTx(Id uint64, status int) (*ent.Tx, error) { //开启Mysql事务 tx, _ := l.svcCtx.DB.Tx(l.ctx) _, err := tx.CompapiAsynctask.UpdateOneID(Id). SetTaskStatus(int8(status)). SetUpdatedAt(time.Now()). Save(l.ctx) if err != nil { return nil, err } return tx, nil } // 更新请求大模型后结果事务版 func (l *CronTask) updateApiResponseByTx(tx *ent.Tx, taskId uint64, apiResponse string) error { _, err := tx.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetResponseRaw(apiResponse). SetLastError(""). SetRetryCount(0). Save(l.ctx) if err != nil { _ = tx.Rollback() //回滚之前更新状态 } return err } func (l *CronTask) updateCallbackResponse(taskId uint64, callRes any) error { callResStr := "" switch v := callRes.(type) { case []byte: callResStr = string(v) default: if bs, err := json.Marshal(v); err == nil { callResStr = string(bs) } else { return err } } _, err := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetCallbackResponseRaw(callResStr). SetLastError(""). SetRetryCount(0). Save(l.ctx) return err } func (l *CronTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) { var err error var needStop = false if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 _, err = l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()). SetTaskStatus(int8(Task_Fail)). Save(l.ctx) if err == nil { needStop = true taskData.TaskStatus = Task_Fail } } return needStop, err } // 错误任务处理 func (l *CronTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error { logx.Debug("多次循环之后依然失败,进入错误任务处理环节") cauo := l.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()) if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 taskData.TaskStatus = Task_Fail cauo = cauo.SetTaskStatus(int8(Task_Fail)) } else { cauo = cauo.SetRetryCount(taskData.RetryCount + 1). SetLastError(lasterr.Error()) } _, err := cauo.Save(l.ctx) return err }