channel.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package channel
  2. import (
  3. "context"
  4. "encoding/json"
  5. mapset "github.com/deckarep/golang-set/v2"
  6. "github.com/openai/openai-go"
  7. "github.com/zeromicro/go-zero/core/logx"
  8. "math/rand"
  9. "time"
  10. "wechat-api/database/dao/wechat/model"
  11. "wechat-api/internal/pkg/customer_of_im/proto"
  12. "wechat-api/internal/pkg/wechat_ws"
  13. "wechat-api/internal/svc"
  14. )
  15. type IChannel interface {
  16. OnMessage(msg *wechat_ws.MsgJsonObject) error // channel的入口,接收消息
  17. FriendTalkHandle(msgStr string) error // 其他人发送过来的消息处理
  18. TextToChatMessage(msgStr string) (*ChatMessage, error) // 文本转ChatMessage
  19. GenerateRelevantContext(msg *ChatMessage) (*RelevantContext, error) // 创建相关上下文
  20. Send(message *ReplyMessage) error // 发送消息
  21. }
  22. type Channel struct {
  23. subChannel IChannel
  24. }
  25. func (c *Channel) InitChannel(subChannel IChannel) {
  26. c.subChannel = subChannel
  27. }
  28. func (c *Channel) After(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
  29. return true
  30. }
  31. func (c *Channel) Before(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
  32. return true
  33. }
  34. func (c *Channel) randomPause() {
  35. // 创建一个新的随机数源
  36. source := rand.NewSource(time.Now().UnixNano())
  37. // 创建一个新的随机数生成器
  38. rng := rand.New(source)
  39. // 生成 2 到 5 之间的随机整数
  40. pauseDuration := 2 + rng.Intn(4) // 2 + (0, 1, 2, 3) => 2, 3, 4, 5
  41. // 暂停程序执行相应的时间
  42. time.Sleep(time.Duration(pauseDuration) * time.Second)
  43. }
  44. func (c *Channel) CheckAllowOrBlockList(wxInfo *model.Wx, msg *ChatMessage) bool {
  45. if msg.IsGroup {
  46. if wxInfo.GroupBlockList == "" {
  47. wxInfo.GroupBlockList = "[]"
  48. }
  49. groupBlockList := make([]string, 0)
  50. err := json.Unmarshal([]byte(wxInfo.GroupBlockList), &groupBlockList)
  51. if err != nil {
  52. logx.Error("GroupBlockList 解析失败", err)
  53. }
  54. groupBlockListSet := mapset.NewSet[string](groupBlockList...)
  55. // 群聊黑名单如果包含ALL或者包含当前群聊,则返回false
  56. if groupBlockListSet.Contains("ALL") || groupBlockListSet.Contains(msg.GroupId) {
  57. return false
  58. }
  59. if wxInfo.GroupAllowList == "" {
  60. wxInfo.GroupAllowList = "[]"
  61. }
  62. groupAllowList := make([]string, 0)
  63. err = json.Unmarshal([]byte(wxInfo.GroupAllowList), &groupAllowList)
  64. if err != nil {
  65. logx.Error("GroupAllowList 解析失败", err)
  66. }
  67. groupAllowListSet := mapset.NewSet[string](groupAllowList...)
  68. // 如果群聊白名单不包含ALL并且不包含当前群聊,则返回false
  69. if !groupAllowListSet.Contains("ALL") && !groupAllowListSet.Contains(msg.GroupId) {
  70. return false
  71. }
  72. } else {
  73. if wxInfo.BlockList == "" {
  74. wxInfo.BlockList = "[]"
  75. }
  76. blockList := make([]string, 0)
  77. err := json.Unmarshal([]byte(wxInfo.BlockList), &blockList)
  78. if err != nil {
  79. logx.Error("BlockList 解析失败", err)
  80. }
  81. blockListSet := mapset.NewSet[string](blockList...)
  82. // 如果黑名单包含ALL或者包含当前用户,则返回false
  83. if blockListSet.Contains("ALL") || blockListSet.Contains(msg.UserId) {
  84. return false
  85. }
  86. if wxInfo.AllowList == "" {
  87. wxInfo.AllowList = "[]"
  88. }
  89. allowList := make([]string, 0)
  90. err = json.Unmarshal([]byte(wxInfo.AllowList), &allowList)
  91. if err != nil {
  92. logx.Error("AllowList 解析失败", err)
  93. }
  94. allowListSet := mapset.NewSet[string](allowList...)
  95. // 如果白名单不包含ALL并且不包含当前用户,则返回false
  96. if !allowListSet.Contains("ALL") && !allowListSet.Contains(msg.UserId) {
  97. return false
  98. }
  99. }
  100. return true
  101. }
  102. func (c *Channel) beforeSendReply(reply *ReplyMessage) error {
  103. jsonReplay := make([]*ReplyJson, 0)
  104. err := json.Unmarshal([]byte(reply.Content), &jsonReplay)
  105. if err == nil && len(jsonReplay) > 0 {
  106. for i, item := range jsonReplay {
  107. if item.Content == "" {
  108. continue
  109. }
  110. switch item.Type {
  111. case "TEXT":
  112. reply.Type = proto.EnumContentType_Text
  113. reply.Content = item.Content
  114. _ = c.beforeSendReply(reply)
  115. default:
  116. logx.Error("未知类型", item.Type)
  117. }
  118. // 随机暂停
  119. if i < len(jsonReplay)-1 {
  120. c.randomPause()
  121. }
  122. }
  123. return nil
  124. }
  125. err = c.subChannel.Send(reply)
  126. return err
  127. }
  128. func (c *Channel) handle(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) error {
  129. chat := NewChatEngine(chatCtx.BaseURL, chatCtx.APIKey, chatCtx.IsFastGPT)
  130. replyText, err := chat.ChatCompletions(context.Background(), openai.ChatModelGPT4, message.Content, message.UserId)
  131. if err != nil {
  132. logx.Error(err)
  133. return err
  134. }
  135. reply.Content = replyText
  136. err = c.beforeSendReply(reply)
  137. return err
  138. }
  139. func (c *Channel) Produce(svcCtx *svc.ServiceContext, chatCtx *RelevantContext, message *ChatMessage) {
  140. reply := &ReplyMessage{
  141. ChatMessage: &ChatMessage{
  142. UserId: message.UserId,
  143. MsgId: svcCtx.NodeID.Generate().String(),
  144. CreateTime: time.Time{},
  145. Type: proto.EnumContentType_Text,
  146. Content: "",
  147. FromUserId: message.ToUserId,
  148. FromUserNickname: message.ToUserNickname,
  149. ToUserId: message.FromUserId,
  150. ToUserNickname: message.FromUserNickname,
  151. GroupId: message.GroupId,
  152. GroupName: message.GroupName,
  153. IsGroup: message.IsGroup,
  154. IsAt: message.IsAt,
  155. },
  156. }
  157. // 进行前置处理
  158. ret := c.Before(chatCtx, message, reply)
  159. if !ret {
  160. return
  161. }
  162. // 可以使用 https://github.com/tmc/langchaingo 来进行会话记忆
  163. err := c.handle(chatCtx, message, reply)
  164. if err != nil {
  165. logx.Error(err)
  166. return
  167. }
  168. _ = c.After(chatCtx, message, reply)
  169. }