openauthority_middleware.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package middleware
  2. import (
  3. "context"
  4. "net/http"
  5. "wechat-api/ent"
  6. "wechat-api/ent/apikey"
  7. "wechat-api/ent/predicate"
  8. "wechat-api/internal/utils/contextkey"
  9. "wechat-api/internal/config"
  10. "github.com/redis/go-redis/v9"
  11. "github.com/suyuan32/simple-admin-common/utils/jwt"
  12. "github.com/zeromicro/go-zero/core/errorx"
  13. "github.com/zeromicro/go-zero/rest/httpx"
  14. )
  15. /*
  16. //"wechat-api/internal/types/payload"
  17. var p types.payload.SendWxPayload
  18. */
  19. type OpenAuthorityMiddleware struct {
  20. DB *ent.Client
  21. Rds redis.UniversalClient
  22. Config config.Config
  23. }
  24. func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c config.Config) *OpenAuthorityMiddleware {
  25. return &OpenAuthorityMiddleware{
  26. DB: db,
  27. Rds: rds,
  28. Config: c,
  29. }
  30. }
  31. func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, loginToken string) (*ent.ApiKey, error) {
  32. var (
  33. rc int
  34. err error
  35. val *ent.ApiKey
  36. )
  37. /*
  38. r, e = m.getTokenUserInfoByRds(ctx, loginToken)
  39. fmt.Println("redis:", "code-", r, "err-", e)
  40. */
  41. val, rc, err = m.getTokenUserInfoByDb(ctx, loginToken)
  42. _ = rc
  43. if err != nil {
  44. return nil, err
  45. }
  46. return val, nil
  47. }
  48. /*
  49. func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, loginToken string) (code int, err error) {
  50. rcode := -1
  51. val, err := m.Rds.HGet(ctx, "api_key", loginToken).Result()
  52. if err == redis.Nil {
  53. rcode = 0
  54. } else if err == nil {
  55. rcode = 1
  56. }
  57. fmt.Printf("From Redis By Key:'%s' Get '%s'(%v)\n", loginToken, val, val)
  58. fmt.Println(val)
  59. return rcode, err
  60. }
  61. */
  62. func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
  63. rcode := -1
  64. var predicates []predicate.ApiKey
  65. predicates = append(predicates, apikey.KeyEQ(loginToken))
  66. val, err := m.DB.ApiKey.Query().Where(predicates...).WithAgent().Only(ctx)
  67. if err != nil {
  68. return nil, rcode, err
  69. }
  70. return val, 1, nil
  71. }
  72. func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  73. return func(w http.ResponseWriter, r *http.Request) {
  74. ctx := r.Context()
  75. ctx = contextkey.HttpResponseWriterKey.WithValue(ctx, w) //context存入http.ResponseWriter
  76. authToken := jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
  77. if len(authToken) == 0 {
  78. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取token"))
  79. return
  80. }
  81. apiKeyObj, err := m.checkTokenUserInfo(ctx, authToken)
  82. if err != nil {
  83. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
  84. return
  85. }
  86. //ctx = contextkey.OpenapiTokenKey.WithValue(ctx, apiToken)
  87. ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
  88. /*
  89. fmt.Println("=========================================")
  90. fmt.Printf("In Middleware Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
  91. fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
  92. fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  93. fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  94. fmt.Println("=========================================")
  95. claims, err := jwtutils.ParseJwtToken(m.Config.Auth.AccessSecret, authToken)
  96. fmt.Println("claims")
  97. fmt.Printf("%+v\n", claims)
  98. if err != nil {
  99. logx.Errorw("check user auth error", logx.Field("detail", err.Error()))
  100. httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
  101. return
  102. }
  103. */
  104. newReq := r.WithContext(ctx)
  105. // Passthrough to next handler if need
  106. next(w, newReq)
  107. }
  108. }