123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- package middleware
- import (
- "context"
- "errors"
- "github.com/suyuan32/simple-admin-core/rpc/coreclient"
- "github.com/suyuan32/simple-admin-core/rpc/types/core"
- "net/http"
- "strings"
- "github.com/casbin/casbin/v2"
- "github.com/redis/go-redis/v9"
- "github.com/suyuan32/simple-admin-common/config"
- "github.com/suyuan32/simple-admin-common/enum/errorcode"
- "github.com/suyuan32/simple-admin-common/utils/jwt"
- "github.com/zeromicro/go-zero/core/errorx"
- "github.com/zeromicro/go-zero/core/logx"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- type AuthorityMiddleware struct {
- Cbn *casbin.Enforcer
- Rds redis.UniversalClient
- CoreRpc coreclient.Core
- }
- func NewAuthorityMiddleware(cbn *casbin.Enforcer, rds redis.UniversalClient, c coreclient.Core) *AuthorityMiddleware {
- return &AuthorityMiddleware{
- Cbn: cbn,
- Rds: rds,
- CoreRpc: c,
- }
- }
- func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- // get the path
- obj := r.URL.Path
- // get the method
- act := r.Method
- // get the role id
- roleIds := r.Context().Value("roleId").(string)
- // check jwt blacklist
- jwtResult, err := m.Rds.Get(context.Background(), config.RedisTokenPrefix+jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))).Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- logx.Errorw("redis error in jwt", logx.Field("detail", err.Error()))
- httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
- return
- }
- if jwtResult == "1" {
- logx.Errorw("token in blacklist", logx.Field("detail", r.Header.Get("Authorization")))
- httpx.Error(w, errorx.NewApiErrorWithoutMsg(http.StatusUnauthorized))
- return
- }
- result := batchCheck(m.Cbn, roleIds, act, obj)
- if result {
- userIdStr := r.Context().Value("userId").(string)
- logx.Infow("HTTP/HTTPS Request", logx.Field("UUID", userIdStr),
- logx.Field("path", obj), logx.Field("method", act))
- data, err := m.CoreRpc.GetUserById(r.Context(), &core.UUIDReq{Id: userIdStr})
- if err != nil {
- logx.Errorw("get user info error", logx.Field("detail", err.Error()))
- httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
- return
- }
- // 将 data.DepartmentID 插入上下文,以供后续接口使用
- //fmt.Printf("---------------departmentId----------------: %d\n\n", *data.DepartmentId)
- r = r.WithContext(context.WithValue(r.Context(), "organizationId", *data.DepartmentId))
- r = r.WithContext(context.WithValue(r.Context(), "isAdmin", stringInSlice(roleIds, []string{"001"})))
- next(w, r)
- return
- } else {
- logx.Errorw("the role is not permitted to access the API", logx.Field("roleId", roleIds),
- logx.Field("path", obj), logx.Field("method", act))
- httpx.Error(w, errorx.NewCodeError(errorcode.PermissionDenied, "Permission Denied"))
- return
- }
- }
- }
- func batchCheck(cbn *casbin.Enforcer, roleIds, act, obj string) bool {
- var checkReq [][]any
- for _, v := range strings.Split(roleIds, ",") {
- checkReq = append(checkReq, []any{v, obj, act})
- }
- result, err := cbn.BatchEnforce(checkReq)
- if err != nil {
- logx.Errorw("Casbin enforce error", logx.Field("detail", err.Error()))
- return false
- }
- for _, v := range result {
- if v {
- return true
- }
- }
- return false
- }
- func stringInSlice(roleIds string, list []string) bool {
- for _, r := range strings.Split(roleIds, ",") {
- for _, v := range list {
- if v == r {
- return true
- }
- }
- }
- return false
- }
|