package main import ( "bytes" "context" "encoding/json" "errors" "flag" "fmt" "hash/fnv" "os" "os/signal" "runtime" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" "wechat-api/ent" "wechat-api/ent/compapiasynctask" "wechat-api/ent/predicate" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "github.com/suyuan32/simple-admin-common/config" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" ) const ( Task_Ready = 10 //任务就绪 ReqApi_Done = 20 //请求API完成 Callback_Done = 30 //请求回调完成 All_Done = Callback_Done //全部完成 (暂时将成功状态的终点标注于此) Task_Suspend = 60 //任务暂停 Task_Fail = 70 //任务失败 LoopTryCount = 3 //循环体内重试次数 ErrTaskTryCount = 3 //最大允许错误任务重试次数 DefaultDisId = "DIS0001" ) type Config struct { BatchLoadTask uint `json:",default=100"` MaxWorker uint `json:",default=2"` MaxChannel uint `json:",default=1"` Debug bool `json:",default=false"` DatabaseConf config.DatabaseConf RedisConf config.RedisConf } type TaskStat struct { } type AsyncTask struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext Conf Config Stats TaskStat } type Task struct { Data *ent.CompapiAsynctask Idx int Code int } // 带会话管理的任务通道组 type TaskDispatcher struct { mu sync.Mutex workerChs []chan Task // 每个worker独立通道 } var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file") func getGoroutineId() (int64, error) { // 堆栈结果中需要消除的前缀符 var goroutineSpace = []byte("goroutine ") bs := make([]byte, 128) bs = bs[:runtime.Stack(bs, false)] bs = bytes.TrimPrefix(bs, goroutineSpace) i := bytes.IndexByte(bs, ' ') if i < 0 { return -1, errors.New("get current goroutine id failed") } return strconv.ParseInt(string(bs[:i]), 10, 64) } func NewTaskDispatcher(channelCount uint, chanSize uint) *TaskDispatcher { td := &TaskDispatcher{ workerChs: make([]chan Task, channelCount), } // 初始化worker通道 for i := range td.workerChs { td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲 } return td } // 按woker下标索引选择workerChs func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) { nidx := idx % uint(len(td.workerChs)) return nidx, td.workerChs[nidx] } // 按哈希分片选择workerChs func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) { if len(disID) == 0 { disID = DefaultDisId } idx := 0 if len(td.workerChs) > 1 { h := fnv.New32a() h.Write([]byte(disID)) idx = int(h.Sum32()) % len(td.workerChs) } outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID) for i := range len(td.workerChs) { outStr += fmt.Sprintf("#%d", i+1) if i < len(td.workerChs)-1 { outStr += "," } } outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID) logx.Debug(outStr) return idx, td.workerChs[idx] } // 分配任务到对应的消费channel func (td *TaskDispatcher) Dispatch(task Task) { td.mu.Lock() defer td.mu.Unlock() // 根据chatId哈希获得其对应的workerChan workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType) // 将任务送入该chatid的workerChan workerCh <- task logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1) } // 更新任务状态 func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error { _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetTaskStatus(int8(status)). SetUpdatedAt(time.Now()). Save(me.ctx) return err } // 更新请求大模型后结果 func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error { _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetResponseRaw(apiResponse). SetLastError(""). SetRetryCount(0). Save(me.ctx) return err } func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error { callResStr := "" switch v := callRes.(type) { case []byte: callResStr = string(v) default: if bs, err := json.Marshal(v); err == nil { callResStr = string(bs) } else { return err } } _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetCallbackResponseRaw(callResStr). SetLastError(""). SetRetryCount(0). Save(me.ctx) return err } func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) { var err error var needStop = false if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 _, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()). SetTaskStatus(int8(Task_Fail)). Save(me.ctx) if err == nil { needStop = true taskData.TaskStatus = Task_Fail } } return needStop, err } // 错误任务处理 func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error { logx.Debug("多次循环之后依然失败,进入错误任务处理环节") cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()) if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 taskData.TaskStatus = Task_Fail cauo = cauo.SetTaskStatus(int8(Task_Fail)) } else { cauo = cauo.SetRetryCount(taskData.RetryCount + 1). SetLastError(lasterr.Error()) } _, err := cauo.Save(me.ctx) return err } func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) { var workerId int64 if me.Conf.Debug { workerId, _ = getGoroutineId() logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus) } if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != ReqApi_Done { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } if len(taskData.CallbackURL) == 0 { return 0, errors.New("callback url empty") } if len(taskData.ResponseRaw) == 0 { return 0, errors.New("call api response empty") } /* fstr := "mytest-svc:" if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) { taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1) } */ //请求预先给定的callback_url res := map[string]any{} var err error //初始化client client := compapi.NewClient(me.ctx) for i := range LoopTryCount { //重试机制 select { case <-me.ctx.Done(): //接到信号退出 goto endloopTry default: res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData) if err == nil { //call succ //fmt.Println("callback succ..........") //fmt.Println(typekit.PrettyPrint(res)) logx.Infof("callback:'%s' succ", taskData.CallbackURL) goto endloopTry } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*5) * time.Second) } } //多次循环之后依然失败,进入错误任务处理环节 endloopTry: if err != nil { err1 := me.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 //更新任务状态 => Callback_Done(回调完成) err = me.updateTaskStatus(taskData.ID, Callback_Done) if err != nil { return 0, err } //更新taskData.CallbackResponseRaw if len(res) > 0 { me.updateCallbackResponse(taskData.ID, res) } taskData.TaskStatus = Callback_Done //状态迁移 return 1, nil } func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) { var workerId int64 if me.Conf.Debug { workerId, _ = getGoroutineId() logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus) } if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != Task_Ready { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } var ( err error apiResp *types.CompOpenApiResp ) req := types.CompApiReq{} if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil { return 0, err } //初始化client if !strings.HasSuffix(taskData.OpenaiBase, "/") { taskData.OpenaiBase = taskData.OpenaiBase + "/" } client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase), compapi.WithApiKey(taskData.OpenaiKey)) for i := range LoopTryCount { //重试机制 select { case <-me.ctx.Done(): //接到信号退出 goto endloopTry default: apiResp, err = client.Chat(&req) if err == nil && apiResp != nil && len(apiResp.Choices) > 0 { //call succ goto endloopTry } else if apiResp != nil && len(apiResp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问权限") } if me.Conf.Debug { logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*5, i+1, LoopTryCount, taskData.RetryCount) } time.Sleep(time.Duration(1+i*5) * time.Second) } } endloopTry: //多次循环之后依然失败,进入错误任务处理环节 if err != nil || apiResp == nil { if apiResp == nil && err == nil { err = errors.New("resp is null") } err1 := me.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 //更新任务状态 => ReqApi_Done(请求API完成) err = me.updateTaskStatus(taskData.ID, ReqApi_Done) if err != nil { return 0, err } //更新taskData.ResponseRaw taskData.ResponseRaw, err = (*apiResp).ToString() if err != nil { return 0, err } err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw) if err != nil { return 0, err } taskData.TaskStatus = ReqApi_Done //状态迁移 return 1, nil } func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) { var predicates []predicate.CompapiAsynctask predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done)) var tasks []Task res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...). Order(ent.Asc(compapiasynctask.FieldID)). Limit(int(me.Conf.BatchLoadTask)). All(me.ctx) if err == nil { for idx, val := range res { tasks = append(tasks, Task{Data: val, Idx: idx}) } } return tasks, err } func (me *AsyncTask) processTask(workerID uint, task Task) { /* fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n", workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus) */ _ = workerID var err error rt := 0 for { select { case <-me.ctx.Done(): //接到信号退出 return default: if task.Data.TaskStatus >= All_Done { goto endfor //原来的break操作加了switch一层不好使了 } switch task.Data.TaskStatus { case Task_Ready: //请求API平台 // succ: taskStatus Task_Ready => ReqApi_Done // fail: taskStatus保持当前不变或Task_Fail rt, err = me.requestAPI(task.Data) case ReqApi_Done: //结果回调 // succ: taskStatus ReqApi_Done => Callback_Done(All_Done) // fail: taskStatus保持当前不变或Task_Fail rt, err = me.requestCallback(task.Data) } if err != nil { //收集错误 if rt == 0 { //不可恢复错误处理.... logx.Errorf("Task error by '%s'", err) } else { logx.Debugf("Task ignore by '%s'", err) } return //先暂时忽略处理,也许应按错误类型分别对待 } } } endfor: } func (me *AsyncTask) batchWork() (int64, int64) { var ( wg sync.WaitGroup produced int64 //生产数量(原子计数器) consumed int64 //消费数量(原子计数器) ) //创建任务分发器 dispatcher := NewTaskDispatcher(me.Conf.MaxChannel, me.Conf.BatchLoadTask/me.Conf.MaxChannel) //启动消费者 for widx := range me.Conf.MaxWorker { cidx, ch := dispatcher.getWorkerChanByIdx(widx) wg.Add(1) go func(workerID uint, channelID uint, taskCh chan Task) { defer wg.Done() gidStr := "" if me.Conf.Debug { gid, _ := getGoroutineId() gidStr = fmt.Sprintf("(Goroutine:%d)", gid) } logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......", workerID, gidStr, channelID) for task := range taskCh { me.processTask(widx, task) atomic.AddInt64(&consumed, 1) } }(widx+1, cidx+1, ch) } // 生产者 wg.Add(1) go func() { defer wg.Done() gidStr := "" if me.Conf.Debug { gid, _ := getGoroutineId() gidStr = fmt.Sprintf("(Goroutine:%d)", gid) } logx.Infof("Producer @1%s start......", gidStr) //获得待处理异步任务列表 tasks, err := me.getAsyncReqTaskList() if err != nil { logx.Errorf("getAsyncReqTaskList err:%s", err) return } // 分发任务 for _, task := range tasks { dispatcher.Dispatch(task) atomic.AddInt64(&produced, 1) } logx.Infof("📦Producer @1 此批次共创建任务%d件", len(tasks)) // 关闭所有会话通道 dispatcher.mu.Lock() for _, ch := range dispatcher.workerChs { _ = ch close(ch) } dispatcher.mu.Unlock() }() wg.Wait() consumedRatStr := "" if atomic.LoadInt64(&produced) > 0 { consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed), (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100) } logx.Infof("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr) return produced, consumed } func (me *AsyncTask) InitServiceContext() *svc.ServiceContext { rds := me.Conf.RedisConf.MustNewUniversalRedis() dbOpts := []ent.Option{ent.Log(logx.Info), ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())} if me.Conf.Debug { dbOpts = append(dbOpts, ent.Debug()) } db := ent.NewClient(dbOpts...) svcCtx := svc.ServiceContext{DB: db, Rds: rds} //svcCtx.Config return &svcCtx } func (me *AsyncTask) adjustConf() { if me.Conf.MaxWorker <= 0 { me.Conf.MaxWorker = 2 } if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker { me.Conf.MaxChannel = me.Conf.MaxWorker } } func (me *AsyncTask) Run(ctx context.Context) { me.ctx = ctx me.Logger = logx.WithContext(ctx) me.adjustConf() for { select { case <-ctx.Done(): // 收到了取消信号 fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err()) return default: produced, _ := me.batchWork() timeDurnum := 1 if produced == 0 { timeDurnum = 5 } fmt.Printf("batchWork will sleep %d sec\n", timeDurnum) time.Sleep(time.Duration(timeDurnum) * time.Second) } } } func NewAsyncTask() *AsyncTask { ataskObj := AsyncTask{} flag.Parse() //将命令行参数也塞入flag.CommandLine结构 //fmt.Println(typekit.PrettyPrint(flag.CommandLine)) conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv()) //fmt.Println(typekit.PrettyPrint(ataskObj.Conf)) ataskObj.svcCtx = ataskObj.InitServiceContext() return &ataskObj } func main() { fmt.Println("Compapi Asynctask Start......") //ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() NewAsyncTask().Run(ctx) fmt.Println("Compapi Asynctask End......") }