opera_log_middleware.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from asyncio import create_task
  4. from asgiref.sync import sync_to_async
  5. from fastapi import Response
  6. from starlette.datastructures import UploadFile
  7. from starlette.middleware.base import BaseHTTPMiddleware
  8. from starlette.requests import Request
  9. from common.dataclasses import RequestCallNextReturn
  10. from common.enums import OperaLogCipherType, StatusType
  11. from common.log import log
  12. from core.conf import settings
  13. from utils.encrypt import AESCipher, ItsDCipher, Md5Cipher
  14. from utils.request_parse import parse_ip_info, parse_user_agent_info
  15. from utils.timezone import timezone
  16. class OperaLogMiddleware(BaseHTTPMiddleware):
  17. """操作日志中间件"""
  18. async def dispatch(self, request: Request, call_next) -> Response:
  19. # 排除记录白名单
  20. path = request.url.path
  21. if path in settings.OPERA_LOG_EXCLUDE or not path.startswith(f'{settings.API_V1_STR}'):
  22. return await call_next(request)
  23. # 请求解析
  24. ip_info = await parse_ip_info(request)
  25. ua_info = await parse_user_agent_info(request)
  26. try:
  27. # 此信息依赖于 jwt 中间件
  28. username = request.user.username
  29. except AttributeError:
  30. username = None
  31. method = request.method
  32. args = await self.get_request_args(request)
  33. args = await self.desensitization(args)
  34. # 设置附加请求信息
  35. request.state.ip = ip_info.ip
  36. request.state.country = ip_info.country
  37. request.state.region = ip_info.region
  38. request.state.city = ip_info.city
  39. request.state.user_agent = ua_info.user_agent
  40. request.state.os = ua_info.os
  41. request.state.browser = ua_info.browser
  42. request.state.device = ua_info.device
  43. # 执行请求
  44. start_time = timezone.now()
  45. res = await self.execute_request(request, call_next)
  46. end_time = timezone.now()
  47. cost_time = (end_time - start_time).total_seconds() * 1000.0
  48. # 此信息只能在请求后获取
  49. _route = request.scope.get('route')
  50. summary = getattr(_route, 'summary', None) or ''
  51. # 日志创建
  52. # opera_log_in = CreateOperaLogParam(
  53. # username=username,
  54. # method=method,
  55. # title=summary,
  56. # path=path,
  57. # ip=request.state.ip,
  58. # country=request.state.country,
  59. # region=request.state.region,
  60. # city=request.state.city,
  61. # user_agent=request.state.user_agent,
  62. # os=request.state.os,
  63. # browser=request.state.browser,
  64. # device=request.state.device,
  65. # args=args,
  66. # status=res.status,
  67. # code=res.code,
  68. # msg=res.msg,
  69. # cost_time=cost_time,
  70. # opera_time=start_time,
  71. # )
  72. # create_task(OperaLogService.create(obj_in=opera_log_in)) # noqa: ignore
  73. # 错误抛出
  74. err = res.err
  75. if err:
  76. raise err from None
  77. return res.response
  78. async def execute_request(self, request: Request, call_next) -> RequestCallNextReturn:
  79. """执行请求"""
  80. code = 200
  81. msg = 'Success'
  82. status = StatusType.enable
  83. err = None
  84. response = None
  85. try:
  86. response = await call_next(request)
  87. except Exception as e:
  88. log.exception(e)
  89. code, msg = await self.request_exception_handler(request, code, msg)
  90. # code 处理包含 SQLAlchemy 和 Pydantic
  91. code = getattr(e, 'code', None) or code
  92. msg = getattr(e, 'msg', None) or msg
  93. status = StatusType.disable
  94. err = e
  95. return RequestCallNextReturn(code=str(code), msg=msg, status=status, err=err, response=response)
  96. @staticmethod
  97. @sync_to_async
  98. def request_exception_handler(request: Request, code: int, msg: str) -> tuple[str, str]:
  99. """请求异常处理器"""
  100. try:
  101. http_exception = request.state.__request_http_exception__
  102. except AttributeError:
  103. pass
  104. else:
  105. code = http_exception.get('code', 500)
  106. msg = http_exception.get('msg', 'Internal Server Error')
  107. try:
  108. validation_exception = request.state.__request_validation_exception__
  109. except AttributeError:
  110. pass
  111. else:
  112. code = validation_exception.get('code', 400)
  113. msg = validation_exception.get('msg', 'Bad Request')
  114. return code, msg
  115. @staticmethod
  116. async def get_request_args(request: Request) -> dict:
  117. """获取请求参数"""
  118. args = dict(request.query_params)
  119. args.update(request.path_params)
  120. # Tip: .body() 必须在 .form() 之前获取
  121. # https://github.com/encode/starlette/discussions/1933
  122. body_data = await request.body()
  123. form_data = await request.form()
  124. if len(form_data) > 0:
  125. args.update({k: v.filename if isinstance(v, UploadFile) else v for k, v in form_data.items()})
  126. else:
  127. if body_data:
  128. json_data = await request.json()
  129. if not isinstance(json_data, dict):
  130. json_data = {
  131. f'{type(json_data)}_to_dict_data': json_data.decode('utf-8')
  132. if isinstance(json_data, bytes)
  133. else json_data
  134. }
  135. args.update(json_data)
  136. return args
  137. @staticmethod
  138. @sync_to_async
  139. def desensitization(args: dict) -> dict | None:
  140. """
  141. 脱敏处理
  142. :param args:
  143. :return:
  144. """
  145. if not args:
  146. args = None
  147. else:
  148. match settings.OPERA_LOG_ENCRYPT:
  149. case OperaLogCipherType.aes:
  150. for key in args.keys():
  151. if key in settings.OPERA_LOG_ENCRYPT_INCLUDE:
  152. args[key] = (AESCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(args[key])).hex()
  153. case OperaLogCipherType.md5:
  154. for key in args.keys():
  155. if key in settings.OPERA_LOG_ENCRYPT_INCLUDE:
  156. args[key] = Md5Cipher.encrypt(args[key])
  157. case OperaLogCipherType.itsdangerous:
  158. for key in args.keys():
  159. if key in settings.OPERA_LOG_ENCRYPT_INCLUDE:
  160. args[key] = ItsDCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(args[key])
  161. case OperaLogCipherType.plan:
  162. pass
  163. case _:
  164. for key in args.keys():
  165. if key in settings.OPERA_LOG_ENCRYPT_INCLUDE:
  166. args[key] = '******'
  167. return args