浏览代码

Merge branch 'origin/develop/apikey_and_agent_v1.013' into debug

* origin/develop/apikey_and_agent_v1.013:
  loop try的因子常量化
  1.修改取任务方法按chat_id均分 2.发消息封装成 api
  暂存一下
  更改了由于增加信号处理后双层循环不能break的问题
  将compapi asynctask从crontask中独立出来,永久运行不靠框架机制

# Conflicts:
#	crontask/init.go
boweniac 1 周之前
父节点
当前提交
b83800e1bf

+ 715 - 0
cli/asynctask/asynctask.go

@@ -0,0 +1,715 @@
+package main
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"errors"
+	"flag"
+	"fmt"
+	"hash/fnv"
+	"os"
+	"os/signal"
+	"reflect"
+	"runtime"
+	"strconv"
+	"strings"
+	"sync"
+	"sync/atomic"
+	"syscall"
+	"time"
+
+	"wechat-api/ent"
+	"wechat-api/ent/compapiasynctask"
+	"wechat-api/ent/predicate"
+	"wechat-api/internal/svc"
+	"wechat-api/internal/types"
+	"wechat-api/internal/utils/compapi"
+
+	"github.com/suyuan32/simple-admin-common/config"
+	"github.com/zeromicro/go-zero/core/conf"
+	"github.com/zeromicro/go-zero/core/logx"
+)
+
+const (
+	Task_Ready    = 10            //任务就绪
+	ReqApi_Done   = 20            //请求API完成
+	Callback_Done = 30            //请求回调完成
+	All_Done      = Callback_Done //全部完成 (暂时将成功状态的终点标注于此)
+	Task_Suspend  = 60            //任务暂停
+	Task_Fail     = 70            //任务失败
+
+	LoopTryCount    = 3 //循环体内重试次数
+	LoopDelayFactor = 3
+	ErrTaskTryCount = 3 //最大允许错误任务重试次数
+
+	DefaultDisId = "DIS0001"
+)
+
+type Config struct {
+	BatchLoadTask uint `json:",default=100"`
+	MaxWorker     uint `json:",default=2"`
+	MaxChannel    uint `json:",default=1"`
+	Debug         bool `json:",default=false"`
+	DatabaseConf  config.DatabaseConf
+	RedisConf     config.RedisConf
+}
+
+type TaskStat struct {
+}
+
+type AsyncTask struct {
+	logx.Logger
+	ctx    context.Context
+	svcCtx *svc.ServiceContext
+	Conf   Config
+	Stats  TaskStat
+}
+
+type Task struct {
+	Data *ent.CompapiAsynctask
+	Idx  int
+	Code int
+}
+
+// 带会话管理的任务通道组
+type TaskDispatcher struct {
+	mu        sync.Mutex
+	workerChs []chan Task // 每个worker独立通道
+	Debug     bool
+}
+
+var configFile = flag.String("f", "./etc/asynctask.yaml", "the config file")
+
+func getGoroutineId() (int64, error) {
+	// 堆栈结果中需要消除的前缀符
+	var goroutineSpace = []byte("goroutine ")
+
+	bs := make([]byte, 128)
+	bs = bs[:runtime.Stack(bs, false)]
+	bs = bytes.TrimPrefix(bs, goroutineSpace)
+	i := bytes.IndexByte(bs, ' ')
+	if i < 0 {
+		return -1, errors.New("get current goroutine id failed")
+	}
+	return strconv.ParseInt(string(bs[:i]), 10, 64)
+}
+
+func NewTaskDispatcher(channelCount uint, chanSize uint, debug bool) *TaskDispatcher {
+	td := &TaskDispatcher{
+		workerChs: make([]chan Task, channelCount),
+		Debug:     debug,
+	}
+	// 初始化worker通道
+	for i := range td.workerChs {
+		td.workerChs[i] = make(chan Task, chanSize+1) // 每个worker带缓冲
+	}
+	return td
+}
+
+// 按woker下标索引选择workerChs
+func (td *TaskDispatcher) getWorkerChanByIdx(idx uint) (uint, chan Task) {
+	nidx := idx % uint(len(td.workerChs))
+	return nidx, td.workerChs[nidx]
+}
+
+// 按哈希分片选择workerChs
+func (td *TaskDispatcher) getWorkerChanByHash(disID string) (int, chan Task) {
+	if len(disID) == 0 {
+		disID = DefaultDisId
+	}
+	idx := 0
+	if len(td.workerChs) > 1 {
+		h := fnv.New32a()
+		h.Write([]byte(disID))
+		idx = int(h.Sum32()) % len(td.workerChs)
+	}
+	if td.Debug {
+		outStr := fmt.Sprintf("getWorkerChannel by disId Hash:'%s',from workerChs{", disID)
+		for i := range len(td.workerChs) {
+			outStr += fmt.Sprintf("#%d", i+1)
+			if i < len(td.workerChs)-1 {
+				outStr += ","
+			}
+		}
+		outStr += fmt.Sprintf("} choice chs:'#%d' by '%s'", idx+1, disID)
+		logx.Debug(outStr)
+	}
+
+	return idx, td.workerChs[idx]
+}
+
+// 分配任务到对应的消费channel
+func (td *TaskDispatcher) Dispatch(task Task) {
+	td.mu.Lock()
+	defer td.mu.Unlock()
+
+	// 根据chatId哈希获得其对应的workerChan
+	workerChIdx, workerCh := td.getWorkerChanByHash(task.Data.EventType)
+	// 将任务送入该chatid的workerChan
+	workerCh <- task
+	if td.Debug {
+		logx.Debugf("Producer:EventType:[%s] Task Push to WorkerChan:#%d", task.Data.EventType, workerChIdx+1)
+	}
+}
+
+// 更新任务状态
+func (me *AsyncTask) updateTaskStatus(taskId uint64, status int) error {
+	_, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
+		SetTaskStatus(int8(status)).
+		SetUpdatedAt(time.Now()).
+		Save(me.ctx)
+	return err
+}
+
+// 更新请求大模型后结果
+func (me *AsyncTask) updateApiResponse(taskId uint64, apiResponse string) error {
+	_, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
+		SetUpdatedAt(time.Now()).
+		SetResponseRaw(apiResponse).
+		SetLastError("").
+		SetRetryCount(0).
+		Save(me.ctx)
+	return err
+}
+
+func (me *AsyncTask) updateCallbackResponse(taskId uint64, callRes any) error {
+	callResStr := ""
+	switch v := callRes.(type) {
+	case []byte:
+		callResStr = string(v)
+	default:
+		if bs, err := json.Marshal(v); err == nil {
+			callResStr = string(bs)
+		} else {
+			return err
+		}
+	}
+	_, err := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskId).
+		SetUpdatedAt(time.Now()).
+		SetCallbackResponseRaw(callResStr).
+		SetLastError("").
+		SetRetryCount(0).
+		Save(me.ctx)
+	return err
+}
+
+func (me *AsyncTask) checkErrRetry(taskData *ent.CompapiAsynctask) (bool, error) {
+	var err error
+	var needStop = false
+	if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
+		_, err = me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
+			SetUpdatedAt(time.Now()).
+			SetTaskStatus(int8(Task_Fail)).
+			Save(me.ctx)
+		if err == nil {
+			needStop = true
+			taskData.TaskStatus = Task_Fail
+		}
+	}
+	return needStop, err
+}
+
+// 错误任务处理
+func (me *AsyncTask) dealErrorTask(taskData *ent.CompapiAsynctask, lasterr error) error {
+	logx.Debug("多次循环之后依然失败,进入错误任务处理环节")
+	cauo := me.svcCtx.DB.CompapiAsynctask.UpdateOneID(taskData.ID).
+		SetUpdatedAt(time.Now())
+	if taskData.RetryCount >= ErrTaskTryCount { //错误任务尝试次数超过约定则将任务状态永久设为失败
+		taskData.TaskStatus = Task_Fail
+		cauo = cauo.SetTaskStatus(int8(Task_Fail))
+	} else {
+		cauo = cauo.SetRetryCount(taskData.RetryCount + 1).
+			SetLastError(lasterr.Error())
+	}
+	_, err := cauo.Save(me.ctx)
+	return err
+}
+
+func (me *AsyncTask) requestCallback(taskData *ent.CompapiAsynctask) (int, error) {
+	var workerId int64
+	if me.Conf.Debug {
+		workerId, _ = getGoroutineId()
+		logx.Debugf("Worker:%d INTO requestCallback for task status:%d", workerId, taskData.TaskStatus)
+	}
+
+	if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
+		return 1, errors.New("too many err retry")
+	}
+	if taskData.TaskStatus != ReqApi_Done {
+		return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
+	}
+
+	if len(taskData.CallbackURL) == 0 {
+		return 0, errors.New("callback url empty")
+	}
+	if len(taskData.ResponseRaw) == 0 {
+		return 0, errors.New("call api response empty")
+	}
+
+	/*
+		fstr := "mytest-svc:"
+		if taskData.RetryCount >= 0 && strings.Contains(taskData.CallbackURL, fstr) {
+			taskData.CallbackURL = strings.Replace(taskData.CallbackURL, fstr, "0.0.0.0:", 1)
+		}
+	*/
+
+	//请求预先给定的callback_url
+	res := map[string]any{}
+	var err error
+	//初始化client
+	client := compapi.NewClient(me.ctx)
+	for i := range LoopTryCount { //重试机制
+		select {
+		case <-me.ctx.Done(): //接到信号退出
+			goto endloopTry
+		default:
+			res, err = client.Callback(taskData.EventType, taskData.CallbackURL, taskData)
+			if err == nil {
+				//call succ
+				//fmt.Println("callback succ..........")
+				//fmt.Println(typekit.PrettyPrint(res))
+				if me.Conf.Debug {
+					logx.Infof("callback:'%s' succ", taskData.CallbackURL)
+				}
+				goto endloopTry
+			}
+			logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
+				taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
+			time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
+		}
+	}
+
+	//多次循环之后依然失败,进入错误任务处理环节
+endloopTry:
+	if err != nil {
+		err1 := me.dealErrorTask(taskData, err) //错误任务处理
+		et := 1
+		if err1 != nil {
+			et = 0
+		}
+		return et, err
+	}
+	//成功后处理环节
+
+	//更新任务状态 => Callback_Done(回调完成)
+	err = me.updateTaskStatus(taskData.ID, Callback_Done)
+	if err != nil {
+		return 0, err
+	}
+	//更新taskData.CallbackResponseRaw
+	if len(res) > 0 {
+		me.updateCallbackResponse(taskData.ID, res)
+	}
+
+	taskData.TaskStatus = Callback_Done //状态迁移
+	return 1, nil
+}
+
+func (me *AsyncTask) requestAPI(taskData *ent.CompapiAsynctask) (int, error) {
+	var workerId int64
+	if me.Conf.Debug {
+		workerId, _ = getGoroutineId()
+		logx.Debugf("Worker:%d INTO requestAPI for task status:%d", workerId, taskData.TaskStatus)
+	}
+
+	if needStop, _ := me.checkErrRetry(taskData); needStop { //重试次数检测,如果超过直接标为永久失败而不再处理
+		return 1, errors.New("too many err retry")
+	}
+
+	if taskData.TaskStatus != Task_Ready {
+		return 0, fmt.Errorf("invalid task run order for status:%d", taskData.TaskStatus)
+	}
+	var (
+		err     error
+		apiResp *types.CompOpenApiResp
+	)
+	req := types.CompApiReq{}
+	if err = json.Unmarshal([]byte(taskData.RequestRaw), &req); err != nil {
+		return 0, err
+	}
+	//初始化client
+	if !strings.HasSuffix(taskData.OpenaiBase, "/") {
+		taskData.OpenaiBase = taskData.OpenaiBase + "/"
+	}
+	client := compapi.NewClient(me.ctx, compapi.WithApiBase(taskData.OpenaiBase),
+		compapi.WithApiKey(taskData.OpenaiKey))
+
+	for i := range LoopTryCount { //重试机制
+
+		select {
+		case <-me.ctx.Done(): //接到信号退出
+			goto endloopTry
+		default:
+			apiResp, err = client.Chat(&req)
+			if err == nil && apiResp != nil && len(apiResp.Choices) > 0 {
+				//call succ
+				goto endloopTry
+			} else if apiResp != nil && len(apiResp.Choices) == 0 {
+				err = errors.New("返回结果缺失,请检查访问权限")
+			}
+			if me.Conf.Debug {
+				logx.Infof("Worker:%d call '%s' fail: '%s',sleep %d Second for next(%d/%d/%d)", workerId,
+					taskData.CallbackURL, err, 1+i*LoopDelayFactor, i+1, LoopTryCount, taskData.RetryCount)
+			}
+			time.Sleep(time.Duration(1+i*LoopDelayFactor) * time.Second)
+		}
+	}
+
+endloopTry:
+	//多次循环之后依然失败,进入错误任务处理环节
+	if err != nil || apiResp == nil {
+		if apiResp == nil && err == nil {
+			err = errors.New("resp is null")
+		}
+		err1 := me.dealErrorTask(taskData, err) //错误任务处理
+		et := 1
+		if err1 != nil {
+			et = 0
+		}
+		return et, err
+	}
+	//成功后处理环节
+
+	//更新任务状态 => ReqApi_Done(请求API完成)
+	err = me.updateTaskStatus(taskData.ID, ReqApi_Done)
+	if err != nil {
+		return 0, err
+	}
+	//更新taskData.ResponseRaw
+	taskData.ResponseRaw, err = (*apiResp).ToString()
+	if err != nil {
+		return 0, err
+	}
+	err = me.updateApiResponse(taskData.ID, taskData.ResponseRaw)
+	if err != nil {
+		return 0, err
+	}
+
+	taskData.TaskStatus = ReqApi_Done //状态迁移
+	return 1, nil
+}
+
+func EntStructGenScanField(structPtr any) (string, []any, error) {
+	t := reflect.TypeOf(structPtr)
+	v := reflect.ValueOf(structPtr)
+
+	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
+		return "", nil, errors.New("input must be a pointer to a struct")
+	}
+	t = t.Elem()
+	v = v.Elem()
+
+	var fields []string
+	var scanArgs []any
+	for i := 0; i < t.NumField(); i++ {
+		field := t.Field(i)
+		value := v.Field(i)
+
+		// Skip unexported fields
+		if !field.IsExported() {
+			continue
+		}
+
+		// Get json tag
+		jsonTag := field.Tag.Get("json")
+		if jsonTag == "-" || jsonTag == "" {
+			continue
+		}
+
+		jsonParts := strings.Split(jsonTag, ",")
+		jsonName := jsonParts[0]
+		if jsonName == "" {
+			continue
+		}
+
+		fields = append(fields, jsonName)
+		scanArgs = append(scanArgs, value.Addr().Interface())
+	}
+	return strings.Join(fields, ", "), scanArgs, nil
+}
+
+/*
+CREATE INDEX idx_compapi_task_status_chat_id_id_desc
+ON compapi_asynctask (task_status, chat_id, id DESC);
+*/
+func (me *AsyncTask) getAsyncReqTaskFairList() ([]Task, error) {
+	fieldListStr, _, err := EntStructGenScanField(&ent.CompapiAsynctask{})
+	if err != nil {
+		return nil, err
+	}
+	rawQuery := fmt.Sprintf(`
+		    WITH ranked AS (
+		        SELECT %s,
+		               ROW_NUMBER() OVER (PARTITION BY chat_id ORDER BY id DESC) AS rn
+		        FROM compapi_asynctask
+		        WHERE task_status < ?
+		    )
+		    SELECT %s
+		    FROM ranked
+		    WHERE rn <= ?
+			ORDER BY rn,id DESC
+		    LIMIT ?;
+		    `, fieldListStr, fieldListStr)
+
+	// 执行原始查询
+	rows, err := me.svcCtx.DB.QueryContext(me.ctx, rawQuery, All_Done,
+		me.Conf.BatchLoadTask, me.Conf.BatchLoadTask)
+	if err != nil {
+		return nil, fmt.Errorf("fetch fair tasks query error: %w", err)
+	}
+	defer rows.Close()
+	Idx := 0
+	tasks := []Task{}
+	for rows.Next() {
+		taskrow := ent.CompapiAsynctask{}
+		var scanParams []any
+		_, scanParams, err = EntStructGenScanField(&taskrow)
+		if err != nil {
+			break
+		}
+		err = rows.Scan(scanParams...)
+		if err != nil {
+			break
+		}
+		task := Task{Data: &taskrow, Idx: Idx}
+		tasks = append(tasks, task)
+		Idx++
+	}
+	fmt.Printf("getAsyncReqTaskFairList get task:%d\n", len(tasks))
+	return tasks, nil
+}
+
+func (me *AsyncTask) getAsyncReqTaskList() ([]Task, error) {
+	var predicates []predicate.CompapiAsynctask
+	predicates = append(predicates, compapiasynctask.TaskStatusLT(All_Done))
+
+	var tasks []Task
+	res, err := me.svcCtx.DB.CompapiAsynctask.Query().Where(predicates...).
+		Order(ent.Desc(compapiasynctask.FieldID)).
+		Limit(int(me.Conf.BatchLoadTask)).
+		All(me.ctx)
+	if err == nil {
+		for idx, val := range res {
+			tasks = append(tasks, Task{Data: val, Idx: idx})
+		}
+	}
+	return tasks, err
+}
+
+func (me *AsyncTask) processTask(workerID uint, task Task) {
+	/*
+		fmt.Printf("In processTask,Consumer(%d) dealing Task Detail: User(%s/%s/%s) Async Call %s(%s) on Status:%d\n",
+			workerID, task.Data.EventType, task.Data.ChatID, task.Data.AuthToken,
+			task.Data.OpenaiBase, task.Data.OpenaiKey, task.Data.TaskStatus)
+	*/
+	_ = workerID
+	var err error
+	rt := 0
+	for {
+		select {
+		case <-me.ctx.Done(): //接到信号退出
+			return
+		default:
+			if task.Data.TaskStatus >= All_Done {
+				goto endfor //原来的break操作加了switch一层不好使了
+			}
+			switch task.Data.TaskStatus {
+			case Task_Ready:
+				//请求API平台
+				//	succ: taskStatus Task_Ready => ReqApi_Done
+				//  fail: taskStatus保持当前不变或Task_Fail
+				rt, err = me.requestAPI(task.Data)
+			case ReqApi_Done:
+				//结果回调
+				// succ: taskStatus ReqApi_Done => Callback_Done(All_Done)
+				// fail: taskStatus保持当前不变或Task_Fail
+				rt, err = me.requestCallback(task.Data)
+			}
+			if err != nil {
+				//收集错误
+				if rt == 0 {
+					//不可恢复错误处理....
+					logx.Errorf("Task error by '%s'", err)
+				} else {
+					logx.Debugf("Task ignore by '%s'", err)
+				}
+				return //先暂时忽略处理,也许应按错误类型分别对待
+			}
+		}
+	}
+endfor:
+}
+
+func (me *AsyncTask) batchWork() (int64, int64) {
+
+	var (
+		wg       sync.WaitGroup
+		produced int64 //生产数量(原子计数器)
+		consumed int64 //消费数量(原子计数器)
+	)
+	//创建任务分发器
+	dispatcher := NewTaskDispatcher(me.Conf.MaxChannel,
+		me.Conf.BatchLoadTask/me.Conf.MaxChannel, me.Conf.Debug)
+
+	//启动消费者
+	for widx := range me.Conf.MaxWorker {
+		cidx, ch := dispatcher.getWorkerChanByIdx(widx)
+		wg.Add(1)
+		go func(workerID uint, channelID uint, taskCh chan Task) {
+			defer wg.Done()
+			gidStr := ""
+			if me.Conf.Debug {
+				gid, _ := getGoroutineId()
+				gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
+
+				logx.Infof("Consumer @%d%s bind WorkerChan:#%d start......",
+					workerID, gidStr, channelID)
+			}
+			for task := range taskCh {
+				me.processTask(widx, task)
+				atomic.AddInt64(&consumed, 1)
+			}
+		}(widx+1, cidx+1, ch)
+	}
+
+	// 生产者
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		gidStr := ""
+		if me.Conf.Debug {
+			gid, _ := getGoroutineId()
+			gidStr = fmt.Sprintf("(Goroutine:%d)", gid)
+			logx.Infof("Producer @1%s start......", gidStr)
+		}
+
+		//获得待处理异步任务列表
+		//tasks, err := me.getAsyncReqTaskList()
+		tasks, err := me.getAsyncReqTaskFairList()
+		if err != nil {
+			logx.Errorf("getAsyncReqTaskList err:%s", err)
+			return
+		}
+
+		// 分发任务
+		for _, task := range tasks {
+			dispatcher.Dispatch(task)
+			atomic.AddInt64(&produced, 1)
+		}
+
+		fmt.Printf("📦Producer @1 此批次共创建任务%d件\n", len(tasks))
+
+		// 关闭所有会话通道
+		dispatcher.mu.Lock()
+		for _, ch := range dispatcher.workerChs {
+			_ = ch
+			close(ch)
+		}
+		dispatcher.mu.Unlock()
+	}()
+
+	wg.Wait()
+	consumedRatStr := ""
+	if atomic.LoadInt64(&produced) > 0 {
+		consumedRatStr = fmt.Sprintf(" (%d/%d)*100=%d%%", atomic.LoadInt64(&produced), atomic.LoadInt64(&consumed),
+			(atomic.LoadInt64(&consumed)/atomic.LoadInt64(&produced))*100)
+	}
+	fmt.Printf("🏁本次任务完成度统计: Task dispatch: %d(%d)(Producer:1),Task deal: %d(Consumer:%d)%s\n", atomic.LoadInt64(&produced), me.Conf.BatchLoadTask, atomic.LoadInt64(&consumed), me.Conf.MaxWorker, consumedRatStr)
+
+	return produced, consumed
+}
+
+func (me *AsyncTask) InitServiceContext() *svc.ServiceContext {
+	rds := me.Conf.RedisConf.MustNewUniversalRedis()
+
+	dbOpts := []ent.Option{ent.Log(logx.Info),
+		ent.Driver(me.Conf.DatabaseConf.NewNoCacheDriver())}
+	if me.Conf.Debug {
+		dbOpts = append(dbOpts, ent.Debug())
+	}
+	db := ent.NewClient(dbOpts...)
+	svcCtx := svc.ServiceContext{DB: db, Rds: rds}
+	//svcCtx.Config
+	return &svcCtx
+}
+
+func (me *AsyncTask) adjustConf() {
+	if me.Conf.MaxWorker <= 0 {
+		me.Conf.MaxWorker = 2
+	}
+	if me.Conf.MaxChannel <= 0 || me.Conf.MaxChannel > me.Conf.MaxWorker {
+		me.Conf.MaxChannel = me.Conf.MaxWorker
+	}
+}
+
+func (me *AsyncTask) Run(ctx context.Context) {
+	me.ctx = ctx
+	me.Logger = logx.WithContext(ctx)
+	me.adjustConf()
+
+	/*
+		tasks, err := me.getAsyncReqTaskFairList()
+		if err != nil {
+			fmt.Println(err)
+			return
+		}
+		for idx, task := range tasks {
+			if idx > 20 {
+				break
+			}
+			fmt.Printf("#%d=>%d ||[%s]'%s' || '%s' || '%s'|| '%s'\n", idx, task.Data.ID, task.Data.CreatedAt, task.Data.EventType,
+				task.Data.Model, task.Data.OpenaiBase, task.Data.ChatID)
+		}
+	*/
+
+	batchId := 0
+	for {
+		batchId++
+		select {
+		case <-ctx.Done():
+			// 收到了取消信号
+			fmt.Printf("Main Will Shutting down gracefully... Reason: %v\n", ctx.Err())
+			return
+		default:
+			timeStart := time.Now()
+			secStart := timeStart.Unix()
+			fmt.Printf("[%s]batchWork#%d start......\n", timeStart.Format("2006-01-02 15:04:05"), batchId)
+			produced, _ := me.batchWork()
+			timeEnd := time.Now()
+			fmt.Printf("[%s]batchWork#%d end,spent %d sec\n", timeEnd.Format("2006-01-02 15:04:05"),
+				batchId, timeEnd.Unix()-secStart)
+			timeDurnum := 1
+
+			if produced == 0 {
+				timeDurnum = 5
+			}
+			fmt.Printf("batchWork will sleep %d sec\n", timeDurnum)
+			time.Sleep(time.Duration(timeDurnum) * time.Second)
+		}
+	}
+}
+
+func NewAsyncTask() *AsyncTask {
+
+	ataskObj := AsyncTask{}
+	flag.Parse() //将命令行参数也塞入flag.CommandLine结构
+	//fmt.Println(typekit.PrettyPrint(flag.CommandLine))
+	conf.MustLoad(*configFile, &ataskObj.Conf, conf.UseEnv())
+	//fmt.Println(typekit.PrettyPrint(ataskObj.Conf))
+	ataskObj.svcCtx = ataskObj.InitServiceContext()
+
+	return &ataskObj
+}
+
+func main() {
+
+	fmt.Println("Compapi Asynctask Start......")
+	//ctx, cancel := context.WithCancel(context.Background())
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+	defer cancel()
+	NewAsyncTask().Run(ctx)
+
+	fmt.Println("Compapi Asynctask End......")
+
+}

+ 21 - 0
cli/asynctask/etc/asynctask.yaml

@@ -0,0 +1,21 @@
+BatchLoadTask: 200 #每批次取任务数
+MaxWorker: 10       #最大消费者数量
+MaxChannel: 2      #最大消费通道数量
+Debug: false
+
+DatabaseConf: #数据库配置
+  Type: mysql
+  Host: 127.0.0.1
+  Port: 3306
+  DBName: wechat
+  Username: root
+  Password: simple-admin.
+  MaxOpenConn: 100
+  SSLMode: disable
+  CacheTime: 5
+
+RedisConf: #redis配置
+  Host: 127.0.0.1:6379
+
+
+

+ 11 - 0
desc/openapi/chat.api

@@ -210,3 +210,14 @@ service Wechat {
     @handler chatCompletions
     post /chat/completions (CompApiReq) returns (CompOpenApiResp)
 }
+
+@server(
+	
+    group: chat
+    middleware: OpenAuthority
+)
+
+service Wechat {
+	@handler sendTextMsg
+	post /wx/sendTextMsg (SendTextMsgReq) returns (BaseMsgResp)
+}	

+ 44 - 0
internal/handler/chat/send_text_msg_handler.go

@@ -0,0 +1,44 @@
+package chat
+
+import (
+	"net/http"
+
+	"github.com/zeromicro/go-zero/rest/httpx"
+
+	"wechat-api/internal/logic/chat"
+	"wechat-api/internal/svc"
+	"wechat-api/internal/types"
+)
+
+// swagger:route post /v1/wx/sendTextMsg chat SendTextMsg
+//
+
+//
+
+//
+// Parameters:
+//  + name: body
+//    require: true
+//    in: body
+//    type: SendTextMsgReq
+//
+// Responses:
+//  200: BaseMsgResp
+
+func SendTextMsgHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		var req types.SendTextMsgReq
+		if err := httpx.Parse(r, &req, true); err != nil {
+			httpx.ErrorCtx(r.Context(), w, err)
+			return
+		}
+
+		l := chat.NewSendTextMsgLogic(r.Context(), svcCtx)
+		resp, err := l.SendTextMsg(&req)
+		if err != nil {
+			httpx.ErrorCtx(r.Context(), w, err)
+		} else {
+			httpx.OkJsonCtx(r.Context(), w, resp)
+		}
+	}
+}

+ 13 - 0
internal/handler/routes.go

@@ -921,6 +921,19 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
 
 	server.AddRoutes(
 		rest.WithMiddlewares(
+			[]rest.Middleware{serverCtx.OpenAuthority},
+			[]rest.Route{
+				{
+					Method:  http.MethodPost,
+					Path:    "/wx/sendTextMsg",
+					Handler: chat.SendTextMsgHandler(serverCtx),
+				},
+			}...,
+		),
+	)
+
+	server.AddRoutes(
+		rest.WithMiddlewares(
 			[]rest.Middleware{serverCtx.Authority},
 			[]rest.Route{
 				{

+ 108 - 0
internal/logic/chat/send_text_msg_logic.go

@@ -0,0 +1,108 @@
+package chat
+
+import (
+	"context"
+	"errors"
+
+	"wechat-api/ent"
+	"wechat-api/ent/predicate"
+	"wechat-api/ent/wx"
+	"wechat-api/hook"
+	"wechat-api/internal/svc"
+	"wechat-api/internal/types"
+	"wechat-api/internal/utils/contextkey"
+
+	"github.com/suyuan32/simple-admin-common/enum/errorcode"
+	"github.com/suyuan32/simple-admin-common/msg/errormsg"
+	"github.com/zeromicro/go-zero/core/logx"
+)
+
+type SendTextMsgLogic struct {
+	logx.Logger
+	ctx    context.Context
+	svcCtx *svc.ServiceContext
+}
+
+func NewSendTextMsgLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SendTextMsgLogic {
+	return &SendTextMsgLogic{
+		Logger: logx.WithContext(ctx),
+		ctx:    ctx,
+		svcCtx: svcCtx}
+}
+
+func (l *SendTextMsgLogic) SendTextMsg(req *types.SendTextMsgReq) (resp *types.BaseMsgResp, err error) {
+	// todo: add your logic here and delete this line
+
+	var (
+		apiKeyObj *ent.ApiKey
+		ok        bool
+	)
+
+	//从上下文中获取鉴权中间件埋下的apiAuthInfo
+	apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx)
+	if !ok {
+		return nil, errors.New("content get auth info err")
+	}
+
+	//根据wx实体的wxid查询
+	var predicates []predicate.Wx
+	if req.WxWxid != nil {
+		predicates = append(predicates, wx.WxidContains(*req.WxWxid))
+	}
+
+	wxInfo, err := l.svcCtx.DB.Wx.Query().Where(predicates...).Only(l.ctx)
+
+	//根据wx实体的主键ID查询
+	//wxInfo, err := l.svcCtx.DB.Wx.Get(l.ctx, *req.AgentWxId)
+
+	l.Infof("wxInfo = %v", wxInfo)
+
+	if err != nil {
+		l.Error("查询微信信息失败", err)
+		return
+	}
+	if wxInfo.OrganizationID != apiKeyObj.OrganizationID {
+		return nil, errors.New("OID不一致")
+	}
+
+	privateIP := ""
+	adminPort := ""
+	port := ""
+
+	if wxInfo.ServerID != 0 {
+		serverInfo, err := l.svcCtx.DB.Server.Get(l.ctx, wxInfo.ServerID)
+		if err != nil {
+			l.Error("查询服务器信息失败", err)
+			return nil, err
+		}
+		privateIP = serverInfo.PrivateIP
+		adminPort = serverInfo.AdminPort
+		port = wxInfo.Port
+	}
+
+	var ctype uint64
+	if req.Ctype != nil && *req.Ctype != 0 {
+		ctype = *req.Ctype
+	}
+
+	var hookClient *hook.Hook
+	if ctype == 3 {
+		hookClient = hook.NewWecomHook("", adminPort, port)
+	} else {
+		hookClient = hook.NewHook(privateIP, adminPort, port)
+	}
+
+	err = hookClient.SendTextMsg(*req.Wxid, *req.Msg, wxInfo.Wxid)
+
+	if err != nil {
+		l.Errorf("发送微信文本消息失败:%v\n", err)
+		return nil, err
+	}
+
+	resp = &types.BaseMsgResp{
+		Msg:  errormsg.Success,
+		Code: errorcode.OK,
+	}
+
+	return resp, nil
+}

+ 1 - 0
internal/utils/compapi/base.go

@@ -106,6 +106,7 @@ func (me *Client) Callback(clientType string, callbackUrl string, params any) (m
 	resp := map[string]any{}
 	err = me.OAC.Post(me.ctx, callbackUrl, newParams, &resp)
 	if err != nil {
+		fmt.Printf("Callback Post(%s) By Params:'%s' error\n", callbackUrl, string(newParams))
 		return nil, err
 	}
 	return resp, nil