package compapi import ( "bufio" "bytes" "encoding/json" "errors" "io" "net/http" "strings" "wechat-api/internal/types" "github.com/openai/openai-go/packages/ssestream" ) type ChatCompSteamChunk struct { types.StdCompApiResp RAW string `json:"-"` } type ApiRespStreamChunk struct { Event string `json:"event"` Data ChatCompSteamChunk `json:"data"` } type myStreamDecoder struct { evt ssestream.Event rc io.ReadCloser scn *bufio.Scanner err error closed bool // 用于处理多行事件 pendingEvent string } func (r *ChatCompSteamChunk) UnmarshalJSON(data []byte) (err error) { r.RAW = string(data) type Alias ChatCompSteamChunk return json.Unmarshal(data, (*Alias)(r)) } func ApiRespStreamDecoder(res any) ssestream.Decoder { var rc io.ReadCloser switch v := res.(type) { case *http.Response: rc = v.Body case []byte: rc = io.NopCloser(bytes.NewReader(v)) case string: rc = io.NopCloser(bytes.NewReader([]byte(v))) default: rc = io.NopCloser(strings.NewReader("")) } return &myStreamDecoder{rc: rc, scn: bufio.NewScanner(rc)} } func (s *myStreamDecoder) Event() ssestream.Event { return s.evt } func (s *myStreamDecoder) Close() error { s.closed = true if closer, ok := s.rc.(io.Closer); ok { return closer.Close() } return nil } func (s *myStreamDecoder) Err() error { return s.err } func (s *myStreamDecoder) Next() bool { if s.err != nil { return false } eventType := "" dataBuffer := bytes.NewBuffer(nil) for s.scn.Scan() { line := strings.TrimSpace(s.scn.Text()) if len(line) == 0 { continue //跳过空行 } // 处理事件类型行 if strings.HasPrefix(line, "event:") { s.pendingEvent = strings.TrimSpace(line[len("event:"):]) continue } // 处理数据行 if strings.HasPrefix(line, "data:") { tmpdata := strings.TrimSpace(line[len("data:"):]) //确定事件类型 if s.pendingEvent != "" { eventType = s.pendingEvent s.pendingEvent = "" } else { eventType = "answer" // 默认类型 } _, s.err = dataBuffer.WriteString(tmpdata) break } //忽略无法识别的行 } if dataBuffer.Len() > 0 { s.evt = ssestream.Event{ Type: eventType, Data: dataBuffer.Bytes(), } return true } if err := s.scn.Err(); err != nil && !errors.Is(err, io.EOF) { s.err = s.scn.Err() } return false }