authority_middleware.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/suyuan32/simple-admin-core/rpc/coreclient"
  6. "github.com/suyuan32/simple-admin-core/rpc/types/core"
  7. "net/http"
  8. "strings"
  9. "github.com/casbin/casbin/v2"
  10. "github.com/redis/go-redis/v9"
  11. "github.com/suyuan32/simple-admin-common/config"
  12. "github.com/suyuan32/simple-admin-common/enum/errorcode"
  13. "github.com/suyuan32/simple-admin-common/utils/jwt"
  14. "github.com/zeromicro/go-zero/core/errorx"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. )
  18. type AuthorityMiddleware struct {
  19. Cbn *casbin.Enforcer
  20. Rds redis.UniversalClient
  21. CoreRpc coreclient.Core
  22. }
  23. func NewAuthorityMiddleware(cbn *casbin.Enforcer, rds redis.UniversalClient, c coreclient.Core) *AuthorityMiddleware {
  24. return &AuthorityMiddleware{
  25. Cbn: cbn,
  26. Rds: rds,
  27. CoreRpc: c,
  28. }
  29. }
  30. func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  31. return func(w http.ResponseWriter, r *http.Request) {
  32. // get the path
  33. obj := r.URL.Path
  34. // get the method
  35. act := r.Method
  36. // get the role id
  37. roleIds := r.Context().Value("roleId").(string)
  38. // check jwt blacklist
  39. jwtResult, err := m.Rds.Get(context.Background(), config.RedisTokenPrefix+jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))).Result()
  40. if err != nil && !errors.Is(err, redis.Nil) {
  41. logx.Errorw("redis error in jwt", logx.Field("detail", err.Error()))
  42. httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
  43. return
  44. }
  45. if jwtResult == "1" {
  46. logx.Errorw("token in blacklist", logx.Field("detail", r.Header.Get("Authorization")))
  47. httpx.Error(w, errorx.NewApiErrorWithoutMsg(http.StatusUnauthorized))
  48. return
  49. }
  50. result := batchCheck(m.Cbn, roleIds, act, obj)
  51. logx.Errorw("------------------------result--------------------------- ", logx.Field("detail", result))
  52. if result {
  53. userIdStr := r.Context().Value("userId").(string)
  54. logx.Errorw("------------------------userIdStr--------------------------- ", logx.Field("detail", userIdStr))
  55. logx.Infow("HTTP/HTTPS Request", logx.Field("UUID", userIdStr),
  56. logx.Field("path", obj), logx.Field("method", act))
  57. data, err := m.CoreRpc.GetUserById(r.Context(), &core.UUIDReq{Id: userIdStr})
  58. if err != nil {
  59. logx.Errorw("get user info error", logx.Field("detail", err.Error()))
  60. httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
  61. return
  62. }
  63. // 将 data.DepartmentID 插入上下文,以供后续接口使用
  64. //fmt.Printf("---------------departmentId----------------: %d\n\n", *data.DepartmentId)
  65. r = r.WithContext(context.WithValue(r.Context(), "organizationId", *data.DepartmentId))
  66. r = r.WithContext(context.WithValue(r.Context(), "isAdmin", stringInSlice(roleIds, []string{"001"})))
  67. next(w, r)
  68. return
  69. } else {
  70. logx.Errorw("the role is not permitted to access the API", logx.Field("roleId", roleIds),
  71. logx.Field("path", obj), logx.Field("method", act))
  72. httpx.Error(w, errorx.NewCodeError(errorcode.PermissionDenied, "Permission Denied"))
  73. return
  74. }
  75. }
  76. }
  77. func batchCheck(cbn *casbin.Enforcer, roleIds, act, obj string) bool {
  78. var checkReq [][]any
  79. for _, v := range strings.Split(roleIds, ",") {
  80. checkReq = append(checkReq, []any{v, obj, act})
  81. }
  82. result, err := cbn.BatchEnforce(checkReq)
  83. if err != nil {
  84. logx.Errorw("Casbin enforce error", logx.Field("detail", err.Error()))
  85. return false
  86. }
  87. for _, v := range result {
  88. if v {
  89. return true
  90. }
  91. }
  92. return false
  93. }
  94. func stringInSlice(roleIds string, list []string) bool {
  95. for _, r := range strings.Split(roleIds, ",") {
  96. for _, v := range list {
  97. if v == r {
  98. return true
  99. }
  100. }
  101. }
  102. return false
  103. }