|
@@ -0,0 +1,715 @@
|
|
|
+package main
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "context"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "flag"
|
|
|
+ "fmt"
|
|
|
+ "hash/fnv"
|
|
|
+ "os"
|
|
|
+ "os/signal"
|
|
|
+ "reflect"
|
|
|
+ "runtime"
|
|
|
+ "strconv"
|
|
|
+ "strings"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
+ "syscall"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "wechat-api/ent"
|
|
|
+ "wechat-api/ent/compapiasynctask"
|
|
|
+ "wechat-api/ent/predicate"
|
|
|
+ "wechat-api/internal/svc"
|
|
|
+ "wechat-api/internal/types"
|
|
|
+ "wechat-api/internal/utils/compapi"
|
|
|
+
|
|
|
+ "github.com/suyuan32/simple-admin-common/config"
|
|
|
+ "github.com/zeromicro/go-zero/core/conf"
|
|
|
+ "github.com/zeromicro/go-zero/core/logx"
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ Task_Ready = 10 //任务就绪
|
|
|
+ ReqApi_Done = 20 //请求API完成
|
|
|
+ Callback_Done = 30 //请求回调完成
|
|
|
+ All_Done = Callback_Done //全部完成 (暂时将成功状态的终点标注于此)
|
|
|
+ Task_Suspend = 60 //任务暂停
|
|
|
+ Task_Fail = 70 //任务失败
|
|
|
+
|
|
|
+ LoopTryCount = 3 //循环体内重试次数
|
|
|
+ LoopDelayFactor = 3
|
|
|
+ ErrTaskTryCount = 3 //最大允许错误任务重试次数
|
|
|
+
|
|
|
+ DefaultDisId = "DIS0001"
|
|
|
+)
|
|
|
+
|
|
|
+type Config struct {
|
|
|
+ BatchLoadTask uint `json:",default=100"`
|
|
|
+ MaxWorker uint `json:",default=2"`
|
|
|
+ MaxChannel uint `json:",default=1"`
|
|
|
+ Debug bool `json:",default=false"`
|
|
|
+ DatabaseConf config.DatabaseConf
|
|
|
+ RedisConf config.RedisConf
|
|
|
+}
|
|
|
+
|
|
|
+type TaskStat struct {
|
|
|
+}
|
|
|
+
|
|
|
+type AsyncTask struct {
|
|
|
+ logx.Logger
|
|
|
+ ctx context.Context
|
|
|
+ svcCtx *svc.ServiceContext
|
|
|
+ Conf Config
|
|
|
+ Stats TaskStat
|
|
|
+}
|
|
|
+
|
|
|
+type Task struct {
|
|
|
+ Data *ent.CompapiAsynctask
|
|
|
+ Idx int
|
|
|
+ Code int
|
|
|
+}
|
|
|
+
|
|
|
+// 带会话管理的任务通道组
|
|
|
+type TaskDispatcher struct {
|
|
|
+ mu sync.Mutex
|
|
|
+ workerChs []chan Task // 每个worker独立通道
|
|
|
+ Debug bool
|
|
|
+}
|
|
|
+
|
|
|
+var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file")
|
|
|
+
|
|
|
+func getGoroutineId() (int64, error) {
|
|
|
+ // 堆栈结果中需要消除的前缀符
|
|
|
+ var goroutineSpace = []byte("goroutine ")
|
|
|
+
|
|
|
+ bs := make([]byte, 128)
|
|
|
+ bs = bs[:runtime.Stack(bs, false)]
|
|
|
+ bs = bytes.TrimPrefix(bs, goroutineSpace)
|
|
|
+ i := bytes.IndexByte(bs, ' ')
|
|
|
+ if i < 0 {
|
|
|
+ return -1, errors.New("get current goroutine id failed")
|
|
|
+ }
|
|
|
+ return strconv.ParseInt(string(bs[:i]), 10, 64)
|
|
|
+}
|
|
|
+
|
|
|
+func NewTaskDispatcher(channelCount uint, chanSize uint, debug bool) *TaskDispatcher {
|
|
|
+ td := &TaskDispatcher{
|
|
|
+ workerChs: make([]chan Task, channelCount),
|
|
|
+ Debug: debug,
|
|
|
+ }
|
|
|
+ // 初始化worker通道
|
|
|
+ for i := range td.workerChs {
|
|
|
+ td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲
|
|
|
+ }
|
|
|
+ return td
|
|
|
+}
|
|
|
+
|
|
|
+// 按woker下标索引选择workerChs
|
|
|
+func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) {
|
|
|
+ nidx := idx % uint(len(td.workerChs))
|
|
|
+ return nidx, td.workerChs[nidx]
|
|
|
+}
|
|
|
+
|
|
|
+// 按哈希分片选择workerChs
|
|
|
+func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
|
|
|
+ if len(disID) == 0 {
|
|
|
+ disID = DefaultDisId
|
|
|
+ }
|
|
|
+ idx := 0
|
|
|
+ if len(td.workerChs) > 1 {
|
|
|
+ h := fnv.New32a()
|
|
|
+ h.Write([]byte(disID))
|
|
|
+ idx = int(h.Sum32()) % len(td.workerChs)
|
|
|
+ }
|
|
|
+ if td.Debug {
|
|
|
+ outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
|
|
|
+ for i := range len(td.workerChs) {
|
|
|
+ outStr += fmt.Sprintf("#%d", i+1)
|
|
|
+ if i < len(td.workerChs)-1 {
|
|
|
+ outStr += ","
|
|
|
+ }
|
|
|
+ }
|
|
|
+ outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
|
|
|
+ logx.Debug(outStr)
|
|
|
+ }
|
|
|
+
|
|
|
+ return idx, td.workerChs[idx]
|
|
|
+}
|
|
|
+
|
|
|
+// 分配任务到对应的消费channel
|
|
|
+func (td *TaskDispatcher) Dispatch(task Task) {
|
|
|
+ td.mu.Lock()
|
|
|
+ defer td.mu.Unlock()
|
|
|
+
|
|
|
+ // 根据chatId哈希获得其对应的workerChan
|
|
|
+ workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
|
|
|
+ // 将任务送入该chatid的workerChan
|
|
|
+ workerCh <- task
|
|
|
+ if td.Debug {
|
|
|
+ logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// 更新任务状态
|
|
|
+func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error {
|
|
|
+ _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
|
|
|
+ SetTaskStatus(int8(status)).
|
|
|
+ SetUpdatedAt(time.Now()).
|
|
|
+ Save(me.ctx)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+// 更新请求大模型后结果
|
|
|
+func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error {
|
|
|
+ _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
|
|
|
+ SetUpdatedAt(time.Now()).
|
|
|
+ SetResponseRaw(apiResponse).
|
|
|
+ SetLastError("").
|
|
|
+ SetRetryCount(0).
|
|
|
+ Save(me.ctx)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error {
|
|
|
+ callResStr := ""
|
|
|
+ switch v := callRes.(type) {
|
|
|
+ case []byte:
|
|
|
+ callResStr = string(v)
|
|
|
+ default:
|
|
|
+ if bs, err := json.Marshal(v); err == nil {
|
|
|
+ callResStr = string(bs)
|
|
|
+ } else {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
|
|
|
+ SetUpdatedAt(time.Now()).
|
|
|
+ SetCallbackResponseRaw(callResStr).
|
|
|
+ SetLastError("").
|
|
|
+ SetRetryCount(0).
|
|
|
+ Save(me.ctx)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
|
|
|
+ var err error
|
|
|
+ var needStop = false
|
|
|
+ if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
|
|
|
+ _, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
|
|
|
+ SetUpdatedAt(time.Now()).
|
|
|
+ SetTaskStatus(int8(Task_Fail)).
|
|
|
+ Save(me.ctx)
|
|
|
+ if err == nil {
|
|
|
+ needStop = true
|
|
|
+ taskData.TaskStatus = Task_Fail
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return needStop, err
|
|
|
+}
|
|
|
+
|
|
|
+// 错误任务处理
|
|
|
+func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
|
|
|
+ logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
|
|
|
+ cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
|
|
|
+ SetUpdatedAt(time.Now())
|
|
|
+ if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
|
|
|
+ taskData.TaskStatus = Task_Fail
|
|
|
+ cauo = cauo.SetTaskStatus(int8(Task_Fail))
|
|
|
+ } else {
|
|
|
+ cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
|
|
|
+ SetLastError(lasterr.Error())
|
|
|
+ }
|
|
|
+ _, err := cauo.Save(me.ctx)
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
|
|
|
+ var workerId int64
|
|
|
+ if me.Conf.Debug {
|
|
|
+ workerId, _ = getGoroutineId()
|
|
|
+ logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
|
|
|
+ }
|
|
|
+
|
|
|
+ if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
|
|
|
+ return 1, errors.New("too many err retry")
|
|
|
+ }
|
|
|
+ if taskData.TaskStatus != ReqApi_Done {
|
|
|
+ return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(taskData.CallbackURL) == 0 {
|
|
|
+ return 0, errors.New("callback url empty")
|
|
|
+ }
|
|
|
+ if len(taskData.ResponseRaw) == 0 {
|
|
|
+ return 0, errors.New("call api response empty")
|
|
|
+ }
|
|
|
+
|
|
|
+ /*
|
|
|
+ fstr := "mytest-svc:"
|
|
|
+ if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) {
|
|
|
+ taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
|
|
|
+ }
|
|
|
+ */
|
|
|
+
|
|
|
+ //请求预先给定的callback_url
|
|
|
+ res := map[string]any{}
|
|
|
+ var err error
|
|
|
+ //初始化client
|
|
|
+ client := compapi.NewClient(me.ctx)
|
|
|
+ for i := range LoopTryCount { //重试机制
|
|
|
+ select {
|
|
|
+ case <-me.ctx.Done(): //接到信号退出
|
|
|
+ goto endloopTry
|
|
|
+ default:
|
|
|
+ res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData)
|
|
|
+ if err == nil {
|
|
|
+ //call succ
|
|
|
+ //fmt.Println("callback succ..........")
|
|
|
+ //fmt.Println(typekit.PrettyPrint(res))
|
|
|
+ if me.Conf.Debug {
|
|
|
+ logx.Infof("callback:'%s' succ", taskData.CallbackURL)
|
|
|
+ }
|
|
|
+ goto endloopTry
|
|
|
+ }
|
|
|
+ logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
|
|
|
+ taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
|
|
|
+ time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ //多次循环之后依然失败,进入错误任务处理环节
|
|
|
+endloopTry:
|
|
|
+ if err != nil {
|
|
|
+ err1 := me.dealErrorTask(taskData, err) //错误任务处理
|
|
|
+ et := 1
|
|
|
+ if err1 != nil {
|
|
|
+ et = 0
|
|
|
+ }
|
|
|
+ return et, err
|
|
|
+ }
|
|
|
+ //成功后处理环节
|
|
|
+
|
|
|
+ //更新任务状态 => Callback_Done(回调完成)
|
|
|
+ err = me.updateTaskStatus(taskData.ID, Callback_Done)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ //更新taskData.CallbackResponseRaw
|
|
|
+ if len(res) > 0 {
|
|
|
+ me.updateCallbackResponse(taskData.ID, res)
|
|
|
+ }
|
|
|
+
|
|
|
+ taskData.TaskStatus = Callback_Done //状态迁移
|
|
|
+ return 1, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
|
|
|
+ var workerId int64
|
|
|
+ if me.Conf.Debug {
|
|
|
+ workerId, _ = getGoroutineId()
|
|
|
+ logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
|
|
|
+ }
|
|
|
+
|
|
|
+ if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
|
|
|
+ return 1, errors.New("too many err retry")
|
|
|
+ }
|
|
|
+
|
|
|
+ if taskData.TaskStatus != Task_Ready {
|
|
|
+ return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
|
|
|
+ }
|
|
|
+ var (
|
|
|
+ err error
|
|
|
+ apiResp *types.CompOpenApiResp
|
|
|
+ )
|
|
|
+ req := types.CompApiReq{}
|
|
|
+ if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ //初始化client
|
|
|
+ if !strings.HasSuffix(taskData.OpenaiBase, "/") {
|
|
|
+ taskData.OpenaiBase = taskData.OpenaiBase + "/"
|
|
|
+ }
|
|
|
+ client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase),
|
|
|
+ compapi.WithApiKey(taskData.OpenaiKey))
|
|
|
+
|
|
|
+ for i := range LoopTryCount { //重试机制
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-me.ctx.Done(): //接到信号退出
|
|
|
+ goto endloopTry
|
|
|
+ default:
|
|
|
+ apiResp, err = client.Chat(&req)
|
|
|
+ if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
|
|
|
+ //call succ
|
|
|
+ goto endloopTry
|
|
|
+ } else if apiResp != nil && len(apiResp.Choices) == 0 {
|
|
|
+ err = errors.New("返回结果缺失,请检查访问权限")
|
|
|
+ }
|
|
|
+ if me.Conf.Debug {
|
|
|
+ logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
|
|
|
+ taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
|
|
|
+ }
|
|
|
+ time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+endloopTry:
|
|
|
+ //多次循环之后依然失败,进入错误任务处理环节
|
|
|
+ if err != nil || apiResp == nil {
|
|
|
+ if apiResp == nil && err == nil {
|
|
|
+ err = errors.New("resp is null")
|
|
|
+ }
|
|
|
+ err1 := me.dealErrorTask(taskData, err) //错误任务处理
|
|
|
+ et := 1
|
|
|
+ if err1 != nil {
|
|
|
+ et = 0
|
|
|
+ }
|
|
|
+ return et, err
|
|
|
+ }
|
|
|
+ //成功后处理环节
|
|
|
+
|
|
|
+ //更新任务状态 => ReqApi_Done(请求API完成)
|
|
|
+ err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ //更新taskData.ResponseRaw
|
|
|
+ taskData.ResponseRaw, err = (*apiResp).ToString()
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ taskData.TaskStatus = ReqApi_Done //状态迁移
|
|
|
+ return 1, nil
|
|
|
+}
|
|
|
+
|
|
|
+func EntStructGenScanField(structPtr any) (string, []any, error) {
|
|
|
+ t := reflect.TypeOf(structPtr)
|
|
|
+ v := reflect.ValueOf(structPtr)
|
|
|
+
|
|
|
+ if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
|
|
|
+ return "", nil, errors.New("input must be a pointer to a struct")
|
|
|
+ }
|
|
|
+ t = t.Elem()
|
|
|
+ v = v.Elem()
|
|
|
+
|
|
|
+ var fields []string
|
|
|
+ var scanArgs []any
|
|
|
+ for i := 0; i < t.NumField(); i++ {
|
|
|
+ field := t.Field(i)
|
|
|
+ value := v.Field(i)
|
|
|
+
|
|
|
+ // Skip unexported fields
|
|
|
+ if !field.IsExported() {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get json tag
|
|
|
+ jsonTag := field.Tag.Get("json")
|
|
|
+ if jsonTag == "-" || jsonTag == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ jsonParts := strings.Split(jsonTag, ",")
|
|
|
+ jsonName := jsonParts[0]
|
|
|
+ if jsonName == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ fields = append(fields, jsonName)
|
|
|
+ scanArgs = append(scanArgs, value.Addr().Interface())
|
|
|
+ }
|
|
|
+ return strings.Join(fields, ", "), scanArgs, nil
|
|
|
+}
|
|
|
+
|
|
|
+/*
|
|
|
+CREATE INDEX idx_compapi_task_status_chat_id_id_desc
|
|
|
+ON compapi_asynctask (task_status, chat_id, id DESC);
|
|
|
+*/
|
|
|
+func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) {
|
|
|
+ fieldListStr, _, err := EntStructGenScanField(&ent.CompapiAsynctask{})
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ rawQuery := fmt.Sprintf(`
|
|
|
+ WITH ranked AS (
|
|
|
+ SELECT %s,
|
|
|
+ ROW_NUMBER() OVER (PARTITION BY chat_id ORDER BY id DESC) AS rn
|
|
|
+ FROM compapi_asynctask
|
|
|
+ WHERE task_status < ?
|
|
|
+ )
|
|
|
+ SELECT %s
|
|
|
+ FROM ranked
|
|
|
+ WHERE rn <= ?
|
|
|
+ ORDER BY rn,id DESC
|
|
|
+ LIMIT ?;
|
|
|
+ `, fieldListStr, fieldListStr)
|
|
|
+
|
|
|
+ // 执行原始查询
|
|
|
+ rows, err := me.svcCtx.DB.QueryContext(me.ctx, rawQuery, All_Done,
|
|
|
+ me.Conf.BatchLoadTask, me.Conf.BatchLoadTask)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("fetch fair tasks query error: %w", err)
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+ Idx := 0
|
|
|
+ tasks := []Task{}
|
|
|
+ for rows.Next() {
|
|
|
+ taskrow := ent.CompapiAsynctask{}
|
|
|
+ var scanParams []any
|
|
|
+ _, scanParams, err = EntStructGenScanField(&taskrow)
|
|
|
+ if err != nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ err = rows.Scan(scanParams...)
|
|
|
+ if err != nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ task := Task{Data: &taskrow, Idx: Idx}
|
|
|
+ tasks = append(tasks, task)
|
|
|
+ Idx++
|
|
|
+ }
|
|
|
+ fmt.Printf("getAsyncReqTaskFairList get task:%d\n", len(tasks))
|
|
|
+ return tasks, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) {
|
|
|
+ var predicates []predicate.CompapiAsynctask
|
|
|
+ predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
|
|
|
+
|
|
|
+ var tasks []Task
|
|
|
+ res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
|
|
|
+ Order(ent.Desc(compapiasynctask.FieldID)).
|
|
|
+ Limit(int(me.Conf.BatchLoadTask)).
|
|
|
+ All(me.ctx)
|
|
|
+ if err == nil {
|
|
|
+ for idx, val := range res {
|
|
|
+ tasks = append(tasks, Task{Data: val, Idx: idx})
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return tasks, err
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) processTask(workerID uint, task Task) {
|
|
|
+ /*
|
|
|
+ fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
|
|
|
+ workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
|
|
|
+ task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)
|
|
|
+ */
|
|
|
+ _ = workerID
|
|
|
+ var err error
|
|
|
+ rt := 0
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-me.ctx.Done(): //接到信号退出
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ if task.Data.TaskStatus >= All_Done {
|
|
|
+ goto endfor //原来的break操作加了switch一层不好使了
|
|
|
+ }
|
|
|
+ switch task.Data.TaskStatus {
|
|
|
+ case Task_Ready:
|
|
|
+ //请求API平台
|
|
|
+ // succ: taskStatus Task_Ready => ReqApi_Done
|
|
|
+ // fail: taskStatus保持当前不变或Task_Fail
|
|
|
+ rt, err = me.requestAPI(task.Data)
|
|
|
+ case ReqApi_Done:
|
|
|
+ //结果回调
|
|
|
+ // succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
|
|
|
+ // fail: taskStatus保持当前不变或Task_Fail
|
|
|
+ rt, err = me.requestCallback(task.Data)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ //收集错误
|
|
|
+ if rt == 0 {
|
|
|
+ //不可恢复错误处理....
|
|
|
+ logx.Errorf("Task error by '%s'", err)
|
|
|
+ } else {
|
|
|
+ logx.Debugf("Task ignore by '%s'", err)
|
|
|
+ }
|
|
|
+ return //先暂时忽略处理,也许应按错误类型分别对待
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+endfor:
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) batchWork() (int64, int64) {
|
|
|
+
|
|
|
+ var (
|
|
|
+ wg sync.WaitGroup
|
|
|
+ produced int64 //生产数量(原子计数器)
|
|
|
+ consumed int64 //消费数量(原子计数器)
|
|
|
+ )
|
|
|
+ //创建任务分发器
|
|
|
+ dispatcher := NewTaskDispatcher(me.Conf.MaxChannel,
|
|
|
+ me.Conf.BatchLoadTask/me.Conf.MaxChannel, me.Conf.Debug)
|
|
|
+
|
|
|
+ //启动消费者
|
|
|
+ for widx := range me.Conf.MaxWorker {
|
|
|
+ cidx, ch := dispatcher.getWorkerChanByIdx(widx)
|
|
|
+ wg.Add(1)
|
|
|
+ go func(workerID uint, channelID uint, taskCh chan Task) {
|
|
|
+ defer wg.Done()
|
|
|
+ gidStr := ""
|
|
|
+ if me.Conf.Debug {
|
|
|
+ gid, _ := getGoroutineId()
|
|
|
+ gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
|
|
|
+
|
|
|
+ logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......",
|
|
|
+ workerID, gidStr, channelID)
|
|
|
+ }
|
|
|
+ for task := range taskCh {
|
|
|
+ me.processTask(widx, task)
|
|
|
+ atomic.AddInt64(&consumed, 1)
|
|
|
+ }
|
|
|
+ }(widx+1, cidx+1, ch)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 生产者
|
|
|
+ wg.Add(1)
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ gidStr := ""
|
|
|
+ if me.Conf.Debug {
|
|
|
+ gid, _ := getGoroutineId()
|
|
|
+ gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
|
|
|
+ logx.Infof("Producer @1%s start......", gidStr)
|
|
|
+ }
|
|
|
+
|
|
|
+ //获得待处理异步任务列表
|
|
|
+ //tasks, err := me.getAsyncReqTaskList()
|
|
|
+ tasks, err := me.getAsyncReqTaskFairList()
|
|
|
+ if err != nil {
|
|
|
+ logx.Errorf("getAsyncReqTaskList err:%s", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // 分发任务
|
|
|
+ for _, task := range tasks {
|
|
|
+ dispatcher.Dispatch(task)
|
|
|
+ atomic.AddInt64(&produced, 1)
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("📦Producer @1 此批次共创建任务%d件\n", len(tasks))
|
|
|
+
|
|
|
+ // 关闭所有会话通道
|
|
|
+ dispatcher.mu.Lock()
|
|
|
+ for _, ch := range dispatcher.workerChs {
|
|
|
+ _ = ch
|
|
|
+ close(ch)
|
|
|
+ }
|
|
|
+ dispatcher.mu.Unlock()
|
|
|
+ }()
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+ consumedRatStr := ""
|
|
|
+ if atomic.LoadInt64(&produced) > 0 {
|
|
|
+ consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed),
|
|
|
+ (atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
|
|
|
+ }
|
|
|
+ fmt.Printf("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s\n", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr)
|
|
|
+
|
|
|
+ return produced, consumed
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) InitServiceContext() *svc.ServiceContext {
|
|
|
+ rds := me.Conf.RedisConf.MustNewUniversalRedis()
|
|
|
+
|
|
|
+ dbOpts := []ent.Option{ent.Log(logx.Info),
|
|
|
+ ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())}
|
|
|
+ if me.Conf.Debug {
|
|
|
+ dbOpts = append(dbOpts, ent.Debug())
|
|
|
+ }
|
|
|
+ db := ent.NewClient(dbOpts...)
|
|
|
+ svcCtx := svc.ServiceContext{DB: db, Rds: rds}
|
|
|
+ //svcCtx.Config
|
|
|
+ return &svcCtx
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) adjustConf() {
|
|
|
+ if me.Conf.MaxWorker <= 0 {
|
|
|
+ me.Conf.MaxWorker = 2
|
|
|
+ }
|
|
|
+ if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker {
|
|
|
+ me.Conf.MaxChannel = me.Conf.MaxWorker
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (me *AsyncTask) Run(ctx context.Context) {
|
|
|
+ me.ctx = ctx
|
|
|
+ me.Logger = logx.WithContext(ctx)
|
|
|
+ me.adjustConf()
|
|
|
+
|
|
|
+ /*
|
|
|
+ tasks, err := me.getAsyncReqTaskFairList()
|
|
|
+ if err != nil {
|
|
|
+ fmt.Println(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ for idx, task := range tasks {
|
|
|
+ if idx > 20 {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ fmt.Printf("#%d=>%d ||[%s]'%s' || '%s' || '%s'|| '%s'\n", idx, task.Data.ID, task.Data.CreatedAt, task.Data.EventType,
|
|
|
+ task.Data.Model, task.Data.OpenaiBase, task.Data.ChatID)
|
|
|
+ }
|
|
|
+ */
|
|
|
+
|
|
|
+ batchId := 0
|
|
|
+ for {
|
|
|
+ batchId++
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ // 收到了取消信号
|
|
|
+ fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err())
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ timeStart := time.Now()
|
|
|
+ secStart := timeStart.Unix()
|
|
|
+ fmt.Printf("[%s]batchWork#%d start......\n", timeStart.Format("2006-01-02 15:04:05"), batchId)
|
|
|
+ produced, _ := me.batchWork()
|
|
|
+ timeEnd := time.Now()
|
|
|
+ fmt.Printf("[%s]batchWork#%d end,spent %d sec\n", timeEnd.Format("2006-01-02 15:04:05"),
|
|
|
+ batchId, timeEnd.Unix()-secStart)
|
|
|
+ timeDurnum := 1
|
|
|
+
|
|
|
+ if produced == 0 {
|
|
|
+ timeDurnum = 5
|
|
|
+ }
|
|
|
+ fmt.Printf("batchWork will sleep %d sec\n", timeDurnum)
|
|
|
+ time.Sleep(time.Duration(timeDurnum) * time.Second)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func NewAsyncTask() *AsyncTask {
|
|
|
+
|
|
|
+ ataskObj := AsyncTask{}
|
|
|
+ flag.Parse() //将命令行参数也塞入flag.CommandLine结构
|
|
|
+ //fmt.Println(typekit.PrettyPrint(flag.CommandLine))
|
|
|
+ conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv())
|
|
|
+ //fmt.Println(typekit.PrettyPrint(ataskObj.Conf))
|
|
|
+ ataskObj.svcCtx = ataskObj.InitServiceContext()
|
|
|
+
|
|
|
+ return &ataskObj
|
|
|
+}
|
|
|
+
|
|
|
+func main() {
|
|
|
+
|
|
|
+ fmt.Println("Compapi Asynctask Start......")
|
|
|
+ //ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
|
+ defer cancel()
|
|
|
+ NewAsyncTask().Run(ctx)
|
|
|
+
|
|
|
+ fmt.Println("Compapi Asynctask End......")
|
|
|
+
|
|
|
+}
|