test_node_logic.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package sop_task
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "wechat-api/ent"
  7. "wechat-api/ent/sopnode"
  8. "wechat-api/ent/sopstage"
  9. "wechat-api/internal/svc"
  10. "wechat-api/internal/types"
  11. "github.com/sashabaranov/go-openai"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. type TestNodeLogic struct {
  15. logx.Logger
  16. ctx context.Context
  17. svcCtx *svc.ServiceContext
  18. }
  19. func NewTestNodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TestNodeLogic {
  20. return &TestNodeLogic{
  21. Logger: logx.WithContext(ctx),
  22. ctx: ctx,
  23. svcCtx: svcCtx}
  24. }
  25. func (l *TestNodeLogic) TestNode(req *types.TestNodeReq) (resp *types.TestNodeResp, err error) {
  26. var nodes []*ent.SopNode
  27. q := ""
  28. if req.Type == 1 {
  29. parentNode, err := l.svcCtx.DB.SopStage.Query().Where(sopstage.IDEQ(req.Id)).Only(l.ctx)
  30. if err != nil {
  31. return nil, err
  32. }
  33. for _, message := range parentNode.ActionMessage {
  34. if message.Type == 1 {
  35. q += message.Content
  36. }
  37. }
  38. nodes, err = l.svcCtx.DB.SopNode.Query().Where(sopnode.StageIDEQ(req.Id), sopnode.ParentIDEQ(0)).All(l.ctx)
  39. if err != nil {
  40. return nil, err
  41. }
  42. } else {
  43. parentNode, err := l.svcCtx.DB.SopNode.Query().Where(sopnode.IDEQ(req.Id)).Only(l.ctx)
  44. if err != nil {
  45. return nil, err
  46. }
  47. for _, message := range parentNode.ActionMessage {
  48. if message.Type == 1 {
  49. q += message.Content
  50. }
  51. }
  52. nodes, err = l.svcCtx.DB.SopNode.Query().Where(sopnode.ParentIDEQ(req.Id)).All(l.ctx)
  53. if err != nil {
  54. return nil, err
  55. }
  56. }
  57. backupNodeIndex := -1
  58. needJudge := false
  59. prompt := fmt.Sprintf(`# 任务
  60. 请根据历史消息,判断用户回复的内容或深层意图,与哪个节点的意图相匹配。
  61. # 历史消息:
  62. 助手发送:%s
  63. 用户回复:%s
  64. # 节点列表:`, q, req.Content)
  65. for i, node := range nodes {
  66. if node.ConditionList != nil && node.ConditionList[0] != "" {
  67. needJudge = true
  68. prompt += fmt.Sprintf(`
  69. 节点 id: %d
  70. 节点意图:%s
  71. `, i, node.ConditionList)
  72. } else {
  73. if node.NoReplyCondition == 0 {
  74. backupNodeIndex = i
  75. }
  76. }
  77. }
  78. prompt += `
  79. # 回复要求
  80. - 如果命中节点:则仅回复节点 id 数字(如命中多个节点,则仅回复最小值)
  81. - 如果未命中节点:则仅回复一个单词: None`
  82. index := ""
  83. if needJudge {
  84. // 调用openai接口,使用自定义base和key
  85. baseUrl := l.svcCtx.Config.Fastgpt.BASE_URL
  86. apiKey := l.svcCtx.Config.Fastgpt.API_KEY
  87. index, err = ChatWithCustomConfig(baseUrl, apiKey, prompt)
  88. if err != nil {
  89. return nil, err
  90. }
  91. } else {
  92. index = "None"
  93. }
  94. if index == "None" && backupNodeIndex != -1 {
  95. index = strconv.Itoa(backupNodeIndex)
  96. }
  97. nodeName := ""
  98. if index != "None" {
  99. // 将index转换成int
  100. indexInt, err := strconv.Atoi(index)
  101. if err != nil {
  102. return nil, err
  103. }
  104. nodeName = nodes[indexInt].Name
  105. } else {
  106. nodeName = "FastGPT"
  107. }
  108. var nodeNames []string
  109. nodeNames = append(nodeNames, nodeName)
  110. resp = &types.TestNodeResp{}
  111. resp.Msg = "success"
  112. resp.Data = nodeNames
  113. return resp, nil
  114. }
  115. func ChatWithCustomConfig(baseURL, apiKey, prompt string) (string, error) {
  116. // 创建OpenAI客户端配置
  117. config := openai.DefaultConfig(apiKey)
  118. config.BaseURL = baseURL
  119. // 创建OpenAI客户端
  120. openaiClient := openai.NewClientWithConfig(config)
  121. // 构建请求
  122. request := openai.ChatCompletionRequest{
  123. Model: openai.GPT4o,
  124. Messages: []openai.ChatCompletionMessage{
  125. {
  126. Role: "user",
  127. Content: prompt,
  128. },
  129. },
  130. }
  131. // 调用Chat接口
  132. response, err := openaiClient.CreateChatCompletion(context.Background(), request)
  133. if err != nil {
  134. return "", err
  135. }
  136. // 返回响应内容
  137. return response.Choices[0].Message.Content, nil
  138. }