chat_completions_logic.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package chat
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "unicode"
  12. "wechat-api/ent"
  13. "wechat-api/internal/svc"
  14. "wechat-api/internal/types"
  15. "wechat-api/internal/utils/compapi"
  16. "wechat-api/internal/utils/contextkey"
  17. "wechat-api/internal/utils/typekit"
  18. "wechat-api/ent/custom_types"
  19. "wechat-api/ent/predicate"
  20. "wechat-api/ent/usagedetail"
  21. "wechat-api/ent/usagetotal"
  22. "github.com/zeromicro/go-zero/core/logx"
  23. )
  24. type ChatCompletionsLogic struct {
  25. logx.Logger
  26. ctx context.Context
  27. svcCtx *svc.ServiceContext
  28. }
  29. func NewChatCompletionsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatCompletionsLogic {
  30. return &ChatCompletionsLogic{
  31. Logger: logx.WithContext(ctx),
  32. ctx: ctx,
  33. svcCtx: svcCtx}
  34. }
  35. func (l *ChatCompletionsLogic) ChatCompletions(req *types.CompApiReq) (asyncMode bool, resp *types.CompOpenApiResp, err error) {
  36. // todo: add your logic here and delete this line
  37. /*
  38. 1.鉴权获得token
  39. 2.必要参数检测及转换
  40. 3. 根据event_type选择不同处理路由
  41. */
  42. var (
  43. apiKeyObj *ent.ApiKey
  44. ok bool
  45. )
  46. asyncMode = false
  47. //微调部分请求参数
  48. reqAdjust(req)
  49. //从上下文中获取鉴权中间件埋下的apiAuthInfo
  50. apiKeyObj, ok = contextkey.AuthTokenInfoKey.GetValue(l.ctx)
  51. if !ok {
  52. return asyncMode, nil, errors.New("content get auth info err")
  53. }
  54. //微调apiKeyObj的openaikey
  55. apiKeyObjAdjust(req.EventType, req.WorkId, apiKeyObj)
  56. apiKeyObj.CreatedAt = time.Now()
  57. fmt.Println(typekit.PrettyPrint(apiKeyObj))
  58. fmt.Println("=========================================")
  59. fmt.Printf("In ChatCompletion Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
  60. fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
  61. fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  62. fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  63. fmt.Printf("workToken:'%s' because %s/%s\n", apiKeyObj.OpenaiKey, req.EventType, req.WorkId)
  64. fmt.Printf("req.ChatId:'%s VS req.FastgptChatId:'%s'\n", req.ChatId, req.FastgptChatId)
  65. fmt.Printf("apiKeyObj.CreatedAt:'%v' || apiKeyObj.UpdatedAt:'%v'\n", apiKeyObj.CreatedAt, apiKeyObj.UpdatedAt)
  66. fmt.Println("=========================================")
  67. if isAsyncReqest(req) { //异步请求处理模式
  68. fmt.Println("~~~~~~~~~~~~~~~~~~~isAsyncReqest:", req.Callback)
  69. asyncMode = true
  70. err = l.appendAsyncRequest(apiKeyObj, req)
  71. } else { //同步请求处理模式
  72. fmt.Println("~~~~~~~~~~~~~~~~~~~isSyncReqest")
  73. resp, err = l.workForFastgpt(apiKeyObj, req)
  74. if err == nil && resp != nil {
  75. l.doSyncRequestLog(apiKeyObj, req, resp) //请求记录
  76. }
  77. }
  78. return asyncMode, resp, err
  79. }
  80. func (l *ChatCompletionsLogic) appendAsyncRequest(apiKeyObj *ent.ApiKey, req *types.CompApiReq) error {
  81. workIDIdx := int8(compapi.GetWorkIdxByID(req.EventType, req.WorkId))
  82. rawReqResp := custom_types.OriginalData{Request: req}
  83. res, err := l.svcCtx.DB.CompapiJob.Create().
  84. SetNotNilCallbackURL(&req.Callback).
  85. SetNotNilAuthToken(&apiKeyObj.Key).
  86. SetNotNilEventType(&req.EventType).
  87. SetNillableWorkidIdx(&workIDIdx).
  88. SetNotNilRequestJSON(&rawReqResp).
  89. SetNillableChatID(&req.ChatId).
  90. Save(l.ctx)
  91. if err == nil {
  92. logx.Infof("appendAsyncRequest succ,get id:%d", res.ID)
  93. }
  94. return err
  95. }
  96. func (l *ChatCompletionsLogic) doSyncRequestLog(obj *ent.ApiKey, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  97. return l.appendUsageDetailLog(obj.Key, req, resp)
  98. }
  99. func (l *ChatCompletionsLogic) getUsagetotalIdByToken(authToken string) (uint64, error) {
  100. var predicates []predicate.UsageTotal
  101. predicates = append(predicates, usagetotal.BotIDEQ(authToken))
  102. return l.svcCtx.DB.UsageTotal.Query().Where(predicates...).FirstID(l.ctx)
  103. }
  104. func (l *ChatCompletionsLogic) replaceUsagetotalTokens(authToken string, sumTotalTokens uint64, newUsageDetailId uint64, orgId uint64) error {
  105. Id, err := l.getUsagetotalIdByToken(authToken)
  106. if err != nil && !ent.IsNotFound(err) {
  107. return err
  108. }
  109. if Id > 0 { //UsageTotal have record by newUsageDetailId
  110. _, err = l.svcCtx.DB.UsageTotal.UpdateOneID(Id).
  111. SetTotalTokens(sumTotalTokens).
  112. SetEndIndex(newUsageDetailId).
  113. Save(l.ctx)
  114. } else { //create new record by newUsageDetailId
  115. logType := 5
  116. _, err = l.svcCtx.DB.UsageTotal.Create().
  117. SetNotNilBotID(&authToken).
  118. SetNotNilEndIndex(&newUsageDetailId).
  119. SetNotNilTotalTokens(&sumTotalTokens).
  120. SetNillableType(&logType).
  121. SetNotNilOrganizationID(&orgId).
  122. Save(l.ctx)
  123. }
  124. return err
  125. }
  126. func (l *ChatCompletionsLogic) updateUsageTotal(authToken string, newUsageDetailId uint64, orgId uint64) error {
  127. sumTotalTokens, err := l.sumTotalTokensByAuthToken(authToken) //首先sum UsageDetail的TotalTokens
  128. if err == nil {
  129. err = l.replaceUsagetotalTokens(authToken, sumTotalTokens, newUsageDetailId, orgId) //再更新(包含新建)Usagetotal的otalTokens
  130. }
  131. return err
  132. }
  133. // sum total_tokens from usagedetail by AuthToken
  134. func (l *ChatCompletionsLogic) sumTotalTokensByAuthToken(authToken string) (uint64, error) {
  135. var predicates []predicate.UsageDetail
  136. predicates = append(predicates, usagedetail.BotIDEQ(authToken))
  137. var res []struct {
  138. Sum, Min, Max, Count uint64
  139. }
  140. totalTokens := uint64(0)
  141. var err error = nil
  142. err = l.svcCtx.DB.UsageDetail.Query().Where(predicates...).Aggregate(ent.Sum("total_tokens"),
  143. ent.Min("total_tokens"), ent.Max("total_tokens"), ent.Count()).Scan(l.ctx, &res)
  144. if err == nil {
  145. if len(res) > 0 {
  146. totalTokens = res[0].Sum
  147. } else {
  148. totalTokens = 0
  149. }
  150. }
  151. return totalTokens, err
  152. }
  153. func (l *ChatCompletionsLogic) appendUsageDetailLog(authToken string, req *types.CompApiReq, resp *types.CompOpenApiResp) error {
  154. logType := 5
  155. workIdx := int(compapi.GetWorkIdxByID(req.EventType, req.WorkId))
  156. rawReqResp := custom_types.OriginalData{Request: req, Response: resp}
  157. tmpId := 0
  158. tmpId, _ = strconv.Atoi(resp.ID)
  159. sessionId := uint64(tmpId)
  160. orgId := uint64(0)
  161. apiKeyObj, ok := contextkey.AuthTokenInfoKey.GetValue(l.ctx)
  162. if ok {
  163. orgId = apiKeyObj.OrganizationID
  164. }
  165. promptTokens := uint64(resp.Usage.PromptTokens)
  166. completionToken := uint64(resp.Usage.CompletionTokens)
  167. totalTokens := promptTokens + completionToken
  168. msgContent := ""
  169. switch val := req.Messages[0].Content.(type) {
  170. case string:
  171. msgContent = val
  172. case []interface{}:
  173. if len(val) > 0 {
  174. if valc, ok := val[0].(map[string]interface{}); ok {
  175. if valcc, ok := valc["text"]; ok {
  176. msgContent, _ = valcc.(string)
  177. }
  178. }
  179. }
  180. }
  181. res, err := l.svcCtx.DB.UsageDetail.Create().
  182. SetNotNilType(&logType).
  183. SetNotNilBotID(&authToken).
  184. SetNotNilReceiverID(&req.EventType).
  185. SetNotNilSessionID(&sessionId).
  186. SetNillableApp(&workIdx).
  187. SetNillableRequest(&msgContent).
  188. SetNillableResponse(&resp.Choices[0].Message.Content).
  189. SetNillableOrganizationID(&orgId).
  190. SetOriginalData(rawReqResp).
  191. SetNillablePromptTokens(&promptTokens).
  192. SetNillableCompletionTokens(&completionToken).
  193. SetNillableTotalTokens(&totalTokens).
  194. Save(l.ctx)
  195. if err == nil { //插入UsageDetai之后再统计UsageTotal
  196. l.updateUsageTotal(authToken, res.ID, orgId)
  197. }
  198. return err
  199. }
  200. func (l *ChatCompletionsLogic) workForFastgpt(apiKeyObj *ent.ApiKey, req *types.CompApiReq) (resp *types.CompOpenApiResp, err error) {
  201. //apiKey := "fastgpt-d2uehCb2T40h9chNGjf4bpFrVKmMkCFPbrjfVLZ6DAL2zzqzOFJWP"
  202. return compapi.NewFastgptChatCompletions(l.ctx, apiKeyObj.OpenaiKey, apiKeyObj.OpenaiBase, req)
  203. }
  204. func reqAdjust(req *types.CompApiReq) {
  205. if req.EventType != "fastgpt" {
  206. return
  207. }
  208. if len(req.Model) > 0 {
  209. if req.Variables == nil {
  210. req.Variables = make(map[string]string)
  211. }
  212. req.Variables["model"] = req.Model
  213. }
  214. if len(req.ChatId) > 0 && len(req.FastgptChatId) == 0 {
  215. req.FastgptChatId = req.ChatId
  216. } else if len(req.ChatId) == 0 && len(req.FastgptChatId) > 0 {
  217. req.ChatId = req.FastgptChatId
  218. }
  219. }
  220. func apiKeyObjAdjust(eventType string, workId string, obj *ent.ApiKey) {
  221. if eventType != "fastgpt" {
  222. return
  223. }
  224. obj.OpenaiKey, _ = compapi.GetWorkInfoByID(eventType, workId)
  225. }
  226. func IsValidURL(s string) bool {
  227. // 阶段1:快速预检
  228. if len(s) < 10 || len(s) > 2048 { // 常见URL长度范围
  229. return false
  230. }
  231. // 阶段2:标准库解析
  232. u, err := url.Parse(s)
  233. if err != nil || u.Scheme == "" || u.Host == "" {
  234. return false
  235. }
  236. // 阶段3:协议校验(支持常见网络协议)
  237. switch u.Scheme {
  238. case "http", "https", "ftp", "ftps", "sftp":
  239. // 允许的协议类型
  240. default:
  241. return false
  242. }
  243. // 阶段4:主机名深度校验
  244. host := u.Hostname()
  245. if strings.Contains(host, "..") || strings.ContainsAny(host, "!#$%&'()*,:;<=>?[]^`{|}~") {
  246. return false
  247. }
  248. // IPv4/IPv6 校验
  249. if ip := net.ParseIP(host); ip != nil {
  250. return true // 有效IP地址
  251. }
  252. // 域名格式校验
  253. if !isValidDomain(host) {
  254. return false
  255. }
  256. // 阶段5:端口校验(可选)
  257. if port := u.Port(); port != "" {
  258. for _, r := range port {
  259. if !unicode.IsDigit(r) {
  260. return false
  261. }
  262. }
  263. if len(port) > 5 || port == "0" {
  264. return false
  265. }
  266. }
  267. return true
  268. }
  269. // 高性能域名校验 (支持国际化域名IDNA)
  270. func isValidDomain(host string) bool {
  271. // 快速排除非法字符
  272. if strings.ContainsAny(host, " _+/\\") {
  273. return false
  274. }
  275. // 分段检查
  276. labels := strings.Split(host, ".")
  277. if len(labels) < 2 { // 至少包含顶级域和二级域
  278. return false
  279. }
  280. for _, label := range labels {
  281. if len(label) < 1 || len(label) > 63 {
  282. return false
  283. }
  284. if label[0] == '-' || label[len(label)-1] == '-' {
  285. return false
  286. }
  287. }
  288. // 最终DNS格式校验
  289. if _, err := net.LookupHost(host); err == nil {
  290. return true // 实际DNS解析验证(根据需求开启)
  291. }
  292. return true // 若不需要实际解析可始终返回true
  293. }
  294. func isAsyncReqest(req *types.CompApiReq) bool {
  295. if !req.IsBatch || !IsValidURL(req.Callback) {
  296. return false
  297. }
  298. return true
  299. }