authority_middleware.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/suyuan32/simple-admin-core/rpc/coreclient"
  6. "github.com/suyuan32/simple-admin-core/rpc/types/core"
  7. "net/http"
  8. "strings"
  9. "github.com/casbin/casbin/v2"
  10. "github.com/redis/go-redis/v9"
  11. "github.com/suyuan32/simple-admin-common/config"
  12. "github.com/suyuan32/simple-admin-common/enum/errorcode"
  13. "github.com/suyuan32/simple-admin-common/utils/jwt"
  14. "github.com/zeromicro/go-zero/core/errorx"
  15. "github.com/zeromicro/go-zero/core/logx"
  16. "github.com/zeromicro/go-zero/rest/httpx"
  17. )
  18. type AuthorityMiddleware struct {
  19. Cbn *casbin.Enforcer
  20. Rds redis.UniversalClient
  21. CoreRpc coreclient.Core
  22. }
  23. func NewAuthorityMiddleware(cbn *casbin.Enforcer, rds redis.UniversalClient, c coreclient.Core) *AuthorityMiddleware {
  24. return &AuthorityMiddleware{
  25. Cbn: cbn,
  26. Rds: rds,
  27. CoreRpc: c,
  28. }
  29. }
  30. func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  31. return func(w http.ResponseWriter, r *http.Request) {
  32. // get the path
  33. obj := r.URL.Path
  34. // get the method
  35. act := r.Method
  36. // get the role id
  37. roleIds := r.Context().Value("roleId").(string)
  38. // check jwt blacklist
  39. jwtResult, err := m.Rds.Get(context.Background(), config.RedisTokenPrefix+jwt.StripBearerPrefixFromToken(r.Header.Get("Authorization"))).Result()
  40. if err != nil && !errors.Is(err, redis.Nil) {
  41. logx.Errorw("redis error in jwt", logx.Field("detail", err.Error()))
  42. httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
  43. return
  44. }
  45. if jwtResult == "1" {
  46. logx.Errorw("token in blacklist", logx.Field("detail", r.Header.Get("Authorization")))
  47. httpx.Error(w, errorx.NewApiErrorWithoutMsg(http.StatusUnauthorized))
  48. return
  49. }
  50. result := batchCheck(m.Cbn, roleIds, act, obj)
  51. if result {
  52. userIdStr := r.Context().Value("userId").(string)
  53. logx.Infow("HTTP/HTTPS Request", logx.Field("UUID", userIdStr),
  54. logx.Field("path", obj), logx.Field("method", act))
  55. data, err := m.CoreRpc.GetUserById(r.Context(), &core.UUIDReq{Id: userIdStr})
  56. if err != nil {
  57. logx.Errorw("get user info error", logx.Field("detail", err.Error()))
  58. httpx.Error(w, errorx.NewApiError(http.StatusInternalServerError, err.Error()))
  59. return
  60. }
  61. // 将 data.DepartmentID 插入上下文,以供后续接口使用
  62. //fmt.Printf("---------------departmentId----------------: %d\n\n", *data.DepartmentId)
  63. r = r.WithContext(context.WithValue(r.Context(), "organizationId", *data.DepartmentId))
  64. r = r.WithContext(context.WithValue(r.Context(), "isAdmin", stringInSlice(roleIds, []string{"001"})))
  65. next(w, r)
  66. return
  67. } else {
  68. logx.Errorw("the role is not permitted to access the API", logx.Field("roleId", roleIds),
  69. logx.Field("path", obj), logx.Field("method", act))
  70. httpx.Error(w, errorx.NewCodeError(errorcode.PermissionDenied, "Permission Denied"))
  71. return
  72. }
  73. }
  74. }
  75. func batchCheck(cbn *casbin.Enforcer, roleIds, act, obj string) bool {
  76. var checkReq [][]any
  77. for _, v := range strings.Split(roleIds, ",") {
  78. checkReq = append(checkReq, []any{v, obj, act})
  79. }
  80. result, err := cbn.BatchEnforce(checkReq)
  81. if err != nil {
  82. logx.Errorw("Casbin enforce error", logx.Field("detail", err.Error()))
  83. return false
  84. }
  85. for _, v := range result {
  86. if v {
  87. return true
  88. }
  89. }
  90. return false
  91. }
  92. func stringInSlice(roleIds string, list []string) bool {
  93. for _, r := range strings.Split(roleIds, ",") {
  94. for _, v := range list {
  95. if v == r {
  96. return true
  97. }
  98. }
  99. }
  100. return false
  101. }