openauthority_middleware.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package middleware
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "wechat-api/ent"
  9. "wechat-api/ent/apikey"
  10. "wechat-api/ent/predicate"
  11. "wechat-api/internal/utils/compapi"
  12. "wechat-api/internal/utils/contextkey"
  13. "wechat-api/internal/config"
  14. "github.com/redis/go-redis/v9"
  15. "github.com/suyuan32/simple-admin-common/utils/jwt"
  16. "github.com/zeromicro/go-zero/core/errorx"
  17. "github.com/zeromicro/go-zero/rest/httpx"
  18. )
  19. /*
  20. //"wechat-api/internal/types/payload"
  21. var p types.payload.SendWxPayload
  22. */
  23. type OpenAuthorityMiddleware struct {
  24. DB *ent.Client
  25. Rds redis.UniversalClient
  26. Config config.Config
  27. }
  28. func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c config.Config) *OpenAuthorityMiddleware {
  29. return &OpenAuthorityMiddleware{
  30. DB: db,
  31. Rds: rds,
  32. Config: c,
  33. }
  34. }
  35. func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  36. var (
  37. rc int
  38. err error
  39. apiKeyObj *ent.ApiKey
  40. fromId = -1
  41. )
  42. _ = fromId
  43. //首先从redis取数据
  44. apiKeyObj, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
  45. if rc <= 0 || err != nil { //无法获得后再从数据库获得
  46. rc = 0
  47. err = nil
  48. apiKeyObj, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
  49. if err == nil {
  50. //get apiKeyObj from db succ
  51. fromId = 1
  52. err = m.saveTokenUserInfoToRds(ctx, authToken, apiKeyObj)
  53. }
  54. } else {
  55. fromId = 2
  56. }
  57. if err == nil {
  58. /*
  59. fromStr := ""
  60. switch fromId {
  61. case 1:
  62. fromStr = "DB"
  63. case 2:
  64. fromStr = "RDS"
  65. }
  66. fmt.Println("=========================================>>>")
  67. fmt.Printf("In checkTokenUserInfo Get Token Info From %s\n", fromStr)
  68. fmt.Printf("Auth Token:'%s'\n", apiKeyObj.Key)
  69. fmt.Printf("AgentID:%d\n", apiKeyObj.AgentID)
  70. fmt.Printf("APIBase:'%s'\n", apiKeyObj.Edges.Agent.APIBase)
  71. fmt.Printf("APIKey:'%s'\n", apiKeyObj.Edges.Agent.APIKey)
  72. fmt.Printf("Type:%d\n", apiKeyObj.Edges.Agent.Type)
  73. fmt.Printf("Model:'%s'\n", apiKeyObj.Edges.Agent.Model)
  74. //fmt.Println(typekit.PrettyPrint(apiKeyObj))
  75. //fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
  76. //fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
  77. fmt.Println("<<<=========================================")
  78. */
  79. }
  80. return apiKeyObj, rc, err
  81. }
  82. func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
  83. bs, err := json.Marshal(saveInfo)
  84. if err == nil {
  85. _, err = m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
  86. }
  87. return err
  88. }
  89. func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
  90. rcode := -1
  91. result := ent.ApiKey{}
  92. m.Rds.Del(ctx, compapi.APIAuthInfoKey) //for debug
  93. jsonStr, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
  94. if errors.Is(err, redis.Nil) {
  95. rcode = 0 //key not exist
  96. } else if err == nil { //find key
  97. err := json.Unmarshal([]byte(jsonStr), &result)
  98. if err == nil {
  99. rcode = 1
  100. }
  101. }
  102. return &result, rcode, err
  103. }
  104. func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
  105. rcode := -1
  106. var predicates []predicate.ApiKey
  107. predicates = append(predicates, apikey.KeyEQ(loginToken))
  108. val, err := m.DB.ApiKey.Query().Where(predicates...).WithAgent().Only(ctx)
  109. //val, err := m.DB.ApiKey.Query().Where(predicates...).First(ctx)
  110. if err != nil {
  111. return nil, rcode, err
  112. }
  113. if val.Edges.Agent == nil {
  114. return nil, rcode, fmt.Errorf("edge get agent info is nil by agentid:%d", val.AgentID)
  115. }
  116. return val, 1, nil
  117. }
  118. func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  119. return func(w http.ResponseWriter, r *http.Request) {
  120. ctx := r.Context()
  121. ctx = contextkey.HttpResponseWriterKey.WithValue(ctx, w) //context存入http.ResponseWriter
  122. authToken := jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
  123. if len(authToken) == 0 {
  124. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取token"))
  125. return
  126. }
  127. apiKeyObj, _, err := m.checkTokenUserInfo(ctx, authToken)
  128. if err != nil {
  129. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
  130. return
  131. }
  132. if len(apiKeyObj.OpenaiBase) == 0 {
  133. httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "缺失OpenaiBase"))
  134. return
  135. }
  136. ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
  137. newReq := r.WithContext(ctx)
  138. // Passthrough to next handler if need
  139. next(w, newReq)
  140. }
  141. }