openauthority_middleware.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package middleware
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "reflect"
  8. "wechat-api/ent"
  9. "wechat-api/ent/apikey"
  10. "wechat-api/ent/predicate"
  11. "wechat-api/internal/utils/compapi"
  12. "wechat-api/internal/utils/contextkey"
  13. "wechat-api/internal/config"
  14. "github.com/redis/go-redis/v9"
  15. "github.com/suyuan32/simple-admin-common/utils/jwt"
  16. "github.com/zeromicro/go-zero/core/errorx"
  17. "github.com/zeromicro/go-zero/rest/httpx"
  18. )
  19. /*
  20. //"wechat-api/internal/types/payload"
  21. var p types.payload.SendWxPayload
  22. */
  23. type OpenAuthorityMiddleware struct {
  24. DB *ent.Client
  25. Rds redis.UniversalClient
  26. Config config.Config
  27. }
  28. func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c config.Config) *OpenAuthorityMiddleware {
  29. return &OpenAuthorityMiddleware{
  30. DB: db,
  31. Rds: rds,
  32. Config: c,
  33. }
  34. }
  35. func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  36. var (
  37. rc int
  38. err error
  39. val *ent.ApiKey
  40. )
  41. val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
  42. return val, rc, err
  43. /*
  44. r, e = m.getTokenUserInfoByRds(ctx, loginToken)
  45. fmt.Println("redis:", "code-", r, "err-", e)
  46. */
  47. /*
  48. //首先从redis取数据
  49. _, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
  50. fmt.Printf("++++++++++++++++++++++++get authinfo from rds out:%d/err:%s\n", rc, err)
  51. if rc <= 0 || err != nil { //无法获得后再从数据库获得
  52. rc = 0
  53. err = nil
  54. val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
  55. fmt.Println("----------------------After m.getTokenUserInfoByDb:", val)
  56. err = m.saveTokenUserInfoToRds(ctx, authToken, val)
  57. fmt.Println("------------save saveTokenUserInfoToRd err:", err)
  58. }
  59. _ = rc
  60. if err != nil {
  61. return nil, 0, err
  62. }
  63. return val, 0, nil
  64. */
  65. }
  66. func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
  67. if bs, err := json.Marshal(saveInfo); err == nil {
  68. return err
  69. } else {
  70. rc, err := m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
  71. fmt.Printf("#~~~~~~~~~~~~~~~++~~~~~~~~~~~~~HSet Val:%s get Result:%d/%s\n", string(bs), rc, err)
  72. return err
  73. }
  74. }
  75. func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  76. rcode := -1
  77. val, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
  78. if err == redis.Nil {
  79. rcode = 0
  80. } else if err == nil {
  81. rcode = 1
  82. }
  83. fmt.Printf("#####################From Redis By Key:'%s' Get '%s'(%s/%T)\n", authToken, val, reflect.TypeOf(val), val)
  84. fmt.Println(val)
  85. return nil, rcode, err
  86. }
  87. func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
  88. rcode := -1
  89. var predicates []predicate.ApiKey
  90. predicates = append(predicates, apikey.KeyEQ(loginToken))
  91. val, err := m.DB.ApiKey.Query().Where(predicates...).WithAgent().Only(ctx)
  92. if err != nil {
  93. return nil, rcode, err
  94. }
  95. return val, 1, nil
  96. }
  97. func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  98. return func(w http.ResponseWriter, r *http.Request) {
  99. ctx := r.Context()
  100. ctx = contextkey.HttpResponseWriterKey.WithValue(ctx, w) //context存入http.ResponseWriter
  101. authToken := jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
  102. if len(authToken) == 0 {
  103. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取token"))
  104. return
  105. }
  106. apiKeyObj, _, err := m.checkTokenUserInfo(ctx, authToken)
  107. if err != nil {
  108. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
  109. return
  110. }
  111. //ctx = contextkey.OpenapiTokenKey.WithValue(ctx, apiToken)
  112. ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
  113. /*
  114. fmt.Println("=========================================")
  115. fmt.Printf("In Middleware Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
  116. fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
  117. fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  118. fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  119. fmt.Println("=========================================")
  120. */
  121. newReq := r.WithContext(ctx)
  122. // Passthrough to next handler if need
  123. next(w, newReq)
  124. }
  125. }