channel.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package channel
  2. import (
  3. "context"
  4. "encoding/json"
  5. "github.com/openai/openai-go"
  6. "github.com/zeromicro/go-zero/core/logx"
  7. "math/rand"
  8. "time"
  9. "wechat-api/internal/pkg/customer_of_im/proto"
  10. "wechat-api/internal/pkg/wechat_ws"
  11. "wechat-api/internal/svc"
  12. )
  13. type IChannel interface {
  14. OnMessage(msg *wechat_ws.MsgJsonObject) error // channel的入口,接收消息
  15. FriendTalkHandle(msgStr string) error // 其他人发送过来的消息处理
  16. TextToChatMessage(msgStr string) (*ChatMessage, error) // 文本转ChatMessage
  17. GenerateRelevantContext(msg *ChatMessage) (*RelevantContext, error) // 创建相关上下文
  18. Send(message *ReplyMessage) error // 发送消息
  19. }
  20. type Channel struct {
  21. subChannel IChannel
  22. }
  23. func (c *Channel) InitChannel(subChannel IChannel) {
  24. c.subChannel = subChannel
  25. }
  26. func (c *Channel) After(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
  27. return true
  28. }
  29. func (c *Channel) Before(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) bool {
  30. return true
  31. }
  32. func (c *Channel) randomPause() {
  33. // 创建一个新的随机数源
  34. source := rand.NewSource(time.Now().UnixNano())
  35. // 创建一个新的随机数生成器
  36. rng := rand.New(source)
  37. // 生成 2 到 5 之间的随机整数
  38. pauseDuration := 2 + rng.Intn(4) // 2 + (0, 1, 2, 3) => 2, 3, 4, 5
  39. // 暂停程序执行相应的时间
  40. time.Sleep(time.Duration(pauseDuration) * time.Second)
  41. }
  42. func (c *Channel) beforeSendReply(reply *ReplyMessage) error {
  43. jsonReplay := make([]*ReplyJson, 0)
  44. err := json.Unmarshal([]byte(reply.Content), &jsonReplay)
  45. if err == nil && len(jsonReplay) > 0 {
  46. for i, item := range jsonReplay {
  47. if item.Content == "" {
  48. continue
  49. }
  50. switch item.Type {
  51. case "TEXT":
  52. reply.Content = item.Content
  53. _ = c.beforeSendReply(reply)
  54. default:
  55. logx.Error("未知类型", item.Type)
  56. }
  57. // 随机暂停
  58. if i < len(jsonReplay)-1 {
  59. c.randomPause()
  60. }
  61. }
  62. return nil
  63. }
  64. err = c.subChannel.Send(reply)
  65. return err
  66. }
  67. func (c *Channel) handle(chatCtx *RelevantContext, message *ChatMessage, reply *ReplyMessage) error {
  68. chat := NewChatEngine(chatCtx.BaseURL, chatCtx.APIKey, chatCtx.IsFastGPT)
  69. replyText, err := chat.ChatCompletions(context.Background(), openai.ChatModelGPT4, message.Content, message.UserId)
  70. if err != nil {
  71. logx.Error(err)
  72. return err
  73. }
  74. reply.Content = replyText
  75. err = c.beforeSendReply(reply)
  76. return err
  77. }
  78. func (c *Channel) Produce(svcCtx *svc.ServiceContext, chatCtx *RelevantContext, message *ChatMessage) {
  79. reply := &ReplyMessage{
  80. ChatMessage: &ChatMessage{
  81. UserId: message.UserId,
  82. MsgId: svcCtx.NodeID.Generate().String(),
  83. CreateTime: time.Time{},
  84. Type: proto.EnumContentType_Text,
  85. Content: "",
  86. FromUserId: message.ToUserId,
  87. FromUserNickname: message.ToUserNickname,
  88. ToUserId: message.FromUserId,
  89. ToUserNickname: message.FromUserNickname,
  90. GroupId: message.GroupId,
  91. GroupName: message.GroupName,
  92. IsGroup: message.IsGroup,
  93. IsAt: message.IsAt,
  94. },
  95. }
  96. // 进行前置处理
  97. ret := c.Before(chatCtx, message, reply)
  98. if !ret {
  99. return
  100. }
  101. // 可以使用 https://github.com/tmc/langchaingo 来进行会话记忆
  102. err := c.handle(chatCtx, message, reply)
  103. if err != nil {
  104. logx.Error(err)
  105. return
  106. }
  107. _ = c.After(chatCtx, message, reply)
  108. }