result.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package compapi
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "wechat-api/internal/types"
  9. )
  10. // 包装了原始的 API 响应,并提供了解析助手方法
  11. type ChatResult struct {
  12. *types.CompOpenApiResp
  13. err error
  14. }
  15. // NewChatResult 是 ChatResult 的构造函数
  16. func NewChatResult(resp any) *ChatResult {
  17. nresp, err := AnyToCompOpenApiResp(resp)
  18. res := ChatResult{nresp, err}
  19. if nresp == nil {
  20. res.err = errors.New("CompOpenApiRes is nil")
  21. }
  22. return &res
  23. }
  24. // 以下是ChatResul助手方法
  25. // GetContentString 返回第一个 Choice 的 Message Content (标准方式)
  26. func (r *ChatResult) GetContentString() (string, error) {
  27. var (
  28. content string = ""
  29. err error
  30. )
  31. if r.err == nil && len(r.Choices) > 0 {
  32. content = r.Choices[0].Message.Content
  33. } else if r.err == nil && len(r.Choices) == 0 {
  34. err = errors.New("choices empty")
  35. }
  36. return content, err
  37. }
  38. func (r *ChatResult) GetContentJsonStr() (string, error) {
  39. var (
  40. content string = ""
  41. err error
  42. )
  43. if r.err == nil && len(r.Choices) > 0 {
  44. content = r.Choices[0].Message.Content
  45. } else if r.err == nil && len(r.Choices) == 0 {
  46. err = errors.New("choices empty")
  47. }
  48. if !IsOpenaiModel(r.Model) { //不支持Response Schema的要特殊处理一下
  49. content, _, err = ExtractJSONContent(content)
  50. if err != nil {
  51. return "", fmt.Errorf("GCJS ExtractJSONContent err:'%s'", err)
  52. }
  53. }
  54. return content, err
  55. }
  56. // ParseContentAs 解析 Message Content 中的 JSON 到指定的 Go 结构体
  57. // target 必须是一个指向目标结构体实例的指针 (e.g., &MyStruct{})
  58. func (r *ChatResult) ParseContentAs(target any) error {
  59. content, err := r.GetContentString()
  60. if err != nil {
  61. return fmt.Errorf("parseContent err: %s", err)
  62. } else if content == "" {
  63. return errors.New("parseContent err: content is empty or unavailable")
  64. }
  65. if !IsOpenaiModel(r.Model) { //不支持Response Schema的要特殊处理一下
  66. content, _, err = ExtractJSONContent(content)
  67. if err != nil {
  68. return fmt.Errorf("PCA ExtractJSONContent err:'%s'", err)
  69. }
  70. }
  71. return ParseContentAs(content, target, false)
  72. }
  73. func AnyToBytes(in any) ([]byte, error) {
  74. switch v := in.(type) {
  75. case string:
  76. return []byte(v), nil
  77. case []byte:
  78. return v, nil
  79. default:
  80. return json.Marshal(v)
  81. }
  82. }
  83. func AnyToCompOpenApiResp(in any) (*types.CompOpenApiResp, error) {
  84. if resp, ok := in.(*types.CompOpenApiResp); ok {
  85. return resp, nil
  86. }
  87. if resp, ok := in.(types.CompOpenApiResp); ok {
  88. return &resp, nil
  89. }
  90. bs, err := AnyToBytes(in)
  91. if err != nil {
  92. return nil, err
  93. }
  94. nresp := &types.CompOpenApiResp{}
  95. err = json.Unmarshal(bs, nresp)
  96. if err != nil {
  97. return nil, err
  98. }
  99. return nresp, nil
  100. }
  101. func AnyToCompApiReq(in any) (*types.CompApiReq, error) {
  102. if req, ok := in.(*types.CompApiReq); ok {
  103. return req, nil
  104. }
  105. if req, ok := in.(types.CompApiReq); ok {
  106. return &req, nil
  107. }
  108. bs, err := AnyToBytes(in)
  109. if err != nil {
  110. return nil, err
  111. }
  112. nreq := &types.CompApiReq{}
  113. err = json.Unmarshal(bs, nreq)
  114. if err != nil {
  115. return nil, err
  116. }
  117. return nreq, nil
  118. }
  119. func CheckJSON(input any, checkEmpty bool) (bool, error) {
  120. inputBytes, err := AnyToBytes(input)
  121. if err != nil {
  122. return false, err
  123. }
  124. var raw json.RawMessage
  125. err = json.Unmarshal(inputBytes, &raw)
  126. if err != nil {
  127. return false, fmt.Errorf("input is not valid JSON: %w", err)
  128. }
  129. if checkEmpty {
  130. trimmed := bytes.TrimSpace(inputBytes)
  131. if len(trimmed) == 0 {
  132. return false, fmt.Errorf("input is empty")
  133. }
  134. }
  135. return true, nil
  136. }
  137. func WrapJSON(input any, warpKey string, checkValid bool) ([]byte, error) {
  138. var (
  139. inputBytes []byte
  140. outputBytes []byte
  141. err error
  142. )
  143. if inputBytes, err = AnyToBytes(input); err != nil {
  144. return nil, err
  145. }
  146. if checkValid {
  147. if _, err = CheckJSON(inputBytes, true); err != nil {
  148. return nil, err
  149. }
  150. }
  151. if len(warpKey) == 0 {
  152. return inputBytes, nil
  153. }
  154. wrapper := map[string]json.RawMessage{
  155. warpKey: json.RawMessage(inputBytes),
  156. }
  157. if outputBytes, err = json.Marshal(wrapper); err != nil {
  158. return nil, fmt.Errorf("failed to marshal wrapper structure: %w", err)
  159. }
  160. return outputBytes, nil
  161. }
  162. func ParseContentAs(content string, target any, removeJsonBlock bool) error {
  163. if removeJsonBlock &&
  164. strings.HasPrefix(content, "```json") && strings.HasSuffix(content, "```") {
  165. content = strings.TrimSuffix(strings.TrimPrefix(content, "```json"), "```")
  166. content = strings.TrimSpace(content)
  167. }
  168. if err := json.Unmarshal([]byte(content), target); err != nil {
  169. return fmt.Errorf("parseContent err:failed to unmarshal"+
  170. " content JSON into target type '%w'", err)
  171. }
  172. return nil
  173. }
  174. func ExtractJSONContent(s string) (string, bool, error) {
  175. startMarker := "```json"
  176. endMarker := "```"
  177. // 寻找起始标记
  178. startIdx := strings.Index(s, startMarker)
  179. if startIdx == -1 {
  180. return s, false, nil // 没有起始标记就不进行之后步骤
  181. }
  182. // 寻找结束标记(需在起始标记之后查找)
  183. endIdx := strings.LastIndex(s, endMarker)
  184. if endIdx == -1 || endIdx <= startIdx {
  185. return s, false, errors.New("lost endMarker") // 没有结束标记或标记顺序错误
  186. }
  187. // 计算内容范围
  188. contentStart := startIdx + len(startMarker)
  189. contentEnd := endIdx
  190. // 提取内容并去除前后空白
  191. content := strings.TrimSpace(s[contentStart:contentEnd])
  192. // 若内容为空视为无效
  193. if content == "" {
  194. return s, false, errors.New("empty content")
  195. }
  196. return content, true, nil
  197. }