func.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package compapi
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "reflect"
  9. "strings"
  10. "wechat-api/internal/types"
  11. "wechat-api/internal/utils/contextkey"
  12. openai "github.com/openai/openai-go"
  13. "github.com/openai/openai-go/option"
  14. "github.com/openai/openai-go/packages/ssestream"
  15. "github.com/zeromicro/go-zero/rest/httpx"
  16. )
  17. func IsOpenaiModel(model string) bool {
  18. prefixes := []string{"gpt-4", "gpt-3", "o1", "o3"}
  19. // 遍历所有前缀进行检查
  20. for _, prefix := range prefixes {
  21. if strings.HasPrefix(model, prefix) {
  22. return true
  23. }
  24. }
  25. return false
  26. }
  27. func EntStructGenScanField(structPtr any, ignoredTypes ...reflect.Type) (string, []any, error) {
  28. t := reflect.TypeOf(structPtr)
  29. v := reflect.ValueOf(structPtr)
  30. if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
  31. return "", nil, errors.New("input must be a pointer to a struct")
  32. }
  33. t = t.Elem()
  34. v = v.Elem()
  35. var fields []string
  36. var scanArgs []any
  37. ignoredMap := make(map[reflect.Type]struct{})
  38. // 检查调用者是否传入了任何要忽略的类型
  39. if len(ignoredTypes) > 0 {
  40. for _, ignoredType := range ignoredTypes {
  41. if ignoredType != nil { // 防止 nil 类型加入 map
  42. ignoredMap[ignoredType] = struct{}{}
  43. }
  44. }
  45. }
  46. for i := 0; i < t.NumField(); i++ {
  47. field := t.Field(i)
  48. value := v.Field(i)
  49. // Skip unexported fields
  50. if !field.IsExported() {
  51. continue
  52. }
  53. // Get json tag
  54. jsonTag := field.Tag.Get("json")
  55. if jsonTag == "-" || jsonTag == "" {
  56. continue
  57. }
  58. jsonParts := strings.Split(jsonTag, ",")
  59. jsonName := jsonParts[0]
  60. if jsonName == "" {
  61. continue
  62. }
  63. //传入了要忽略的类型时进行处理
  64. if len(ignoredMap) > 0 {
  65. fieldType := field.Type //获取字段的实际 Go 类型
  66. //如果字段是指针,我们通常关心的是指针指向的元素的类型
  67. if fieldType.Kind() == reflect.Ptr {
  68. fieldType = fieldType.Elem() // 获取元素类型
  69. }
  70. if _, shouldIgnore := ignoredMap[fieldType]; shouldIgnore {
  71. continue // 成员类型存在于忽略列表中则忽略
  72. }
  73. }
  74. fields = append(fields, jsonName)
  75. scanArgs = append(scanArgs, value.Addr().Interface())
  76. }
  77. return strings.Join(fields, ", "), scanArgs, nil
  78. }
  79. type StdChatClient struct {
  80. *openai.Client
  81. }
  82. func NewStdChatClient(apiKey string, apiBase string) *StdChatClient {
  83. opts := []option.RequestOption{}
  84. if len(apiKey) > 0 {
  85. opts = append(opts, option.WithAPIKey(apiKey))
  86. }
  87. opts = append(opts, option.WithBaseURL(apiBase))
  88. client := openai.NewClient(opts...)
  89. return &StdChatClient{&client}
  90. }
  91. func NewAiClient(apiKey string, apiBase string) *openai.Client {
  92. opts := []option.RequestOption{}
  93. if len(apiKey) > 0 {
  94. opts = append(opts, option.WithAPIKey(apiKey))
  95. }
  96. opts = append(opts, option.WithBaseURL(apiBase))
  97. client := openai.NewClient(opts...)
  98. return &client
  99. }
  100. func NewFastgptClient(apiKey string) *openai.Client {
  101. //http://fastgpt.ascrm.cn/api/v1/
  102. client := openai.NewClient(option.WithAPIKey(apiKey),
  103. option.WithBaseURL("http://fastgpt.ascrm.cn/api/v1/"))
  104. return &client
  105. }
  106. func NewDeepSeekClient(apiKey string) *openai.Client {
  107. client := openai.NewClient(option.WithAPIKey(apiKey),
  108. option.WithBaseURL("https://api.deepseek.com"))
  109. return &client
  110. }
  111. func DoChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  112. var (
  113. jsonBytes []byte
  114. err error
  115. )
  116. emptyParams := openai.ChatCompletionNewParams{}
  117. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  118. return nil, err
  119. }
  120. //fmt.Printf("In DoChatCompletions, req: '%s'\n", string(jsonBytes))
  121. //也许应该对请求体不规范成员名进行检查
  122. customResp := types.CompOpenApiResp{}
  123. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  124. respBodyOps := option.WithResponseBodyInto(&customResp)
  125. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps); err != nil {
  126. return nil, err
  127. }
  128. if customResp.FgtErrCode != nil && customResp.FgtErrStatusTxt != nil { //针对fastgpt出错但New()不返回错误的情况
  129. return nil, fmt.Errorf("%s(%d)", *customResp.FgtErrStatusTxt, *customResp.FgtErrCode)
  130. }
  131. return &customResp, nil
  132. }
  133. func DoChatCompletionsStream(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  134. var (
  135. jsonBytes []byte
  136. raw *http.Response
  137. //raw []byte
  138. ok bool
  139. hw http.ResponseWriter
  140. )
  141. hw, ok = contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter
  142. if !ok {
  143. return nil, errors.New("content get http writer err")
  144. }
  145. flusher, ok := (hw).(http.Flusher)
  146. if !ok {
  147. http.Error(hw, "Streaming unsupported!", http.StatusInternalServerError)
  148. }
  149. emptyParams := openai.ChatCompletionNewParams{}
  150. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  151. return nil, err
  152. }
  153. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  154. respBodyOps := option.WithResponseBodyInto(&raw)
  155. if _, err = client.Chat.Completions.New(ctx, emptyParams, reqBodyOps, respBodyOps, option.WithJSONSet("stream", true)); err != nil {
  156. return nil, err
  157. }
  158. //设置流式输出头 http1.1
  159. hw.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  160. hw.Header().Set("Connection", "keep-alive")
  161. hw.Header().Set("Cache-Control", "no-cache")
  162. chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(raw), err)
  163. defer chatStream.Close()
  164. for chatStream.Next() {
  165. chunk := chatStream.Current()
  166. fmt.Fprintf(hw, "data:%s\n\n", chunk.Data.RAW)
  167. flusher.Flush()
  168. //time.Sleep(1 * time.Millisecond)
  169. }
  170. fmt.Fprintf(hw, "data:%s\n\n", "[DONE]")
  171. flusher.Flush()
  172. httpx.Ok(hw)
  173. return nil, nil
  174. }
  175. func NewChatCompletions(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  176. if chatInfo.Stream {
  177. return DoChatCompletionsStream(ctx, client, chatInfo)
  178. } else {
  179. return DoChatCompletions(ctx, client, chatInfo)
  180. }
  181. }
  182. func NewMismatchChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  183. client := NewAiClient(apiKey, apiBase)
  184. return NewChatCompletions(ctx, client, chatInfo)
  185. }
  186. func NewFastgptChatCompletions(ctx context.Context, apiKey string, apiBase string, chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) {
  187. client := NewAiClient(apiKey, apiBase)
  188. return NewChatCompletions(ctx, client, chatInfo)
  189. }
  190. func NewDeepSeekChatCompletions(ctx context.Context, apiKey string, chatInfo *types.CompApiReq, chatModel openai.ChatModel) (res *types.CompOpenApiResp, err error) {
  191. client := NewDeepSeekClient(apiKey)
  192. if chatModel != ChatModelDeepSeekV3 {
  193. chatModel = ChatModelDeepSeekR1
  194. }
  195. chatInfo.Model = chatModel
  196. return NewChatCompletions(ctx, client, chatInfo)
  197. }
  198. func DoChatCompletionsStreamOld(ctx context.Context, client *openai.Client, chatInfo *types.CompApiReq) (res *types.CompOpenApiResp, err error) {
  199. var (
  200. jsonBytes []byte
  201. )
  202. emptyParams := openai.ChatCompletionNewParams{}
  203. if jsonBytes, err = json.Marshal(chatInfo); err != nil {
  204. return nil, err
  205. }
  206. reqBodyOps := option.WithRequestBody("application/json", jsonBytes)
  207. //customResp := types.CompOpenApiResp{}
  208. //respBodyOps := option.WithResponseBodyInto(&customResp)
  209. //chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps, respBodyOps)
  210. chatStream := client.Chat.Completions.NewStreaming(ctx, emptyParams, reqBodyOps)
  211. // optionally, an accumulator helper can be used
  212. acc := openai.ChatCompletionAccumulator{}
  213. httpWriter, ok := ctx.Value("HttpResp-Writer").(http.ResponseWriter)
  214. if !ok {
  215. return nil, errors.New("content get writer err")
  216. }
  217. //httpWriter.Header().Set("Content-Type", "text/event-stream;charset=utf-8")
  218. //httpWriter.Header().Set("Connection", "keep-alive")
  219. //httpWriter.Header().Set("Cache-Control", "no-cache")
  220. idx := 0
  221. for chatStream.Next() {
  222. chunk := chatStream.Current()
  223. acc.AddChunk(chunk)
  224. fmt.Printf("=====>get %d chunk:%v\n", idx, chunk)
  225. if _, err := fmt.Fprintf(httpWriter, "%v", chunk); err != nil {
  226. fmt.Printf("Error writing to client:%v \n", err)
  227. break
  228. }
  229. if content, ok := acc.JustFinishedContent(); ok {
  230. println("Content stream finished:", content)
  231. }
  232. // if using tool calls
  233. if tool, ok := acc.JustFinishedToolCall(); ok {
  234. println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments)
  235. }
  236. if refusal, ok := acc.JustFinishedRefusal(); ok {
  237. println("Refusal stream finished:", refusal)
  238. }
  239. // it's best to use chunks after handling JustFinished events
  240. if len(chunk.Choices) > 0 {
  241. idx++
  242. fmt.Printf("idx:%d get =>'%s'\n", idx, chunk.Choices[0].Delta.Content)
  243. }
  244. }
  245. if err := chatStream.Err(); err != nil {
  246. return nil, err
  247. }
  248. return nil, nil
  249. }
  250. func GetWorkInfoByID(eventType string, workId string) (string, uint) {
  251. val, exist := fastgptWorkIdMap[workId]
  252. if !exist {
  253. val = fastgptWorkIdMap["default"]
  254. }
  255. return val.Id, val.Idx
  256. }
  257. // 获取workToken
  258. func GetWorkTokenByID(eventType string, workId string) string {
  259. id, _ := GetWorkInfoByID(eventType, workId)
  260. return id
  261. }
  262. // 获取workIdx
  263. func GetWorkIdxByID(eventType string, workId string) uint {
  264. _, idx := GetWorkInfoByID(eventType, workId)
  265. return idx
  266. }