openauthority_middleware.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package middleware
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "net/http"
  7. "wechat-api/ent"
  8. "wechat-api/ent/apikey"
  9. "wechat-api/ent/predicate"
  10. "wechat-api/internal/utils/compapi"
  11. "wechat-api/internal/utils/contextkey"
  12. "wechat-api/internal/config"
  13. "github.com/redis/go-redis/v9"
  14. "github.com/suyuan32/simple-admin-common/utils/jwt"
  15. "github.com/zeromicro/go-zero/core/errorx"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. )
  18. /*
  19. //"wechat-api/internal/types/payload"
  20. var p types.payload.SendWxPayload
  21. */
  22. type OpenAuthorityMiddleware struct {
  23. DB *ent.Client
  24. Rds redis.UniversalClient
  25. Config config.Config
  26. }
  27. func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c config.Config) *OpenAuthorityMiddleware {
  28. return &OpenAuthorityMiddleware{
  29. DB: db,
  30. Rds: rds,
  31. Config: c,
  32. }
  33. }
  34. func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  35. var (
  36. rc int
  37. err error
  38. apiKeyObj *ent.ApiKey
  39. fromId = -1
  40. )
  41. _ = fromId
  42. //首先从redis取数据
  43. apiKeyObj, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
  44. if rc <= 0 || err != nil { //无法获得后再从数据库获得 {
  45. rc = 0
  46. err = nil
  47. apiKeyObj, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
  48. if err == nil {
  49. //get apiKeyObj from db succ
  50. fromId = 1
  51. err = m.saveTokenUserInfoToRds(ctx, authToken, apiKeyObj)
  52. }
  53. } else {
  54. fromId = 2
  55. }
  56. /*
  57. if err == nil {
  58. fromStr := ""
  59. switch fromId {
  60. case 1:
  61. fromStr = "DB"
  62. case 2:
  63. fromStr = "RDS"
  64. }
  65. fmt.Println("=========================================>>>")
  66. fmt.Printf("In checkTokenUserInfo Get Token Info From %s\n", fromStr)
  67. fmt.Printf("Key:'%s'\n", apiKeyObj.Key)
  68. fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
  69. fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  70. fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  71. fmt.Println("<<<=========================================")
  72. }
  73. */
  74. return apiKeyObj, rc, err
  75. }
  76. func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
  77. bs, err := json.Marshal(saveInfo)
  78. if err == nil {
  79. _, err = m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
  80. }
  81. return err
  82. }
  83. func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  84. rcode := -1
  85. result := ent.ApiKey{}
  86. jsonStr, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
  87. if errors.Is(err, redis.Nil) {
  88. rcode = 0 //key not exist
  89. } else if err == nil { //find key
  90. err := json.Unmarshal([]byte(jsonStr), &result)
  91. if err == nil {
  92. rcode = 1
  93. }
  94. }
  95. return &result, rcode, err
  96. }
  97. func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
  98. rcode := -1
  99. var predicates []predicate.ApiKey
  100. predicates = append(predicates, apikey.KeyEQ(loginToken))
  101. val, err := m.DB.ApiKey.Query().Where(predicates...).WithAgent().Only(ctx)
  102. if err != nil {
  103. return nil, rcode, err
  104. }
  105. return val, 1, nil
  106. }
  107. func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  108. return func(w http.ResponseWriter, r *http.Request) {
  109. ctx := r.Context()
  110. ctx = contextkey.HttpResponseWriterKey.WithValue(ctx, w) //context存入http.ResponseWriter
  111. authToken := jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
  112. if len(authToken) == 0 {
  113. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取token"))
  114. return
  115. }
  116. apiKeyObj, _, err := m.checkTokenUserInfo(ctx, authToken)
  117. if err != nil {
  118. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
  119. return
  120. }
  121. ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
  122. ctx = contextkey.OpenapiTokenKey.WithValue(ctx, authToken)
  123. newReq := r.WithContext(ctx)
  124. // Passthrough to next handler if need
  125. next(w, newReq)
  126. }
  127. }