package compapi import ( "context" "errors" "fmt" "net/http" "wechat-api/internal/types" "wechat-api/internal/utils/contextkey" "github.com/invopop/jsonschema" "github.com/openai/openai-go" "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/ssestream" "github.com/zeromicro/go-zero/rest/httpx" ) type ClientConfig struct { ApiKey string ApiBase string } type ClientOption func(*ClientConfig) func WithApiKey(ApiKey string) ClientOption { return func(cfg *ClientConfig) { cfg.ApiKey = ApiKey } } func WithApiBase(ApiBase string) ClientOption { return func(cfg *ClientConfig) { cfg.ApiBase = ApiBase } } type clientActionFace interface { DoRequest(req *types.CompApiReq) (*types.CompOpenApiResp, error) DoRequestStream(req *types.CompApiReq) (*types.CompOpenApiResp, error) BuildRequest(req *types.CompApiReq) error CallbackPrepare(params any) ([]byte, error) } type Client struct { OAC *openai.Client Config ClientConfig ctx context.Context } func NewClient(ctx context.Context, opts ...ClientOption) *Client { client := Client{} for _, opt := range opts { opt(&client.Config) } client.NewOAC() //初始化openai client client.ctx = ctx return &client } func (me *Client) getClientActFace(clientType string) (clientActionFace, error) { var ( err error actFace clientActionFace ) switch clientType { case "mismatch": actFace = &MismatchClient{StdClient: StdClient{Client: me}} case "form": actFace = &FormClient{StdClient: StdClient{Client: me}} default: actFace = &FastgptClient{StdClient: StdClient{Client: me}} } return actFace, err } func (me *Client) NewOAC() { opts := []option.RequestOption{} if len(me.Config.ApiKey) > 0 { opts = append(opts, option.WithAPIKey(me.Config.ApiKey)) } if len(me.Config.ApiBase) > 0 { opts = append(opts, option.WithBaseURL(me.Config.ApiBase)) } oac := openai.NewClient(opts...) me.OAC = &oac } func (me *Client) Callback(clientType string, callbackUrl string, params any) (map[string]any, error) { actFace, err := me.getClientActFace(clientType) if err != nil { return nil, err } /* switch actFace.(type) { case *MismatchClient: fmt.Printf("MismatchClient.Callback() for EventType:'%s'\n", clientType) case *FastgptClient: fmt.Printf("FastgptClient.Callback() for EventType:'%s'\n", clientType) default: fmt.Printf("maybe StdClient.Callback() for EventType:'%s'\n", clientType) } */ var newParams []byte if newParams, err = actFace.CallbackPrepare(params); err != nil { return nil, err } //Post(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) resp := map[string]any{} err = me.OAC.Post(me.ctx, callbackUrl, newParams, &resp) if err != nil { fmt.Printf("Callback Post(%s) By Params:'%s' error\n", callbackUrl, string(newParams)) return nil, err } return resp, nil } func (me *Client) Chat(chatInfo *types.CompApiReq) (*types.CompOpenApiResp, error) { var ( err error actFace clientActionFace apiResp *types.CompOpenApiResp ) actFace, err = me.getClientActFace(chatInfo.EventType) if err != nil { return nil, err } /* switch actFace.(type) { case *MismatchClient: fmt.Printf("MismatchClient.Chat() for EventType:'%s'\n", chatInfo.EventType) case *FastgptClient: fmt.Printf("FastgptClient.Chat() for EventType:'%s'\n", chatInfo.EventType) default: fmt.Printf("maybe StdClient.Chat() for EventType:'%s'\n", chatInfo.EventType) } */ err = actFace.BuildRequest(chatInfo) if err != nil { return nil, err } if chatInfo.Stream { apiResp, err = actFace.DoRequestStream(chatInfo) } else { apiResp, err = actFace.DoRequest(chatInfo) } return apiResp, err } func GenerateSchema[T any]() interface{} { // Structured Outputs uses a subset of JSON schema // These flags are necessary to comply with the subset reflector := jsonschema.Reflector{ AllowAdditionalProperties: false, DoNotReference: true, } var v T schema := reflector.Reflect(v) return schema } func getHttpResponseTools(ctx context.Context) (*http.ResponseWriter, *http.Flusher, error) { hw, ok := contextkey.HttpResponseWriterKey.GetValue(ctx) //context取出http.ResponseWriter if !ok { return nil, nil, errors.New("content get http writer err") } flusher, ok := (hw).(http.Flusher) if !ok { return nil, nil, errors.New("streaming unsupported") } return &hw, &flusher, nil } func streamOut(ctx context.Context, res *http.Response) { var ehw http.ResponseWriter hw, flusher, err := getHttpResponseTools(ctx) if err != nil { http.Error(ehw, "Streaming unsupported!", http.StatusInternalServerError) } //获取API返回结果流 chatStream := ssestream.NewStream[ApiRespStreamChunk](ApiRespStreamDecoder(res), err) defer chatStream.Close() //设置流式输出头 http1.1 (*hw).Header().Set("Content-Type", "text/event-stream;charset=utf-8") (*hw).Header().Set("Connection", "keep-alive") (*hw).Header().Set("Cache-Control", "no-cache") for chatStream.Next() { chunk := chatStream.Current() fmt.Fprintf((*hw), "data:%s\n\n", chunk.Data.RAW) (*flusher).Flush() //time.Sleep(1 * time.Millisecond) } fmt.Fprintf((*hw), "data:%s\n\n", "[DONE]") (*flusher).Flush() httpx.Ok((*hw)) }