chat.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package channel
  2. import (
  3. "bytes"
  4. "context"
  5. "github.com/openai/openai-go"
  6. "github.com/openai/openai-go/option"
  7. "github.com/zeromicro/go-zero/core/logx"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "strings"
  12. "time"
  13. )
  14. type ChatEngine struct {
  15. BaseURL string
  16. APIKey string
  17. IsFastGpt bool
  18. }
  19. func NewChatEngine(baseUrl string, apiKey string, isFastGpt bool) *ChatEngine {
  20. if baseUrl == "" {
  21. baseUrl = "https://api.openai.com/v1/"
  22. }
  23. if !strings.HasSuffix(baseUrl, "/") {
  24. baseUrl += "/"
  25. }
  26. return &ChatEngine{
  27. BaseURL: baseUrl,
  28. APIKey: apiKey,
  29. IsFastGpt: isFastGpt,
  30. }
  31. }
  32. func (o *ChatEngine) ChatCompletions(ctx context.Context, model openai.ChatModel, msg string, chatId string) (string, error) {
  33. client := openai.NewClient(
  34. option.WithBaseURL(o.BaseURL),
  35. option.WithAPIKey(o.APIKey),
  36. option.WithMiddleware(o.Logger),
  37. )
  38. options := make([]option.RequestOption, 0)
  39. if o.IsFastGpt {
  40. options = append(options, option.WithJSONSet("chatId", chatId))
  41. }
  42. param := openai.ChatCompletionNewParams{
  43. Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
  44. openai.UserMessage(msg),
  45. }),
  46. Model: openai.F(model),
  47. }
  48. chatCompletion, err := client.Chat.Completions.New(ctx, param, options...)
  49. if err != nil {
  50. return "", err
  51. }
  52. return chatCompletion.Choices[0].Message.Content, nil
  53. }
  54. func (o *ChatEngine) Logger(req *http.Request, next option.MiddlewareNext) (res *http.Response, err error) {
  55. // Before the request
  56. start := time.Now()
  57. //logx.Info("request:", req)
  58. // 读取并打印请求的 body
  59. var requestBody []byte
  60. if req.Body != nil {
  61. requestBody, err = ioutil.ReadAll(req.Body)
  62. if err != nil {
  63. logx.Error("读取请求 body 失败:", err)
  64. return nil, err
  65. }
  66. // 重新设置请求的 body
  67. req.Body = io.NopCloser(bytes.NewBuffer(requestBody))
  68. }
  69. logx.Info("request:", req)
  70. logx.Info("request body:", string(requestBody))
  71. // Forward the request to the next handler
  72. res, err = next(req)
  73. // Handle stuff after the request
  74. end := time.Now()
  75. //耗时
  76. logx.Info("resp: ", res, err, end.Sub(start))
  77. return res, err
  78. }