compsteam.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package compapi
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "errors"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "wechat-api/internal/types"
  11. "github.com/openai/openai-go/packages/ssestream"
  12. )
  13. type ChatCompSteamChunk struct {
  14. types.StdCompApiResp
  15. RAW string `json:"-"`
  16. }
  17. type ApiRespStreamChunk struct {
  18. Event string `json:"event"`
  19. Data ChatCompSteamChunk `json:"data"`
  20. }
  21. type myStreamDecoder struct {
  22. evt ssestream.Event
  23. rc io.ReadCloser
  24. scn *bufio.Scanner
  25. err error
  26. closed bool
  27. // 用于处理多行事件
  28. pendingEvent string
  29. }
  30. func (r *ChatCompSteamChunk) UnmarshalJSON(data []byte) (err error) {
  31. r.RAW = string(data)
  32. type Alias ChatCompSteamChunk
  33. return json.Unmarshal(data, (*Alias)(r))
  34. }
  35. func ApiRespStreamDecoder(res any) ssestream.Decoder {
  36. var rc io.ReadCloser
  37. switch v := res.(type) {
  38. case *http.Response:
  39. rc = v.Body
  40. case []byte:
  41. rc = io.NopCloser(bytes.NewReader(v))
  42. case string:
  43. rc = io.NopCloser(bytes.NewReader([]byte(v)))
  44. default:
  45. rc = io.NopCloser(strings.NewReader(""))
  46. }
  47. return &myStreamDecoder{rc: rc, scn: bufio.NewScanner(rc)}
  48. }
  49. func (s *myStreamDecoder) Event() ssestream.Event {
  50. return s.evt
  51. }
  52. func (s *myStreamDecoder) Close() error {
  53. s.closed = true
  54. if closer, ok := s.rc.(io.Closer); ok {
  55. return closer.Close()
  56. }
  57. return nil
  58. }
  59. func (s *myStreamDecoder) Err() error {
  60. return s.err
  61. }
  62. func (s *myStreamDecoder) Next() bool {
  63. if s.err != nil {
  64. return false
  65. }
  66. eventType := ""
  67. dataBuffer := bytes.NewBuffer(nil)
  68. for s.scn.Scan() {
  69. line := strings.TrimSpace(s.scn.Text())
  70. if len(line) == 0 {
  71. continue //跳过空行
  72. }
  73. // 处理事件类型行
  74. if strings.HasPrefix(line, "event:") {
  75. s.pendingEvent = strings.TrimSpace(line[len("event:"):])
  76. continue
  77. }
  78. // 处理数据行
  79. if strings.HasPrefix(line, "data:") {
  80. tmpdata := strings.TrimSpace(line[len("data:"):])
  81. //确定事件类型
  82. if s.pendingEvent != "" {
  83. eventType = s.pendingEvent
  84. s.pendingEvent = ""
  85. } else {
  86. eventType = "answer" // 默认类型
  87. }
  88. _, s.err = dataBuffer.WriteString(tmpdata)
  89. break
  90. }
  91. //忽略无法识别的行
  92. }
  93. if dataBuffer.Len() > 0 {
  94. s.evt = ssestream.Event{
  95. Type: eventType,
  96. Data: dataBuffer.Bytes(),
  97. }
  98. return true
  99. }
  100. if err := s.scn.Err(); err != nil && !errors.Is(err, io.EOF) {
  101. s.err = s.scn.Err()
  102. }
  103. return false
  104. }