exception_handler.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from fastapi import FastAPI, Request
  4. from fastapi.exceptions import RequestValidationError
  5. from pydantic import ValidationError
  6. from pydantic.errors import PydanticUserError
  7. from starlette.exceptions import HTTPException
  8. from starlette.middleware.cors import CORSMiddleware
  9. from uvicorn.protocols.http.h11_impl import STATUS_PHRASES
  10. from common.exception.errors import BaseExceptionMixin
  11. from common.log import log
  12. from common.response.response_code import CustomResponseCode, StandardResponseCode
  13. from common.response.response_schema import response_base
  14. from common.schema import (
  15. CUSTOM_USAGE_ERROR_MESSAGES,
  16. CUSTOM_VALIDATION_ERROR_MESSAGES,
  17. )
  18. from core.conf import settings
  19. from utils.serializers import MsgSpecJSONResponse
  20. def _get_exception_code(status_code: int):
  21. """
  22. 获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接
  23. `python 状态码标准支持 <https://github.com/python/cpython/blob/6e3cc72afeaee2532b4327776501eb8234ac787b/Lib/http
  24. /__init__.py#L7>`__
  25. `IANA 状态码注册表 <https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml>`__
  26. :param status_code:
  27. :return:
  28. """
  29. try:
  30. STATUS_PHRASES[status_code]
  31. except Exception:
  32. code = StandardResponseCode.HTTP_400
  33. else:
  34. code = status_code
  35. return code
  36. async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError):
  37. """
  38. 数据验证异常处理
  39. :param e:
  40. :return:
  41. """
  42. errors = []
  43. for error in e.errors():
  44. custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type'])
  45. if custom_message:
  46. ctx = error.get('ctx')
  47. if not ctx:
  48. error['msg'] = custom_message
  49. else:
  50. error['msg'] = custom_message.format(**ctx)
  51. ctx_error = ctx.get('error')
  52. if ctx_error:
  53. error['ctx']['error'] = (
  54. ctx_error.__str__().replace("'", '"') if isinstance(ctx_error, Exception) else None
  55. )
  56. errors.append(error)
  57. error = errors[0]
  58. if error.get('type') == 'json_invalid':
  59. message = 'json解析失败'
  60. else:
  61. error_input = error.get('input')
  62. field = str(error.get('loc')[-1])
  63. error_msg = error.get('msg')
  64. message = f'{error_msg}{field},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg
  65. msg = f'请求参数非法: {message}'
  66. data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None
  67. content = {
  68. 'code': StandardResponseCode.HTTP_422,
  69. 'msg': msg,
  70. 'data': data,
  71. }
  72. request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息
  73. return MsgSpecJSONResponse(status_code=422, content=content)
  74. def register_exception(app: FastAPI):
  75. @app.exception_handler(HTTPException)
  76. async def http_exception_handler(request: Request, exc: HTTPException):
  77. """
  78. 全局HTTP异常处理
  79. :param request:
  80. :param exc:
  81. :return:
  82. """
  83. if settings.ENVIRONMENT == 'dev':
  84. content = {
  85. 'code': exc.status_code,
  86. 'msg': exc.detail,
  87. 'data': None,
  88. }
  89. else:
  90. res = response_base.fail(res=CustomResponseCode.HTTP_400)
  91. content = res.model_dump()
  92. request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息
  93. return MsgSpecJSONResponse(
  94. status_code=_get_exception_code(exc.status_code),
  95. content=content,
  96. headers=exc.headers,
  97. )
  98. @app.exception_handler(RequestValidationError)
  99. async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError):
  100. """
  101. fastapi 数据验证异常处理
  102. :param request:
  103. :param exc:
  104. :return:
  105. """
  106. return await _validation_exception_handler(request, exc)
  107. @app.exception_handler(ValidationError)
  108. async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
  109. """
  110. pydantic 数据验证异常处理
  111. :param request:
  112. :param exc:
  113. :return:
  114. """
  115. return await _validation_exception_handler(request, exc)
  116. @app.exception_handler(PydanticUserError)
  117. async def pydantic_user_error_handler(request: Request, exc: PydanticUserError):
  118. """
  119. Pydantic 用户异常处理
  120. :param request:
  121. :param exc:
  122. :return:
  123. """
  124. return MsgSpecJSONResponse(
  125. status_code=StandardResponseCode.HTTP_500,
  126. content={
  127. 'code': StandardResponseCode.HTTP_500,
  128. 'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code),
  129. 'data': None,
  130. },
  131. )
  132. @app.exception_handler(AssertionError)
  133. async def assertion_error_handler(request: Request, exc: AssertionError):
  134. """
  135. 断言错误处理
  136. :param request:
  137. :param exc:
  138. :return:
  139. """
  140. if settings.ENVIRONMENT == 'dev':
  141. content = {
  142. 'code': StandardResponseCode.HTTP_500,
  143. 'msg': str(''.join(exc.args) if exc.args else exc.__doc__),
  144. 'data': None,
  145. }
  146. else:
  147. res = response_base.fail(res=CustomResponseCode.HTTP_500)
  148. content = res.model_dump()
  149. return MsgSpecJSONResponse(
  150. status_code=StandardResponseCode.HTTP_500,
  151. content=content,
  152. )
  153. @app.exception_handler(Exception)
  154. async def all_exception_handler(request: Request, exc: Exception):
  155. """
  156. 全局异常处理
  157. :param request:
  158. :param exc:
  159. :return:
  160. """
  161. if isinstance(exc, BaseExceptionMixin):
  162. return MsgSpecJSONResponse(
  163. status_code=_get_exception_code(exc.code),
  164. content={
  165. 'code': exc.code,
  166. 'msg': str(exc.msg),
  167. 'data': exc.data if exc.data else None,
  168. },
  169. background=exc.background,
  170. )
  171. else:
  172. import traceback
  173. log.error(f'未知异常: {exc}')
  174. log.error(traceback.format_exc())
  175. if settings.ENVIRONMENT == 'dev':
  176. content = {
  177. 'code': StandardResponseCode.HTTP_500,
  178. 'msg': str(exc),
  179. 'data': None,
  180. }
  181. else:
  182. res = response_base.fail(res=CustomResponseCode.HTTP_500)
  183. content = res.model_dump()
  184. return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_500, content=content)
  185. if settings.MIDDLEWARE_CORS:
  186. @app.exception_handler(StandardResponseCode.HTTP_500)
  187. async def cors_status_code_500_exception_handler(request, exc):
  188. """
  189. 跨域 500 异常处理
  190. `Related issue <https://github.com/encode/starlette/issues/1175>`_
  191. :param request:
  192. :param exc:
  193. :return:
  194. """
  195. if isinstance(exc, BaseExceptionMixin):
  196. content = {
  197. 'code': exc.code,
  198. 'msg': exc.msg,
  199. 'data': exc.data,
  200. }
  201. else:
  202. if settings.ENVIRONMENT == 'dev':
  203. content = {
  204. 'code': StandardResponseCode.HTTP_500,
  205. 'msg': str(exc),
  206. 'data': None,
  207. }
  208. else:
  209. res = response_base.fail(res=CustomResponseCode.HTTP_500)
  210. content = res.model_dump()
  211. response = MsgSpecJSONResponse(
  212. status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500,
  213. content=content,
  214. background=exc.background if isinstance(exc, BaseExceptionMixin) else None,
  215. )
  216. origin = request.headers.get('origin')
  217. if origin:
  218. cors = CORSMiddleware(
  219. app=app,
  220. allow_origins=['*'],
  221. allow_credentials=True,
  222. allow_methods=['*'],
  223. allow_headers=['*'],
  224. )
  225. response.headers.update(cors.simple_headers)
  226. has_cookie = 'cookie' in request.headers
  227. if cors.allow_all_origins and has_cookie:
  228. response.headers['Access-Control-Allow-Origin'] = origin
  229. elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin):
  230. response.headers['Access-Control-Allow-Origin'] = origin
  231. response.headers.add_vary_header('Origin')
  232. return response