123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- package channel
- import (
- "context"
- "encoding/json"
- mapset "github.com/deckarep/golang-set/v2"
- "github.com/openai/openai-go"
- "github.com/zeromicro/go-zero/core/logx"
- "math/rand"
- "time"
- "wechat-api/database/dao/wechat/model"
- "wechat-api/internal/pkg/customer_of_im/proto"
- "wechat-api/internal/pkg/wechat_ws"
- "wechat-api/internal/svc"
- )
- type IChannel interface {
- OnMessage(msg *wechat_ws.MsgJsonObject) error // channel的入口,接收消息
- FriendTalkHandle(msgStr string) error // 其他人发送过来的消息处理
- TextToChatMessage(msgStr string) (*ChatMessage, error) // 文本转ChatMessage
- GenerateRelevantContext(msg *ChatMessage) (*RelevantContext, error) // 创建相关上下文
- Send(message *ReplyMessage) error // 发送消息
- }
- type Channel struct {
- subChannel IChannel
- }
- func (c *Channel) InitChannel(subChannel IChannel) {
- c.subChannel = subChannel
- }
- func (c *Channel) After(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
- return true
- }
- func (c *Channel) Before(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
- return true
- }
- func (c *Channel) randomPause() {
- // 创建一个新的随机数源
- source := rand.NewSource(time.Now().UnixNano())
- // 创建一个新的随机数生成器
- rng := rand.New(source)
- // 生成 2 到 5 之间的随机整数
- pauseDuration := 2 + rng.Intn(4) // 2 + (0, 1, 2, 3) => 2, 3, 4, 5
- // 暂停程序执行相应的时间
- time.Sleep(time.Duration(pauseDuration) * time.Second)
- }
- func (c *Channel) CheckAllowOrBlockList(wxInfo *model.Wx, msg *ChatMessage) bool {
- if msg.IsGroup {
- if wxInfo.GroupBlockList == "" {
- wxInfo.GroupBlockList = "[]"
- }
- groupBlockList := make([]string, 0)
- err := json.Unmarshal([]byte(wxInfo.GroupBlockList), &groupBlockList)
- if err != nil {
- logx.Error("GroupBlockList 解析失败", err)
- }
- groupBlockListSet := mapset.NewSet[string](groupBlockList...)
- // 群聊黑名单如果包含ALL或者包含当前群聊,则返回false
- if groupBlockListSet.Contains("ALL") || groupBlockListSet.Contains(msg.GroupId) {
- return false
- }
- if wxInfo.GroupAllowList == "" {
- wxInfo.GroupAllowList = "[]"
- }
- groupAllowList := make([]string, 0)
- err = json.Unmarshal([]byte(wxInfo.GroupAllowList), &groupAllowList)
- if err != nil {
- logx.Error("GroupAllowList 解析失败", err)
- }
- groupAllowListSet := mapset.NewSet[string](groupAllowList...)
- // 如果群聊白名单不包含ALL并且不包含当前群聊,则返回false
- if !groupAllowListSet.Contains("ALL") && !groupAllowListSet.Contains(msg.GroupId) {
- return false
- }
- } else {
- if wxInfo.BlockList == "" {
- wxInfo.BlockList = "[]"
- }
- blockList := make([]string, 0)
- err := json.Unmarshal([]byte(wxInfo.BlockList), &blockList)
- if err != nil {
- logx.Error("BlockList 解析失败", err)
- }
- blockListSet := mapset.NewSet[string](blockList...)
- // 如果黑名单包含ALL或者包含当前用户,则返回false
- if blockListSet.Contains("ALL") || blockListSet.Contains(msg.UserId) {
- return false
- }
- if wxInfo.AllowList == "" {
- wxInfo.AllowList = "[]"
- }
- allowList := make([]string, 0)
- err = json.Unmarshal([]byte(wxInfo.AllowList), &allowList)
- if err != nil {
- logx.Error("AllowList 解析失败", err)
- }
- allowListSet := mapset.NewSet[string](allowList...)
- // 如果白名单不包含ALL并且不包含当前用户,则返回false
- if !allowListSet.Contains("ALL") && !allowListSet.Contains(msg.UserId) {
- return false
- }
- }
- return true
- }
- func (c *Channel) beforeSendReply(reply *ReplyMessage) error {
- jsonReplay := make([]*ReplyJson, 0)
- err := json.Unmarshal([]byte(reply.Content), &jsonReplay)
- if err == nil && len(jsonReplay) > 0 {
- for i, item := range jsonReplay {
- if item.Content == "" {
- continue
- }
- switch item.Type {
- case "TEXT":
- reply.Type = proto.EnumContentType_Text
- reply.Content = item.Content
- _ = c.beforeSendReply(reply)
- default:
- logx.Error("未知类型", item.Type)
- }
- // 随机暂停
- if i < len(jsonReplay)-1 {
- c.randomPause()
- }
- }
- return nil
- }
- err = c.subChannel.Send(reply)
- return err
- }
- func (c *Channel) handle(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) error {
- chat := NewChatEngine(chatCtx.BaseURL, chatCtx.APIKey, chatCtx.IsFastGPT)
- replyText, err := chat.ChatCompletions(context.Background(), openai.ChatModelGPT4, message.Content, message.UserId)
- if err != nil {
- logx.Error(err)
- return err
- }
- reply.Content = replyText
- err = c.beforeSendReply(reply)
- return err
- }
- func (c *Channel) Produce(svcCtx *svc.ServiceContext, chatCtx *RelevantContext, message *ChatMessage) {
- reply := &ReplyMessage{
- ChatMessage: &ChatMessage{
- UserId: message.UserId,
- MsgId: svcCtx.NodeID.Generate().String(),
- CreateTime: time.Time{},
- Type: proto.EnumContentType_Text,
- Content: "",
- FromUserId: message.ToUserId,
- FromUserNickname: message.ToUserNickname,
- ToUserId: message.FromUserId,
- ToUserNickname: message.FromUserNickname,
- GroupId: message.GroupId,
- GroupName: message.GroupName,
- IsGroup: message.IsGroup,
- IsAt: message.IsAt,
- },
- }
- // 进行前置处理
- ret := c.Before(chatCtx, message, reply)
- if !ret {
- return
- }
- // 可以使用 https://github.com/tmc/langchaingo 来进行会话记忆
- err := c.handle(chatCtx, message, reply)
- if err != nil {
- logx.Error(err)
- return
- }
- _ = c.After(chatCtx, message, reply)
- }
|