123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- package middleware
- import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "reflect"
- "wechat-api/ent"
- "wechat-api/ent/apikey"
- "wechat-api/ent/predicate"
- "wechat-api/internal/utils/compapi"
- "wechat-api/internal/utils/contextkey"
- "wechat-api/internal/config"
- "github.com/redis/go-redis/v9"
- "github.com/suyuan32/simple-admin-common/utils/jwt"
- "github.com/zeromicro/go-zero/core/errorx"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- /*
- //"wechat-api/internal/types/payload"
- var p types.payload.SendWxPayload
- */
- type OpenAuthorityMiddleware struct {
- DB *ent.Client
- Rds redis.UniversalClient
- Config config.Config
- }
- func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c config.Config) *OpenAuthorityMiddleware {
- return &OpenAuthorityMiddleware{
- DB: db,
- Rds: rds,
- Config: c,
- }
- }
- func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
- var (
- rc int
- err error
- val *ent.ApiKey
- )
- val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
- return val, rc, err
- /*
- r, e = m.getTokenUserInfoByRds(ctx, loginToken)
- fmt.Println("redis:", "code-", r, "err-", e)
- */
- /*
- //首先从redis取数据
- _, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
- fmt.Printf("++++++++++++++++++++++++get authinfo from rds out:%d/err:%s\n", rc, err)
- if rc <= 0 || err != nil { //无法获得后再从数据库获得
- rc = 0
- err = nil
- val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
- fmt.Println("----------------------After m.getTokenUserInfoByDb:", val)
- err = m.saveTokenUserInfoToRds(ctx, authToken, val)
- fmt.Println("------------save saveTokenUserInfoToRd err:", err)
- }
- _ = rc
- if err != nil {
- return nil, 0, err
- }
- return val, 0, nil
- */
- }
- func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
- if bs, err := json.Marshal(saveInfo); err == nil {
- return err
- } else {
- rc, err := m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
- fmt.Printf("#~~~~~~~~~~~~~~~++~~~~~~~~~~~~~HSet Val:%s get Result:%d/%s\n", string(bs), rc, err)
- return err
- }
- }
- func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
- rcode := -1
- val, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
- if err == redis.Nil {
- rcode = 0
- } else if err == nil {
- rcode = 1
- }
- fmt.Printf("#####################From Redis By Key:'%s' Get '%s'(%s/%T)\n", authToken, val, reflect.TypeOf(val), val)
- fmt.Println(val)
- return nil, rcode, err
- }
- func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
- rcode := -1
- var predicates []predicate.ApiKey
- predicates = append(predicates, apikey.KeyEQ(loginToken))
- val, err := m.DB.ApiKey.Query().Where(predicates...).WithAgent().Only(ctx)
- if err != nil {
- return nil, rcode, err
- }
- return val, 1, nil
- }
- func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- ctx := r.Context()
- ctx = contextkey.HttpResponseWriterKey.WithValue(ctx, w) //context存入http.ResponseWriter
- authToken := jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
- if len(authToken) == 0 {
- httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取token"))
- return
- }
- apiKeyObj, _, err := m.checkTokenUserInfo(ctx, authToken)
- if err != nil {
- httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
- return
- }
- //ctx = contextkey.OpenapiTokenKey.WithValue(ctx, apiToken)
- ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
- /*
- fmt.Println("=========================================")
- fmt.Printf("In Middleware Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
- fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
- fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
- fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
- fmt.Println("=========================================")
- */
- newReq := r.WithContext(ctx)
- // Passthrough to next handler if need
- next(w, newReq)
- }
- }
|