package main import ( "bytes" "context" "encoding/json" "errors" "flag" "fmt" "hash/fnv" "os" "os/signal" "reflect" "runtime" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" "wechat-api/ent" "wechat-api/ent/compapiasynctask" "wechat-api/ent/predicate" "wechat-api/internal/svc" "wechat-api/internal/types" "wechat-api/internal/utils/compapi" "github.com/suyuan32/simple-admin-common/config" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" ) const ( Task_Ready = 10 //任务就绪 ReqApi_Done = 20 //请求API完成 Callback_Done = 30 //请求回调完成 All_Done = Callback_Done //全部完成 (暂时将成功状态的终点标注于此) Task_Suspend = 60 //任务暂停 Task_Fail = 70 //任务失败 LoopTryCount = 3 //循环体内重试次数 LoopDelayFactor = 3 ErrTaskTryCount = 3 //最大允许错误任务重试次数 DefaultDisId = "DIS0001" ) type Config struct { BatchLoadTask uint `json:",default=100"` MaxWorker uint `json:",default=2"` MaxChannel uint `json:",default=1"` Debug bool `json:",default=false"` DatabaseConf config.DatabaseConf RedisConf config.RedisConf } type TaskStat struct { } type AsyncTask struct { logx.Logger ctx context.Context svcCtx *svc.ServiceContext Conf Config Stats TaskStat } type Task struct { Data *ent.CompapiAsynctask Idx int Code int } // 带会话管理的任务通道组 type TaskDispatcher struct { mu sync.Mutex workerChs []chan Task // 每个worker独立通道 Debug bool } var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file") func getGoroutineId() (int64, error) { // 堆栈结果中需要消除的前缀符 var goroutineSpace = []byte("goroutine ") bs := make([]byte, 128) bs = bs[:runtime.Stack(bs, false)] bs = bytes.TrimPrefix(bs, goroutineSpace) i := bytes.IndexByte(bs, ' ') if i < 0 { return -1, errors.New("get current goroutine id failed") } return strconv.ParseInt(string(bs[:i]), 10, 64) } func NewTaskDispatcher(channelCount uint, chanSize uint, debug bool) *TaskDispatcher { td := &TaskDispatcher{ workerChs: make([]chan Task, channelCount), Debug: debug, } // 初始化worker通道 for i := range td.workerChs { td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲 } return td } // 按woker下标索引选择workerChs func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) { nidx := idx % uint(len(td.workerChs)) return nidx, td.workerChs[nidx] } // 按哈希分片选择workerChs func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) { if len(disID) == 0 { disID = DefaultDisId } idx := 0 if len(td.workerChs) > 1 { h := fnv.New32a() h.Write([]byte(disID)) idx = int(h.Sum32()) % len(td.workerChs) } if td.Debug { outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID) for i := range len(td.workerChs) { outStr += fmt.Sprintf("#%d", i+1) if i < len(td.workerChs)-1 { outStr += "," } } outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID) logx.Debug(outStr) } return idx, td.workerChs[idx] } // 分配任务到对应的消费channel func (td *TaskDispatcher) Dispatch(task Task) { td.mu.Lock() defer td.mu.Unlock() // 根据chatId哈希获得其对应的workerChan workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType) // 将任务送入该chatid的workerChan workerCh <- task if td.Debug { logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1) } } // 更新任务状态 func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error { _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetTaskStatus(int8(status)). SetUpdatedAt(time.Now()). Save(me.ctx) return err } // 更新请求大模型后结果 func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error { _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetResponseRaw(apiResponse). SetLastError(""). SetRetryCount(0). Save(me.ctx) return err } func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error { callResStr := "" switch v := callRes.(type) { case []byte: callResStr = string(v) default: if bs, err := json.Marshal(v); err == nil { callResStr = string(bs) } else { return err } } _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId). SetUpdatedAt(time.Now()). SetCallbackResponseRaw(callResStr). SetLastError(""). SetRetryCount(0). Save(me.ctx) return err } func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) { var err error var needStop = false if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 _, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()). SetTaskStatus(int8(Task_Fail)). Save(me.ctx) if err == nil { needStop = true taskData.TaskStatus = Task_Fail } } return needStop, err } // 错误任务处理 func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error { logx.Debug("多次循环之后依然失败,进入错误任务处理环节") cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID). SetUpdatedAt(time.Now()) if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败 taskData.TaskStatus = Task_Fail cauo = cauo.SetTaskStatus(int8(Task_Fail)) } else { cauo = cauo.SetRetryCount(taskData.RetryCount + 1). SetLastError(lasterr.Error()) } _, err := cauo.Save(me.ctx) return err } func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) { var workerId int64 if me.Conf.Debug { workerId, _ = getGoroutineId() logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus) } if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != ReqApi_Done { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } if len(taskData.CallbackURL) == 0 { return 0, errors.New("callback url empty") } if len(taskData.ResponseRaw) == 0 { return 0, errors.New("call api response empty") } /* fstr := "mytest-svc:" if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) { taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1) } */ //请求预先给定的callback_url res := map[string]any{} var err error //初始化client client := compapi.NewClient(me.ctx) for i := range LoopTryCount { //重试机制 select { case <-me.ctx.Done(): //接到信号退出 goto endloopTry default: res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData) if err == nil { //call succ //fmt.Println("callback succ..........") //fmt.Println(typekit.PrettyPrint(res)) if me.Conf.Debug { logx.Infof("callback:'%s' succ", taskData.CallbackURL) } goto endloopTry } logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount) time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second) } } //多次循环之后依然失败,进入错误任务处理环节 endloopTry: if err != nil { err1 := me.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 //更新任务状态 => Callback_Done(回调完成) err = me.updateTaskStatus(taskData.ID, Callback_Done) if err != nil { return 0, err } //更新taskData.CallbackResponseRaw if len(res) > 0 { me.updateCallbackResponse(taskData.ID, res) } taskData.TaskStatus = Callback_Done //状态迁移 return 1, nil } func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) { var workerId int64 if me.Conf.Debug { workerId, _ = getGoroutineId() logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus) } if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理 return 1, errors.New("too many err retry") } if taskData.TaskStatus != Task_Ready { return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus) } var ( err error apiResp *types.CompOpenApiResp ) req := types.CompApiReq{} if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil { return 0, err } //初始化client if !strings.HasSuffix(taskData.OpenaiBase, "/") { taskData.OpenaiBase = taskData.OpenaiBase + "/" } client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase), compapi.WithApiKey(taskData.OpenaiKey)) for i := range LoopTryCount { //重试机制 select { case <-me.ctx.Done(): //接到信号退出 goto endloopTry default: apiResp, err = client.Chat(&req) if err == nil && apiResp != nil && len(apiResp.Choices) > 0 { //call succ goto endloopTry } else if apiResp != nil && len(apiResp.Choices) == 0 { err = errors.New("返回结果缺失,请检查访问权限") } if me.Conf.Debug { logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId, taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount) } time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second) } } endloopTry: //多次循环之后依然失败,进入错误任务处理环节 if err != nil || apiResp == nil { if apiResp == nil && err == nil { err = errors.New("resp is null") } err1 := me.dealErrorTask(taskData, err) //错误任务处理 et := 1 if err1 != nil { et = 0 } return et, err } //成功后处理环节 //更新任务状态 => ReqApi_Done(请求API完成) err = me.updateTaskStatus(taskData.ID, ReqApi_Done) if err != nil { return 0, err } //更新taskData.ResponseRaw taskData.ResponseRaw, err = (*apiResp).ToString() if err != nil { return 0, err } err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw) if err != nil { return 0, err } taskData.TaskStatus = ReqApi_Done //状态迁移 return 1, nil } func EntStructGenScanField(structPtr any) (string, []any, error) { t := reflect.TypeOf(structPtr) v := reflect.ValueOf(structPtr) if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { return "", nil, errors.New("input must be a pointer to a struct") } t = t.Elem() v = v.Elem() var fields []string var scanArgs []any for i := 0; i < t.NumField(); i++ { field := t.Field(i) value := v.Field(i) // Skip unexported fields if !field.IsExported() { continue } // Get json tag jsonTag := field.Tag.Get("json") if jsonTag == "-" || jsonTag == "" { continue } jsonParts := strings.Split(jsonTag, ",") jsonName := jsonParts[0] if jsonName == "" { continue } fields = append(fields, jsonName) scanArgs = append(scanArgs, value.Addr().Interface()) } return strings.Join(fields, ", "), scanArgs, nil } /* CREATE INDEX idx_compapi_task_status_chat_id_id_desc ON compapi_asynctask (task_status, chat_id, id DESC); */ func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) { fieldListStr, _, err := EntStructGenScanField(&ent.CompapiAsynctask{}) if err != nil { return nil, err } rawQuery := fmt.Sprintf(` WITH ranked AS ( SELECT %s, ROW_NUMBER() OVER (PARTITION BY chat_id ORDER BY id DESC) AS rn FROM compapi_asynctask WHERE task_status < ? ) SELECT %s FROM ranked WHERE rn <= ? ORDER BY rn,id DESC LIMIT ?; `, fieldListStr, fieldListStr) // 执行原始查询 rows, err := me.svcCtx.DB.QueryContext(me.ctx, rawQuery, All_Done, me.Conf.BatchLoadTask, me.Conf.BatchLoadTask) if err != nil { return nil, fmt.Errorf("fetch fair tasks query error: %w", err) } defer rows.Close() Idx := 0 tasks := []Task{} for rows.Next() { taskrow := ent.CompapiAsynctask{} var scanParams []any _, scanParams, err = EntStructGenScanField(&taskrow) if err != nil { break } err = rows.Scan(scanParams...) if err != nil { break } task := Task{Data: &taskrow, Idx: Idx} tasks = append(tasks, task) Idx++ } fmt.Printf("getAsyncReqTaskFairList get task:%d\n", len(tasks)) return tasks, nil } func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) { var predicates []predicate.CompapiAsynctask predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done)) var tasks []Task res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...). Order(ent.Desc(compapiasynctask.FieldID)). Limit(int(me.Conf.BatchLoadTask)). All(me.ctx) if err == nil { for idx, val := range res { tasks = append(tasks, Task{Data: val, Idx: idx}) } } return tasks, err } func (me *AsyncTask) processTask(workerID uint, task Task) { /* fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n", workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken, task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus) */ _ = workerID var err error rt := 0 for { select { case <-me.ctx.Done(): //接到信号退出 return default: if task.Data.TaskStatus >= All_Done { goto endfor //原来的break操作加了switch一层不好使了 } switch task.Data.TaskStatus { case Task_Ready: //请求API平台 // succ: taskStatus Task_Ready => ReqApi_Done // fail: taskStatus保持当前不变或Task_Fail rt, err = me.requestAPI(task.Data) case ReqApi_Done: //结果回调 // succ: taskStatus ReqApi_Done => Callback_Done(All_Done) // fail: taskStatus保持当前不变或Task_Fail rt, err = me.requestCallback(task.Data) } if err != nil { //收集错误 if rt == 0 { //不可恢复错误处理.... logx.Errorf("Task error by '%s'", err) } else { logx.Debugf("Task ignore by '%s'", err) } return //先暂时忽略处理,也许应按错误类型分别对待 } } } endfor: } func (me *AsyncTask) batchWork() (int64, int64) { var ( wg sync.WaitGroup produced int64 //生产数量(原子计数器) consumed int64 //消费数量(原子计数器) ) //创建任务分发器 dispatcher := NewTaskDispatcher(me.Conf.MaxChannel, me.Conf.BatchLoadTask/me.Conf.MaxChannel, me.Conf.Debug) //启动消费者 for widx := range me.Conf.MaxWorker { cidx, ch := dispatcher.getWorkerChanByIdx(widx) wg.Add(1) go func(workerID uint, channelID uint, taskCh chan Task) { defer wg.Done() gidStr := "" if me.Conf.Debug { gid, _ := getGoroutineId() gidStr = fmt.Sprintf("(Goroutine:%d)", gid) logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......", workerID, gidStr, channelID) } for task := range taskCh { me.processTask(widx, task) atomic.AddInt64(&consumed, 1) } }(widx+1, cidx+1, ch) } // 生产者 wg.Add(1) go func() { defer wg.Done() gidStr := "" if me.Conf.Debug { gid, _ := getGoroutineId() gidStr = fmt.Sprintf("(Goroutine:%d)", gid) logx.Infof("Producer @1%s start......", gidStr) } //获得待处理异步任务列表 //tasks, err := me.getAsyncReqTaskList() tasks, err := me.getAsyncReqTaskFairList() if err != nil { logx.Errorf("getAsyncReqTaskList err:%s", err) return } // 分发任务 for _, task := range tasks { dispatcher.Dispatch(task) atomic.AddInt64(&produced, 1) } fmt.Printf("📦Producer @1 此批次共创建任务%d件\n", len(tasks)) // 关闭所有会话通道 dispatcher.mu.Lock() for _, ch := range dispatcher.workerChs { _ = ch close(ch) } dispatcher.mu.Unlock() }() wg.Wait() consumedRatStr := "" if atomic.LoadInt64(&produced) > 0 { consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed), (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100) } fmt.Printf("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s\n", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr) return produced, consumed } func (me *AsyncTask) InitServiceContext() *svc.ServiceContext { rds := me.Conf.RedisConf.MustNewUniversalRedis() dbOpts := []ent.Option{ent.Log(logx.Info), ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())} if me.Conf.Debug { dbOpts = append(dbOpts, ent.Debug()) } db := ent.NewClient(dbOpts...) svcCtx := svc.ServiceContext{DB: db, Rds: rds} //svcCtx.Config return &svcCtx } func (me *AsyncTask) adjustConf() { if me.Conf.MaxWorker <= 0 { me.Conf.MaxWorker = 2 } if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker { me.Conf.MaxChannel = me.Conf.MaxWorker } } func (me *AsyncTask) Run(ctx context.Context) { me.ctx = ctx me.Logger = logx.WithContext(ctx) me.adjustConf() /* tasks, err := me.getAsyncReqTaskFairList() if err != nil { fmt.Println(err) return } for idx, task := range tasks { if idx > 20 { break } fmt.Printf("#%d=>%d ||[%s]'%s' || '%s' || '%s'|| '%s'\n", idx, task.Data.ID, task.Data.CreatedAt, task.Data.EventType, task.Data.Model, task.Data.OpenaiBase, task.Data.ChatID) } */ batchId := 0 for { batchId++ select { case <-ctx.Done(): // 收到了取消信号 fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err()) return default: timeStart := time.Now() secStart := timeStart.Unix() fmt.Printf("[%s]batchWork#%d start......\n", timeStart.Format("2006-01-02 15:04:05"), batchId) produced, _ := me.batchWork() timeEnd := time.Now() fmt.Printf("[%s]batchWork#%d end,spent %d sec\n", timeEnd.Format("2006-01-02 15:04:05"), batchId, timeEnd.Unix()-secStart) timeDurnum := 1 if produced == 0 { timeDurnum = 5 } fmt.Printf("batchWork will sleep %d sec\n", timeDurnum) time.Sleep(time.Duration(timeDurnum) * time.Second) } } } func NewAsyncTask() *AsyncTask { ataskObj := AsyncTask{} flag.Parse() //将命令行参数也塞入flag.CommandLine结构 //fmt.Println(typekit.PrettyPrint(flag.CommandLine)) conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv()) //fmt.Println(typekit.PrettyPrint(ataskObj.Conf)) ataskObj.svcCtx = ataskObj.InitServiceContext() return &ataskObj } func main() { fmt.Println("Compapi Asynctask Start......") //ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() NewAsyncTask().Run(ctx) fmt.Println("Compapi Asynctask End......") }