123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- package sop_task
- import (
- "context"
- "fmt"
- "strconv"
- "wechat-api/ent"
- "wechat-api/ent/sopnode"
- "wechat-api/ent/sopstage"
- "wechat-api/internal/svc"
- "wechat-api/internal/types"
- "github.com/sashabaranov/go-openai"
- "github.com/zeromicro/go-zero/core/logx"
- )
- type TestNodeLogic struct {
- logx.Logger
- ctx context.Context
- svcCtx *svc.ServiceContext
- }
- func NewTestNodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TestNodeLogic {
- return &TestNodeLogic{
- Logger: logx.WithContext(ctx),
- ctx: ctx,
- svcCtx: svcCtx}
- }
- func (l *TestNodeLogic) TestNode(req *types.TestNodeReq) (resp *types.TestNodeResp, err error) {
- var nodes []*ent.SopNode
- q := ""
- if req.Type == 1 {
- parentNode, err := l.svcCtx.DB.SopStage.Query().Where(sopstage.IDEQ(req.Id)).Only(l.ctx)
- if err != nil {
- return nil, err
- }
- for _, message := range parentNode.ActionMessage {
- if message.Type == 1 {
- q += message.Content
- }
- }
- nodes, err = l.svcCtx.DB.SopNode.Query().Where(sopnode.StageIDEQ(req.Id), sopnode.ParentIDEQ(0)).All(l.ctx)
- if err != nil {
- return nil, err
- }
- } else {
- parentNode, err := l.svcCtx.DB.SopNode.Query().Where(sopnode.IDEQ(req.Id)).Only(l.ctx)
- if err != nil {
- return nil, err
- }
- for _, message := range parentNode.ActionMessage {
- if message.Type == 1 {
- q += message.Content
- }
- }
- nodes, err = l.svcCtx.DB.SopNode.Query().Where(sopnode.ParentIDEQ(req.Id)).All(l.ctx)
- if err != nil {
- return nil, err
- }
- }
- backupNodeIndex := -1
- needJudge := false
- prompt := fmt.Sprintf(`# 任务
- 请根据历史消息,判断用户回复的内容或深层意图,与哪个节点的意图相匹配。
- # 历史消息:
- 助手发送:%s
- 用户回复:%s
- # 节点列表:`, q, req.Content)
- for i, node := range nodes {
- if node.ConditionList != nil && node.ConditionList[0] != "" {
- needJudge = true
- prompt += fmt.Sprintf(`
- 节点 id: %d
- 节点意图:%s
- `, i, node.ConditionList)
- } else {
- if node.NoReplyCondition == 0 {
- backupNodeIndex = i
- }
- }
- }
- prompt += `
- # 回复要求
- - 如果命中节点:则仅回复节点 id 数字(如命中多个节点,则仅回复最小值)
- - 如果未命中节点:则仅回复一个单词: None`
- index := ""
- if needJudge {
- // 调用openai接口,使用自定义base和key
- baseUrl := l.svcCtx.Config.Fastgpt.BASE_URL
- apiKey := l.svcCtx.Config.Fastgpt.API_KEY
- index, err = ChatWithCustomConfig(baseUrl, apiKey, prompt)
- if err != nil {
- return nil, err
- }
- } else {
- index = "None"
- }
- if index == "None" && backupNodeIndex != -1 {
- index = strconv.Itoa(backupNodeIndex)
- }
- nodeName := ""
- if index != "None" {
- // 将index转换成int
- indexInt, err := strconv.Atoi(index)
- if err != nil {
- return nil, err
- }
- nodeName = nodes[indexInt].Name
- } else {
- nodeName = "FastGPT"
- }
- var nodeNames []string
- nodeNames = append(nodeNames, nodeName)
- resp = &types.TestNodeResp{}
- resp.Msg = "success"
- resp.Data = nodeNames
- return resp, nil
- }
- func ChatWithCustomConfig(baseURL, apiKey, prompt string) (string, error) {
- // 创建OpenAI客户端配置
- config := openai.DefaultConfig(apiKey)
- config.BaseURL = baseURL
- // 创建OpenAI客户端
- openaiClient := openai.NewClientWithConfig(config)
- // 构建请求
- request := openai.ChatCompletionRequest{
- Model: openai.GPT4o,
- Messages: []openai.ChatCompletionMessage{
- {
- Role: "user",
- Content: prompt,
- },
- },
- }
- // 调用Chat接口
- response, err := openaiClient.CreateChatCompletion(context.Background(), request)
- if err != nil {
- return "", err
- }
- // 返回响应内容
- return response.Choices[0].Message.Content, nil
- }
|