package channel import ( "context" "encoding/json" mapset "github.com/deckarep/golang-set/v2" "github.com/openai/openai-go" "github.com/zeromicro/go-zero/core/logx" "math/rand" "time" "wechat-api/database/dao/wechat/model" "wechat-api/internal/pkg/customer_of_im/proto" "wechat-api/internal/pkg/wechat_ws" "wechat-api/internal/svc" ) type IChannel interface { OnMessage(msg *wechat_ws.MsgJsonObject) error // channel的入口,接收消息 FriendTalkHandle(msgStr string) error // 其他人发送过来的消息处理 TextToChatMessage(msgStr string) (*ChatMessage, error) // 文本转ChatMessage GenerateRelevantContext(msg *ChatMessage) (*RelevantContext, error) // 创建相关上下文 Send(message *ReplyMessage) error // 发送消息 } type Channel struct { subChannel IChannel } func (c *Channel) InitChannel(subChannel IChannel) { c.subChannel = subChannel } func (c *Channel) After(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool { return true } func (c *Channel) Before(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool { return true } func (c *Channel) randomPause() { // 创建一个新的随机数源 source := rand.NewSource(time.Now().UnixNano()) // 创建一个新的随机数生成器 rng := rand.New(source) // 生成 2 到 5 之间的随机整数 pauseDuration := 2 + rng.Intn(4) // 2 + (0, 1, 2, 3) => 2, 3, 4, 5 // 暂停程序执行相应的时间 time.Sleep(time.Duration(pauseDuration) * time.Second) } func (c *Channel) CheckAllowOrBlockList(wxInfo *model.Wx, msg *ChatMessage) bool { if msg.IsGroup { if wxInfo.GroupBlockList == "" { wxInfo.GroupBlockList = "[]" } groupBlockList := make([]string, 0) err := json.Unmarshal([]byte(wxInfo.GroupBlockList), &groupBlockList) if err != nil { logx.Error("GroupBlockList 解析失败", err) } groupBlockListSet := mapset.NewSet[string](groupBlockList...) // 群聊黑名单如果包含ALL或者包含当前群聊,则返回false if groupBlockListSet.Contains("ALL") || groupBlockListSet.Contains(msg.GroupId) { return false } if wxInfo.GroupAllowList == "" { wxInfo.GroupAllowList = "[]" } groupAllowList := make([]string, 0) err = json.Unmarshal([]byte(wxInfo.GroupAllowList), &groupAllowList) if err != nil { logx.Error("GroupAllowList 解析失败", err) } groupAllowListSet := mapset.NewSet[string](groupAllowList...) // 如果群聊白名单不包含ALL并且不包含当前群聊,则返回false if !groupAllowListSet.Contains("ALL") && !groupAllowListSet.Contains(msg.GroupId) { return false } } else { if wxInfo.BlockList == "" { wxInfo.BlockList = "[]" } blockList := make([]string, 0) err := json.Unmarshal([]byte(wxInfo.BlockList), &blockList) if err != nil { logx.Error("BlockList 解析失败", err) } blockListSet := mapset.NewSet[string](blockList...) // 如果黑名单包含ALL或者包含当前用户,则返回false if blockListSet.Contains("ALL") || blockListSet.Contains(msg.UserId) { return false } if wxInfo.AllowList == "" { wxInfo.AllowList = "[]" } allowList := make([]string, 0) err = json.Unmarshal([]byte(wxInfo.AllowList), &allowList) if err != nil { logx.Error("AllowList 解析失败", err) } allowListSet := mapset.NewSet[string](allowList...) // 如果白名单不包含ALL并且不包含当前用户,则返回false if !allowListSet.Contains("ALL") && !allowListSet.Contains(msg.UserId) { return false } } return true } func (c *Channel) beforeSendReply(reply *ReplyMessage) error { jsonReplay := make([]*ReplyJson, 0) err := json.Unmarshal([]byte(reply.Content), &jsonReplay) if err == nil && len(jsonReplay) > 0 { for i, item := range jsonReplay { if item.Content == "" { continue } switch item.Type { case "TEXT": reply.Type = proto.EnumContentType_Text reply.Content = item.Content _ = c.beforeSendReply(reply) default: logx.Error("未知类型", item.Type) } // 随机暂停 if i < len(jsonReplay)-1 { c.randomPause() } } return nil } err = c.subChannel.Send(reply) return err } func (c *Channel) handle(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) error { chat := NewChatEngine(chatCtx.BaseURL, chatCtx.APIKey, chatCtx.IsFastGPT) replyText, err := chat.ChatCompletions(context.Background(), openai.ChatModelGPT4, message.Content, message.UserId) if err != nil { logx.Error(err) return err } reply.Content = replyText err = c.beforeSendReply(reply) return err } func (c *Channel) Produce(svcCtx *svc.ServiceContext, chatCtx *RelevantContext, message *ChatMessage) { reply := &ReplyMessage{ ChatMessage: &ChatMessage{ UserId: message.UserId, MsgId: svcCtx.NodeID.Generate().String(), CreateTime: time.Time{}, Type: proto.EnumContentType_Text, Content: "", FromUserId: message.ToUserId, FromUserNickname: message.ToUserNickname, ToUserId: message.FromUserId, ToUserNickname: message.FromUserNickname, GroupId: message.GroupId, GroupName: message.GroupName, IsGroup: message.IsGroup, IsAt: message.IsAt, }, } // 进行前置处理 ret := c.Before(chatCtx, message, reply) if !ret { return } // 可以使用 https://github.com/tmc/langchaingo 来进行会话记忆 err := c.handle(chatCtx, message, reply) if err != nil { logx.Error(err) return } _ = c.After(chatCtx, message, reply) }