wuroy.eth 1 giorno fa
parent
commit
d1f6709cd0

+ 71 - 0
internal/pkg/customer_of_im/channel/channel.go

@@ -3,10 +3,12 @@ package channel
 import (
 	"context"
 	"encoding/json"
+	mapset "github.com/deckarep/golang-set/v2"
 	"github.com/openai/openai-go"
 	"github.com/zeromicro/go-zero/core/logx"
 	"math/rand"
 	"time"
+	"wechat-api/database/dao/wechat/model"
 	"wechat-api/internal/pkg/customer_of_im/proto"
 	"wechat-api/internal/pkg/wechat_ws"
 	"wechat-api/internal/svc"
@@ -48,6 +50,75 @@ func (c *Channel) randomPause() {
 	time.Sleep(time.Duration(pauseDuration) * time.Second)
 }
 
+func (c *Channel) CheckAllowOrBlockList(wxInfo *model.Wx, msg *ChatMessage) bool {
+
+	if msg.IsGroup {
+		if wxInfo.GroupBlockList == "" {
+			wxInfo.GroupBlockList = "[]"
+		}
+		groupBlockList := make([]string, 0)
+		err := json.Unmarshal([]byte(wxInfo.GroupBlockList), &groupBlockList)
+		if err != nil {
+			logx.Error("GroupBlockList 解析失败", err)
+		}
+		groupBlockListSet := mapset.NewSet[string](groupBlockList...)
+
+		// 群聊黑名单如果包含ALL或者包含当前群聊,则返回false
+		if groupBlockListSet.Contains("ALL") || groupBlockListSet.Contains(msg.GroupId) {
+			return false
+		}
+
+		if wxInfo.GroupAllowList == "" {
+			wxInfo.GroupAllowList = "[]"
+		}
+		groupAllowList := make([]string, 0)
+		err = json.Unmarshal([]byte(wxInfo.GroupAllowList), &groupAllowList)
+		if err != nil {
+			logx.Error("GroupAllowList 解析失败", err)
+		}
+		groupAllowListSet := mapset.NewSet[string](groupAllowList...)
+
+		// 如果群聊白名单不包含ALL并且不包含当前群聊,则返回false
+		if !groupAllowListSet.Contains("ALL") && !groupAllowListSet.Contains(msg.GroupId) {
+			return false
+		}
+
+	} else {
+		if wxInfo.BlockList == "" {
+			wxInfo.BlockList = "[]"
+		}
+		blockList := make([]string, 0)
+		err := json.Unmarshal([]byte(wxInfo.BlockList), &blockList)
+		if err != nil {
+			logx.Error("BlockList 解析失败", err)
+		}
+		blockListSet := mapset.NewSet[string](blockList...)
+
+		// 如果黑名单包含ALL或者包含当前用户,则返回false
+		if blockListSet.Contains("ALL") || blockListSet.Contains(msg.UserId) {
+			return false
+		}
+
+		if wxInfo.AllowList == "" {
+			wxInfo.AllowList = "[]"
+		}
+		allowList := make([]string, 0)
+		err = json.Unmarshal([]byte(wxInfo.AllowList), &allowList)
+		if err != nil {
+			logx.Error("AllowList 解析失败", err)
+		}
+		allowListSet := mapset.NewSet[string](allowList...)
+
+		// 如果白名单不包含ALL并且不包含当前用户,则返回false
+		if !allowListSet.Contains("ALL") && !allowListSet.Contains(msg.UserId) {
+			return false
+		}
+	}
+
+	return true
+
+}
+
 func (c *Channel) beforeSendReply(reply *ReplyMessage) error {
 	jsonReplay := make([]*ReplyJson, 0)
 	err := json.Unmarshal([]byte(reply.Content), &jsonReplay)

+ 49 - 19
internal/pkg/customer_of_im/channel/wechat.go

@@ -4,6 +4,7 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"errors"
+	"github.com/zeromicro/go-zero/core/collection"
 	"github.com/zeromicro/go-zero/core/logx"
 	"google.golang.org/protobuf/encoding/protojson"
 	"regexp"
@@ -35,14 +36,16 @@ func NewWechatChannel(ws *wechat_ws.WechatWsClient, svcCtx *svc.ServiceContext)
 	return wechatChannel
 }
 
-func (w *WechatChannel) GetWxsCache() map[string]*model.Wx {
+func (w *WechatChannel) GetWxsCache() *collection.SafeMap {
 
 	cacheKey := "WechatChannel_WxList"
 
-	wxs := make(map[string]*model.Wx)
+	wxs := collection.NewSafeMap()
+
 	v, exist := w.svcCtx.Cache.Get(cacheKey)
 	if exist {
-		value, ok := v.(map[string]*model.Wx)
+
+		value, ok := v.(*collection.SafeMap)
 		if ok {
 			wxs = value
 			return wxs
@@ -54,12 +57,13 @@ func (w *WechatChannel) GetWxsCache() map[string]*model.Wx {
 	if err != nil {
 		logx.Error("获取微信列表失败:", err)
 	} else {
+
 		for _, wx := range wxlist {
 			if wx.AgentID != 0 {
 				wx.APIKey = w.svcCtx.Config.OpenAI.ApiKey
 				wx.APIBase = w.svcCtx.Config.OpenAI.BaseUrl
 			}
-			wxs[wx.Wxid] = wx
+			wxs.Set(wx.Wxid, wx)
 		}
 		w.svcCtx.Cache.SetWithExpire(cacheKey, wxs, time.Minute*10)
 	}
@@ -67,7 +71,23 @@ func (w *WechatChannel) GetWxsCache() map[string]*model.Wx {
 	return wxs
 }
 
-func (w *WechatChannel) GetContactCache(wxid string) *model.Contact {
+func (w *WechatChannel) GetWxInfo(wxid string) (*model.Wx, error) {
+	wxs := w.GetWxsCache()
+
+	v, exists := wxs.Get(wxid)
+	if !exists {
+		return nil, errors.New("未查找对应微信,不予处理")
+	}
+
+	wx, ok := v.(*model.Wx)
+	if !ok {
+		return nil, errors.New("微信信息转换失败")
+	}
+
+	return wx, nil
+}
+
+func (w *WechatChannel) GetContactCache(wxid string) (*model.Contact, error) {
 
 	cacheKey := "WechatChannel_WxContact_" + wxid
 
@@ -76,7 +96,7 @@ func (w *WechatChannel) GetContactCache(wxid string) *model.Contact {
 	if exist {
 		value, ok := v.(*model.Contact)
 		if ok {
-			return value
+			return value, nil
 		}
 	}
 
@@ -84,10 +104,10 @@ func (w *WechatChannel) GetContactCache(wxid string) *model.Contact {
 	contact, err := contactDao.Where(contactDao.Wxid.Eq(wxid)).First()
 	if err != nil {
 		logx.Error("获取微信联系人失败:", err)
-		return nil
+		return nil, err
 	} else {
 		w.svcCtx.Cache.SetWithExpire(cacheKey, contact, time.Hour*24)
-		return contact
+		return contact, nil
 	}
 }
 
@@ -156,23 +176,22 @@ func (w *WechatChannel) SendText(message *ReplyMessage) error {
 }
 
 func (w *WechatChannel) GenerateRelevantContext(msg *ChatMessage) (*RelevantContext, error) {
-	wxs := w.GetWxsCache()
-
-	wx, exists := wxs[msg.UserId]
-
-	if !exists {
-		return nil, errors.New("未查找对应微信,不予处理")
+	wx, err := w.GetWxInfo(msg.UserId)
+	if err != nil {
+		logx.Error("获取微信信息失败:", err)
+		return nil, err
 	}
+
 	msg.ToUserNickname = wx.Nickname
 
-	from := w.GetContactCache(msg.FromUserId)
-	if from != nil {
+	from, err := w.GetContactCache(msg.FromUserId)
+	if err == nil {
 		msg.FromUserNickname = from.Nickname
 	}
 
 	if msg.GroupId != "" {
-		group := w.GetContactCache(msg.GroupId)
-		if group != nil {
+		group, err := w.GetContactCache(msg.GroupId)
+		if err == nil {
 			msg.GroupName = group.Nickname
 		}
 	}
@@ -280,6 +299,17 @@ func (w *WechatChannel) FriendTalkHandle(msgStr string) error {
 		return nil
 	}
 
+	wxInfo, err := w.GetWxInfo(chatMsg.UserId)
+	if err != nil {
+		logx.Error("FriendTalkNotice GetWxInfo err:", err)
+		return err
+	}
+
+	//白名单检查
+	if !w.CheckAllowOrBlockList(wxInfo, chatMsg) {
+		return errors.New("不在白名单中")
+	}
+
 	chatCtx, err := w.GenerateRelevantContext(chatMsg)
 	if err != nil {
 		logx.Info("FriendTalkNotice GenerateRelevantContext err:", err)
@@ -293,6 +323,6 @@ func (w *WechatChannel) FriendTalkHandle(msgStr string) error {
 	return nil
 }
 func (w *WechatChannel) WeChatTalkToFriendHandle(msgStr string) error {
-	logx.Info("收到发送的消息:", msgStr)
+	//logx.Info("收到发送的消息:", msgStr)
 	return nil
 }