소스 검색

Merge branch 'master' into debug

* master:
  no message

# Conflicts:
#	.gitignore
#	crontask/init.go
#	ent/client.go
#	ent/ent.go
#	ent/mutation.go
#	go.mod
#	internal/logic/chat/chat_completions_logic.go
#	internal/logic/contact/update_contact_logic.go
boweniac 1 주 전
부모
커밋
e1473c6d79
7개의 변경된 파일321개의 추가작업 그리고 101개의 파일을 삭제
  1. 1 1
      .gitignore
  2. 2 0
      crontask/init.go
  3. 1 0
      ent/client.go
  4. 0 6
      go.mod
  5. 315 92
      internal/logic/chat/chat_completions_logic.go
  6. 1 1
      internal/logic/contact/update_contact_logic.go
  7. 1 1
      internal/utils/compapi/config.go

+ 1 - 1
.gitignore

@@ -33,5 +33,5 @@ vendor/
 wechat-api
 etc/wechat.yaml
 /etc/wechat.yaml
-cli/asynctask/etc
+cli/asynctask/etc/asynctask-prod.yaml
 cli/asynctask/wechat_api_open

+ 2 - 0
crontask/init.go

@@ -53,10 +53,12 @@ func ScheduleRun(c *cron.Cron, serverCtx *svc.ServiceContext) {
 	c.AddFunc("0 0 * * *", func() {
 		contactForm.analyze()
 	})
+
 	//l = NewCronTask(context.Background(), serverCtx)
 	//c.AddFunc("* * * * *", func() {
 	//	MaxWorker := 10
 	//	MaxChannel := 3
 	//	l.compApiCallback(MaxWorker, MaxChannel)
 	//})
+
 }

+ 1 - 0
ent/client.go

@@ -2166,6 +2166,7 @@ func (c *ContactClient) QueryContactRelationships(co *Contact) *LabelRelationshi
 	return query
 }
 
+
 // QueryContactFields queries the contact_fields edge of a Contact.
 func (c *ContactClient) QueryContactFields(co *Contact) *ContactFieldQuery {
 	query := (&ContactFieldClient{config: c.config}).Query()

+ 0 - 6
go.mod

@@ -48,12 +48,6 @@ require github.com/invopop/jsonschema v0.13.0
 require github.com/google/uuid v1.6.0
 
 require (
-	github.com/invopop/jsonschema v0.13.0
-	github.com/openai/openai-go v0.1.0-beta.9
-//github.com/openai/openai-go v0.1.0-alpha.62
-)
-
-require (
 	ariga.io/atlas v0.19.2 // indirect
 	filippo.io/edwards25519 v1.1.0 // indirect
 	github.com/ArtisanCloud/PowerLibs/v2 v2.0.49 // indirect

+ 315 - 92
internal/logic/chat/chat_completions_logic.go

@@ -2,8 +2,14 @@ package chat
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
+	"fmt"
+	"net"
+	"net/url"
+	"regexp"
 	"strconv"
+	"strings"
 
 	"wechat-api/ent"
 	"wechat-api/internal/svc"
@@ -25,6 +31,21 @@ type ChatCompletionsLogic struct {
 	svcCtx *svc.ServiceContext
 }
 
+type FastgptChatLogic struct {
+	ChatCompletionsLogic
+}
+
+type MismatchChatLogic struct {
+	ChatCompletionsLogic
+}
+
+type baseLogicWorkflow interface {
+	AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error
+	DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error)
+	AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error
+	AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey)
+}
+
 func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic {
 	return &ChatCompletionsLogic{
 		Logger: logx.WithContext(ctx),
@@ -32,68 +53,237 @@ func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *C
 		svcCtx: svcCtx}
 }
 
-func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) {
+func (l *FastgptChatLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
+
+	l.ChatCompletionsLogic.AdjustRequest(req, apiKeyObj) //先父类的参数调整
+
+	if req.EventType != "fastgpt" {
+		return
+	}
+	if len(req.Model) > 0 {
+		if req.Variables == nil {
+			req.Variables = make(map[string]string)
+		}
+		req.Variables["model"] = req.Model
+	}
+
+	if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 {
+		req.FastgptChatId = req.ChatId
+	} else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 {
+		req.ChatId = req.FastgptChatId
+	}
+}
+
+func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) {
 	// todo: add your logic here and delete this line
 
-	/*
-	   1.鉴权获得token
-	   2.必要参数检测及转换
-	   3. 根据event_type选择不同处理路由
-	*/
 	var (
 		apiKeyObj *ent.ApiKey
 		ok        bool
 	)
-	workToken := compapi.GetWorkTokenByID(req.EventType, req.WorkId)
+	asyncMode = false
+
+	//从上下文中获取鉴权中间件埋下的apiAuthInfo
 	apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx)
 	if !ok {
-		return nil, errors.New("content get token err")
-	}
-	if req.WorkId == "TEST_DOUYIN" || req.WorkId == "TEST_DOUYIN_CN" || req.WorkId == "travel" || req.WorkId == "loreal" || req.WorkId == "xiulike-dc" || req.WorkId == "wuhanzhongxin" { //临时加
-		apiKeyObj.OpenaiBase = "http://cn-agent.gkscrm.com/api/v1/"
-		if req.WorkId == "TEST_DOUYIN" || req.WorkId == "TEST_DOUYIN_CN" {
-			workToken = "fastgpt-jsMmQKEM5uX7tDimT1zHlZHkBhMHRT2k61YaxyDJRZTUHehID7sG8BKXADNIU"
-		} else if req.WorkId == "travel" {
-			workToken = "fastgpt-bcnfFtw1lXWdmYGOv165UVD5R1kY28tyXX8SJv8MHhrSMOgVJsuU"
-		} else if req.WorkId == "loreal" {
-			workToken = "fastgpt-qqJeBEkwhgx7wR9fvGToQygmOb7FVjAbGBTFjOAMbd95InEtndke"
-		} else if req.WorkId == "xiulike-dc" {
-			workToken = "fastgpt-ir9RgnKHMT9HIOnPsUCFChN15ZbW9kt1lbd5Y0ohfLw9gOz3KcPrfaZWHRB"
-		} else if req.WorkId == "wuhanzhongxin" {
-			workToken = "fastgpt-jX6Gl50Ivrc7vzyD4xNlahG11cgmJ4N63QHKrntt2gQ78g31haxuAsA"
-		}
-
+		return asyncMode, nil, errors.New("content get auth info err")
 	}
-
+	//微调apiKeyObj的openaikey
+	//apiKeyObjAdjust(req.EventType, req.WorkId, apiKeyObj)
 	/*
 		fmt.Println("=========================================")
 		fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
-		fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
-		fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
-		fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
-		fmt.Printf("workToken:'%s' because %s/%s\n", workToken, req.EventType, req.WorkId)
+		fmt.Printf("Auth Token:'%s'\n", apiKeyObj.Key)
+		fmt.Printf("ApiKey AgentID:%d\n", apiKeyObj.AgentID)
+		fmt.Printf("ApiKey APIBase:'%s'\n", apiKeyObj.Edges.Agent.APIBase)
+		fmt.Printf("ApiKey APIKey:'%s'\n", apiKeyObj.Edges.Agent.APIKey)
+		fmt.Printf("ApiKey Type:%d\n", apiKeyObj.Edges.Agent.Type)
+		fmt.Printf("ApiKey Model:'%s'\n", apiKeyObj.Edges.Agent.Model)
+		fmt.Printf("EventType:'%s'\n", req.EventType)
+		fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId)
 		fmt.Println("=========================================")
 	*/
 
-	if len(apiKeyObj.OpenaiBase) == 0 || len(workToken) == 0 {
-		return nil, errors.New("not auth info")
+	//根据请求产生相关的工作流接口集
+	wf, err := l.getLogicWorkflow(apiKeyObj, req)
+	if err != nil {
+		return false, nil, err
 	}
+	/*
+		switch wf.(type) {
+		case *MismatchChatLogic:
+			fmt.Println("MismatchChatLogic Flow.....")
+		case *FastgptChatLogic:
+			fmt.Println("FastgptChatLogic Flow.....")
+		default:
+			fmt.Println("Other Flow.....")
+		}
+	*/
 
-	apiResp, err := l.workForFastgpt(req, workToken, apiKeyObj.OpenaiBase)
-	if err == nil && apiResp != nil {
-		l.doRequestLog(req, apiResp) //请求记录
+	//微调部分请求参数
+	wf.AdjustRequest(req, apiKeyObj)
+
+	if isAsyncReqest(req) { //异步请求处理模式
+		asyncMode = true
+		err = wf.AppendAsyncRequest(apiKeyObj, req)
+	} else { //同步请求处理模式
+		resp, err = wf.DoSyncRequest(apiKeyObj, req)
+		if err == nil && resp != nil && len(resp.Choices) > 0 {
+			wf.AppendUsageDetailLog(apiKeyObj.Key, req, resp) //请求记录
+		} else if resp != nil && len(resp.Choices) == 0 {
+			err = errors.New("返回结果缺失,请检查访问地址及权限")
+		}
 	}
+	return asyncMode, resp, err
+}
 
-	return apiResp, err
+func (l *ChatCompletionsLogic) getLogicWorkflow(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (baseLogicWorkflow, error) {
+	var (
+		err error
+		wf  baseLogicWorkflow
+	)
+
+	if apiKeyObj.Edges.Agent.Type != 2 {
+		err = fmt.Errorf("api agent type not support(%d)", apiKeyObj.Edges.Agent.Type)
+	} else if req.EventType == "mismatch" {
+		wf = &MismatchChatLogic{ChatCompletionsLogic: *l}
+	} else {
+		wf = &FastgptChatLogic{ChatCompletionsLogic: *l}
+	}
+	return wf, err
 }
 
-func (l *ChatCompletionsLogic) doRequestLog(req *types.CompApiReq, resp *types.CompOpenApiResp) error {
-	authToken, ok := contextkey.OpenapiTokenKey.GetValue(l.ctx)
-	if !ok {
-		return errors.New("content get auth token err")
+func (l *ChatCompletionsLogic) AdjustRequest(req *types.CompApiReq, apiKeyObj *ent.ApiKey) {
+
+	if len(req.EventType) == 0 {
+		req.EventType = "fastgpt"
 	}
 
-	return l.appendUsageDetailLog(authToken, req, resp)
+	if len(req.Model) == 0 && len(apiKeyObj.Edges.Agent.Model) > 0 {
+		req.Model = apiKeyObj.Edges.Agent.Model
+	}
+
+	//异步任务相关参数调整
+	if req.IsBatch {
+		//流模式暂时不支持异步模式
+		//Callback格式非法则取消批量模式
+		if req.Stream || !IsValidURL(&req.Callback, true) {
+			req.IsBatch = false
+		}
+	}
+}
+
+func (l *ChatCompletionsLogic) DoSyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (*types.CompOpenApiResp, error) {
+	//return compapi.NewFastgptChatCompletions(l.ctx, apiKeyObj.Edges.Agent.APIKey, apiKeyObj.Edges.Agent.APIBase, req)
+	resp, err := compapi.NewClient(l.ctx, compapi.WithApiBase(apiKeyObj.Edges.Agent.APIBase),
+		compapi.WithApiKey(apiKeyObj.Edges.Agent.APIKey)).
+		Chat(req)
+
+	/*
+		if err != nil {
+			return nil, err
+		}
+		if req.EventType == "mismatch" {
+
+			client := compapi.MismatchClient{}
+			taskData := ent.CompapiAsynctask{}
+			taskData.ID = 1234
+			taskData.ResponseChatItemID = req.ResponseChatItemId
+			taskData.EventType = req.EventType
+			taskData.ChatID = req.ChatId
+			var err error
+			taskData.ResponseRaw, err = resp.ToString()
+			if err != nil {
+				fmt.Println(err)
+				return nil, err
+			}
+			var bs []byte
+			bs, err = client.CallbackPrepare(&taskData)
+			if err != nil {
+				fmt.Println(err)
+				return nil, err
+			}
+			fmt.Println(string(bs))
+			nres := map[string]string{}
+			err = json.Unmarshal(bs, &nres)
+			if err != nil {
+				fmt.Println(err)
+				return nil, err
+			}
+			fmt.Println(typekit.PrettyPrint(nres))
+
+			res := compapi.MismatchResponse{}
+			err = compapi.NewChatResult(resp).ParseContentAs(&res)
+			fmt.Println(err)
+			fmt.Println(typekit.PrettyPrint(res))
+		}
+	*/
+	return resp, err
+}
+
+func (l *ChatCompletionsLogic) AppendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error {
+
+	rawReqBs, err := json.Marshal(*req)
+	if err != nil {
+		return err
+	}
+	rawReqStr := string(rawReqBs)
+	res, err := l.svcCtx.DB.CompapiAsynctask.Create().
+		SetNotNilAuthToken(&apiKeyObj.Key).
+		SetNotNilOpenaiBase(&apiKeyObj.Edges.Agent.APIBase).
+		SetNotNilOpenaiKey(&apiKeyObj.Edges.Agent.APIKey).
+		SetNotNilOrganizationID(&apiKeyObj.OrganizationID).
+		SetNotNilEventType(&req.EventType).
+		SetNillableModel(&req.Model).
+		SetNillableChatID(&req.ChatId).
+		SetNillableResponseChatItemID(&req.ResponseChatItemId).
+		SetNotNilRequestRaw(&rawReqStr).
+		SetNotNilCallbackURL(&req.Callback).
+		Save(l.ctx)
+
+	if err == nil {
+		logx.Infof("appendAsyncRequest succ,get id:%d", res.ID)
+	}
+	return err
+}
+
+func (l *ChatCompletionsLogic) AppendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+
+	logType := 5
+	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+	tmpId := 0
+	tmpId, _ = strconv.Atoi(resp.ID)
+	sessionId := uint64(tmpId)
+	orgId := uint64(0)
+	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
+	if ok {
+		orgId = apiKeyObj.OrganizationID
+	}
+	promptTokens := uint64(resp.Usage.PromptTokens)
+	completionToken := uint64(resp.Usage.CompletionTokens)
+	totalTokens := promptTokens + completionToken
+
+	msgContent := getMessageContentStr(req.Messages[0].Content)
+
+	_, _, _ = logType, sessionId, totalTokens
+	res, err := l.svcCtx.DB.UsageDetail.Create().
+		SetNotNilType(&logType).
+		SetNotNilBotID(&authToken).
+		SetNotNilReceiverID(&req.EventType).
+		SetNotNilSessionID(&sessionId).
+		SetNillableRequest(&msgContent).
+		SetNillableResponse(&resp.Choices[0].Message.Content).
+		SetNillableOrganizationID(&orgId).
+		SetOriginalData(rawReqResp).
+		SetNillablePromptTokens(&promptTokens).
+		SetNillableCompletionTokens(&completionToken).
+		SetNillableTotalTokens(&totalTokens).
+		Save(l.ctx)
+
+	if err == nil { //插入UsageDetai之后再统计UsageTotal
+		l.updateUsageTotal(authToken, res.ID, orgId)
+	}
+	return err
 }
 
 func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
@@ -161,72 +351,105 @@ func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint
 	return totalTokens, err
 }
 
-func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
+func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) {
+	if eventType != "fastgpt" {
+		return
+	}
+	obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId)
+}
+
+// 合法域名正则(支持通配符、中文域名等场景按需调整)
+var domainRegex = regexp.MustCompile(
+	// 多级域名(如 example.com)
+	`^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$` +
+		`|` +
+		// 单级域名(如 localhost 或 mytest-svc)
+		`^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$`,
+)
 
-	logType := 5
-	workIdx := compapi.GetWorkIdxByID(req.EventType, req.WorkId)
-	rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
+func IsValidURL(input *string, adjust bool) bool {
+	// 空值直接返回
+	if *input == "" {
+		return false
+	}
+	inputStr := *input
 
-	tmpId := 0
-	tmpId, _ = strconv.Atoi(resp.ID)
-	sessionId := uint64(tmpId)
-	orgId := uint64(0)
-	apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
-	if ok {
-		orgId = apiKeyObj.OrganizationID
+	// --- 预处理输入:自动补全协议 ---
+	// 若输入不包含协议头,默认添加 http://
+	if !strings.Contains(*input, "://") {
+		inputStr = "http://" + *input
 	}
-	promptTokens := uint64(resp.Usage.PromptTokens)
-	completionToken := uint64(resp.Usage.CompletionTokens)
-	totalTokens := promptTokens + completionToken
 
-	msgContent := ""
-	switch val := req.Messages[0].Content.(type) {
+	// --- 解析 URL ---
+	u, err := url.Parse(inputStr)
+	if err != nil {
+		return false
+	}
+
+	// --- 校验协议 ---
+	// 只允许常见协议(按需扩展)
+	switch u.Scheme {
+	case "http", "https", "ftp", "ftps":
+	default:
+		return false
+	}
+
+	// --- 拆分 Host 和 Port ---
+	host, port, err := net.SplitHostPort(u.Host)
+	if err != nil {
+		// 无端口时,整个 Host 作为主机名
+		host = u.Host
+		port = ""
+	}
+
+	// --- 校验主机名 ---
+	// 场景1:IPv4 或 IPv6
+	if ip := net.ParseIP(host); ip != nil {
+		// 允许私有或保留 IP(按需调整)
+		// 示例中允许所有合法 IP
+	} else {
+		// 场景2:域名(包括 localhost)
+		if !domainRegex.MatchString(host) {
+			return false
+		}
+	}
+
+	// --- 校验端口 ---
+	if port != "" {
+		p, err := net.LookupPort("tcp", port) // 动态获取端口(如 "http" 对应 80)
+		if err != nil {
+			// 直接尝试解析为数字端口
+			numPort, err := strconv.Atoi(port)
+			if err != nil || numPort < 1 || numPort > 65535 {
+				return false
+			}
+		} else if p == 0 { // 动态端口为 0 时无效
+			return false
+		}
+	}
+	if adjust {
+		*input = inputStr
+	}
+	return true
+}
+
+func getMessageContentStr(input any) string {
+	str := ""
+	switch val := input.(type) {
 	case string:
-		msgContent = val
+		str = val
 	case []interface{}:
 		if len(val) > 0 {
 			if valc, ok := val[0].(map[string]interface{}); ok {
 				if valcc, ok := valc["text"]; ok {
-					msgContent, _ = valcc.(string)
+					str, _ = valcc.(string)
 				}
 			}
 		}
 	}
-
-	res, err := l.svcCtx.DB.UsageDetail.Create().
-		SetNotNilType(&logType).
-		SetNotNilBotID(&authToken).
-		SetNotNilReceiverID(&req.EventType).
-		SetNotNilSessionID(&sessionId).
-		SetNillableApp(&workIdx).
-		//SetNillableRequest(&req.Messages[0].Content).
-		SetNillableRequest(&msgContent).
-		SetNillableResponse(&resp.Choices[0].Message.Content).
-		SetNillableOrganizationID(&orgId).
-		SetOriginalData(rawReqResp).
-		SetNillablePromptTokens(&promptTokens).
-		SetNillableCompletionTokens(&completionToken).
-		SetNillableTotalTokens(&totalTokens).
-		Save(l.ctx)
-
-	if err == nil { //插入UsageDetai之后再统计UsageTotal
-		l.updateUsageTotal(authToken, res.ID, orgId)
-	}
-	return err
+	return str
 }
 
-func (l *ChatCompletionsLogic) workForFastgpt(req *types.CompApiReq, apiKey string, apiBase string) (resp *types.CompOpenApiResp, err error) {
-
-	//apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP"
-	if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 {
-		req.FastgptChatId = req.ChatId
-	}
-	if len(req.Model) > 0 {
-		if req.Variables == nil {
-			req.Variables = make(map[string]string)
-		}
-		req.Variables["model"] = req.Model
-	}
-	return compapi.NewFastgptChatCompletions(l.ctx, apiKey, apiBase, req)
-
+func isAsyncReqest(req *types.CompApiReq) bool {
+	return req.IsBatch
 }

+ 1 - 1
internal/logic/contact/update_contact_logic.go

@@ -37,7 +37,7 @@ func (l *UpdateContactLogic) UpdateContact(req *types.ContactInfo) (*types.BaseM
 	if req.Cage != nil && *req.Cage > 0 {
 		cage = *req.Cage
 	}
-	err = l.svcCtx.DB.Contact.UpdateOneID(*req.Id).
+	err = tx.Contact.UpdateOneID(*req.Id).
 		Where(contact.OrganizationID(organizationId)).
 		SetNotNilStatus(req.Status).
 		SetNotNilWxWxid(req.WxWxid).

+ 1 - 1
internal/utils/compapi/config.go

@@ -13,7 +13,7 @@ const (
 
 type workIdInfo struct {
 	Id  string
-	Idx int
+	Idx uint
 }
 
 var fastgptWorkIdMap = map[string]workIdInfo{