|
@@ -3,9 +3,8 @@ package middleware
|
|
import (
|
|
import (
|
|
"context"
|
|
"context"
|
|
"encoding/json"
|
|
"encoding/json"
|
|
- "fmt"
|
|
|
|
|
|
+ "errors"
|
|
"net/http"
|
|
"net/http"
|
|
- "reflect"
|
|
|
|
|
|
|
|
"wechat-api/ent"
|
|
"wechat-api/ent"
|
|
"wechat-api/ent/apikey"
|
|
"wechat-api/ent/apikey"
|
|
@@ -42,59 +41,74 @@ func NewOpenAuthorityMiddleware(db *ent.Client, rds redis.UniversalClient, c con
|
|
|
|
|
|
func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
|
|
func (m *OpenAuthorityMiddleware) checkTokenUserInfo(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
|
|
var (
|
|
var (
|
|
- rc int
|
|
|
|
- err error
|
|
|
|
- val *ent.ApiKey
|
|
|
|
|
|
+ rc int
|
|
|
|
+ err error
|
|
|
|
+ apiKeyObj *ent.ApiKey
|
|
|
|
+ fromId = -1
|
|
)
|
|
)
|
|
- val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
|
|
|
|
- return val, rc, err
|
|
|
|
-
|
|
|
|
- /*
|
|
|
|
- r, e = m.getTokenUserInfoByRds(ctx, loginToken)
|
|
|
|
- fmt.Println("redis:", "code-", r, "err-", e)
|
|
|
|
- */
|
|
|
|
- /*
|
|
|
|
- //首先从redis取数据
|
|
|
|
- _, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
|
|
|
|
- fmt.Printf("++++++++++++++++++++++++get authinfo from rds out:%d/err:%s\n", rc, err)
|
|
|
|
- if rc <= 0 || err != nil { //无法获得后再从数据库获得
|
|
|
|
- rc = 0
|
|
|
|
- err = nil
|
|
|
|
- val, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
|
|
|
|
- fmt.Println("----------------------After m.getTokenUserInfoByDb:", val)
|
|
|
|
- err = m.saveTokenUserInfoToRds(ctx, authToken, val)
|
|
|
|
- fmt.Println("------------save saveTokenUserInfoToRd err:", err)
|
|
|
|
|
|
+ _ = fromId
|
|
|
|
+
|
|
|
|
+ //首先从redis取数据
|
|
|
|
+ apiKeyObj, rc, err = m.getTokenUserInfoByRds(ctx, authToken)
|
|
|
|
+ if rc <= 0 || err != nil { //无法获得后再从数据库获得 {
|
|
|
|
+
|
|
|
|
+ rc = 0
|
|
|
|
+ err = nil
|
|
|
|
+ apiKeyObj, rc, err = m.getTokenUserInfoByDb(ctx, authToken)
|
|
|
|
+ if err == nil {
|
|
|
|
+ //get apiKeyObj from db succ
|
|
|
|
+ fromId = 1
|
|
|
|
+ err = m.saveTokenUserInfoToRds(ctx, authToken, apiKeyObj)
|
|
}
|
|
}
|
|
|
|
+ } else {
|
|
|
|
+ fromId = 2
|
|
|
|
+ }
|
|
|
|
|
|
- _ = rc
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, 0, err
|
|
|
|
|
|
+ /*
|
|
|
|
+ if err == nil {
|
|
|
|
+
|
|
|
|
+ fromStr := ""
|
|
|
|
+ switch fromId {
|
|
|
|
+ case 1:
|
|
|
|
+ fromStr = "DB"
|
|
|
|
+ case 2:
|
|
|
|
+ fromStr = "RDS"
|
|
|
|
+ }
|
|
|
|
+ fmt.Println("=========================================>>>")
|
|
|
|
+ fmt.Printf("In checkTokenUserInfo Get Token Info From %s\n", fromStr)
|
|
|
|
+ fmt.Printf("Key:'%s'\n", apiKeyObj.Key)
|
|
|
|
+ fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
|
|
|
|
+ fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
|
|
|
|
+ fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
|
|
|
|
+ fmt.Println("<<<=========================================")
|
|
}
|
|
}
|
|
- return val, 0, nil
|
|
|
|
*/
|
|
*/
|
|
|
|
+ return apiKeyObj, rc, err
|
|
}
|
|
}
|
|
|
|
|
|
func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
|
|
func (m *OpenAuthorityMiddleware) saveTokenUserInfoToRds(ctx context.Context, authToken string, saveInfo *ent.ApiKey) error {
|
|
- if bs, err := json.Marshal(saveInfo); err == nil {
|
|
|
|
- return err
|
|
|
|
- } else {
|
|
|
|
- rc, err := m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
|
|
|
|
- fmt.Printf("#~~~~~~~~~~~~~~~++~~~~~~~~~~~~~HSet Val:%s get Result:%d/%s\n", string(bs), rc, err)
|
|
|
|
- return err
|
|
|
|
|
|
+ bs, err := json.Marshal(saveInfo)
|
|
|
|
+ if err == nil {
|
|
|
|
+
|
|
|
|
+ _, err = m.Rds.HSet(ctx, compapi.APIAuthInfoKey, authToken, string(bs)).Result()
|
|
}
|
|
}
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
|
|
func (m *OpenAuthorityMiddleware) getTokenUserInfoByRds(ctx context.Context, authToken string) (*ent.ApiKey, int, error) {
|
|
|
|
|
|
rcode := -1
|
|
rcode := -1
|
|
- val, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
|
|
|
|
- if err == redis.Nil {
|
|
|
|
- rcode = 0
|
|
|
|
- } else if err == nil {
|
|
|
|
- rcode = 1
|
|
|
|
|
|
+ result := ent.ApiKey{}
|
|
|
|
+ jsonStr, err := m.Rds.HGet(ctx, compapi.APIAuthInfoKey, authToken).Result()
|
|
|
|
+ if errors.Is(err, redis.Nil) {
|
|
|
|
+ rcode = 0 //key not exist
|
|
|
|
+ } else if err == nil { //find key
|
|
|
|
+
|
|
|
|
+ err := json.Unmarshal([]byte(jsonStr), &result)
|
|
|
|
+ if err == nil {
|
|
|
|
+ rcode = 1
|
|
|
|
+ }
|
|
}
|
|
}
|
|
- fmt.Printf("#####################From Redis By Key:'%s' Get '%s'(%s/%T)\n", authToken, val, reflect.TypeOf(val), val)
|
|
|
|
- fmt.Println(val)
|
|
|
|
- return nil, rcode, err
|
|
|
|
|
|
+ return &result, rcode, err
|
|
}
|
|
}
|
|
|
|
|
|
func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
|
|
func (m *OpenAuthorityMiddleware) getTokenUserInfoByDb(ctx context.Context, loginToken string) (*ent.ApiKey, int, error) {
|
|
@@ -124,16 +138,9 @@ func (m *OpenAuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc
|
|
httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
|
|
httpx.Error(w, errorx.NewApiError(http.StatusForbidden, "无法获取合适的授权信息"))
|
|
return
|
|
return
|
|
}
|
|
}
|
|
- //ctx = contextkey.OpenapiTokenKey.WithValue(ctx, apiToken)
|
|
|
|
ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
|
|
ctx = contextkey.AuthTokenInfoKey.WithValue(ctx, apiKeyObj)
|
|
- /*
|
|
|
|
- fmt.Println("=========================================")
|
|
|
|
- fmt.Printf("In Middleware Get Token Info:\nKey:'%s'\n", apiKeyObj.Key)
|
|
|
|
- fmt.Printf("Title:'%s'\n", apiKeyObj.Title)
|
|
|
|
- fmt.Printf("OpenaiBase:'%s'\n", apiKeyObj.OpenaiBase)
|
|
|
|
- fmt.Printf("OpenaiKey:'%s'\n", apiKeyObj.OpenaiKey)
|
|
|
|
- fmt.Println("=========================================")
|
|
|
|
- */
|
|
|
|
|
|
+ ctx = contextkey.OpenapiTokenKey.WithValue(ctx, authToken)
|
|
|
|
+
|
|
newReq := r.WithContext(ctx)
|
|
newReq := r.WithContext(ctx)
|
|
// Passthrough to next handler if need
|
|
// Passthrough to next handler if need
|
|
next(w, newReq)
|
|
next(w, newReq)
|