package middleware import ( "context" "errors" "github.com/suyuan32/simple-admin-core/rpc/coreclient" "github.com/suyuan32/simple-admin-core/rpc/types/core" "net/http" "strings" "github.com/casbin/casbin/v2" "github.com/redis/go-redis/v9" "github.com/suyuan32/simple-admin-common/config" "github.com/suyuan32/simple-admin-common/enum/errorcode" "github.com/suyuan32/simple-admin-common/utils/jwt" "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/httpx" ) type AuthorityMiddleware struct { Cbn *casbin.Enforcer Rds redis.UniversalClient CoreRpc coreclient.Core } func NewAuthorityMiddleware(cbn *casbin.Enforcer, rds redis.UniversalClient, c coreclient.Core) *AuthorityMiddleware { return &AuthorityMiddleware{ Cbn: cbn, Rds: rds, CoreRpc: c, } } func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // get the path obj := r.URL.Path // get the method act := r.Method // get the role id roleIds := r.Context().Value("roleId").(string) // check jwt blacklist jwtResult, err := m.Rds.Get(context.Background(), config.RedisTokenPrefix+jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))).Result() if err != nil && !errors.Is(err, redis.Nil) { logx.Errorw("redis error in jwt", logx.Field("detail", err.Error())) httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error())) return } if jwtResult == "1" { logx.Errorw("token in blacklist", logx.Field("detail", r.Header.Get("Authorization"))) httpx.Error(w, errorx.NewApiErrorWithoutMsg(http.StatusUnauthorized)) return } result := batchCheck(m.Cbn, roleIds, act, obj) if result { userIdStr := r.Context().Value("userId").(string) logx.Infow("HTTP/HTTPS Request", logx.Field("UUID", userIdStr), logx.Field("path", obj), logx.Field("method", act)) data, err := m.CoreRpc.GetUserById(r.Context(), &core.UUIDReq{Id: userIdStr}) if err != nil { logx.Errorw("get user info error", logx.Field("detail", err.Error())) httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error())) return } // 将 data.DepartmentID 插入上下文,以供后续接口使用 //fmt.Printf("---------------departmentId----------------: %d\n\n", *data.DepartmentId) r = r.WithContext(context.WithValue(r.Context(), "organizationId", *data.DepartmentId)) r = r.WithContext(context.WithValue(r.Context(), "isAdmin", stringInSlice(roleIds, []string{"001"}))) next(w, r) return } else { logx.Errorw("the role is not permitted to access the API", logx.Field("roleId", roleIds), logx.Field("path", obj), logx.Field("method", act)) httpx.Error(w, errorx.NewCodeError(errorcode.PermissionDenied, "Permission Denied")) return } } } func batchCheck(cbn *casbin.Enforcer, roleIds, act, obj string) bool { var checkReq [][]any for _, v := range strings.Split(roleIds, ",") { checkReq = append(checkReq, []any{v, obj, act}) } result, err := cbn.BatchEnforce(checkReq) if err != nil { logx.Errorw("Casbin enforce error", logx.Field("detail", err.Error())) return false } for _, v := range result { if v { return true } } return false } func stringInSlice(roleIds string, list []string) bool { for _, r := range strings.Split(roleIds, ",") { for _, v := range list { if v == r { return true } } } return false }