123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- from fastapi import FastAPI, Request
- from fastapi.exceptions import RequestValidationError
- from pydantic import ValidationError
- from pydantic.errors import PydanticUserError
- from starlette.exceptions import HTTPException
- from starlette.middleware.cors import CORSMiddleware
- from uvicorn.protocols.http.h11_impl import STATUS_PHRASES
- from common.exception.errors import BaseExceptionMixin
- from common.log import log
- from common.response.response_code import CustomResponseCode, StandardResponseCode
- from common.response.response_schema import response_base
- from common.schema import (
- CUSTOM_USAGE_ERROR_MESSAGES,
- CUSTOM_VALIDATION_ERROR_MESSAGES,
- )
- from core.conf import settings
- from utils.serializers import MsgSpecJSONResponse
- def _get_exception_code(status_code: int):
- """
- 获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接
- `python 状态码标准支持 <https://github.com/python/cpython/blob/6e3cc72afeaee2532b4327776501eb8234ac787b/Lib/http
- /__init__.py#L7>`__
- `IANA 状态码注册表 <https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml>`__
- :param status_code:
- :return:
- """
- try:
- STATUS_PHRASES[status_code]
- except Exception:
- code = StandardResponseCode.HTTP_400
- else:
- code = status_code
- return code
- async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError):
- """
- 数据验证异常处理
- :param e:
- :return:
- """
- errors = []
- for error in e.errors():
- custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type'])
- if custom_message:
- ctx = error.get('ctx')
- if not ctx:
- error['msg'] = custom_message
- else:
- error['msg'] = custom_message.format(**ctx)
- ctx_error = ctx.get('error')
- if ctx_error:
- error['ctx']['error'] = (
- ctx_error.__str__().replace("'", '"') if isinstance(ctx_error, Exception) else None
- )
- errors.append(error)
- error = errors[0]
- if error.get('type') == 'json_invalid':
- message = 'json解析失败'
- else:
- error_input = error.get('input')
- field = str(error.get('loc')[-1])
- error_msg = error.get('msg')
- message = f'{error_msg}{field},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg
- msg = f'请求参数非法: {message}'
- data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None
- content = {
- 'code': StandardResponseCode.HTTP_422,
- 'msg': msg,
- 'data': data,
- }
- request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息
- return MsgSpecJSONResponse(status_code=422, content=content)
- def register_exception(app: FastAPI):
- @app.exception_handler(HTTPException)
- async def http_exception_handler(request: Request, exc: HTTPException):
- """
- 全局HTTP异常处理
- :param request:
- :param exc:
- :return:
- """
- if settings.ENVIRONMENT == 'dev':
- content = {
- 'code': exc.status_code,
- 'msg': exc.detail,
- 'data': None,
- }
- else:
- res = response_base.fail(res=CustomResponseCode.HTTP_400)
- content = res.model_dump()
- request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息
- return MsgSpecJSONResponse(
- status_code=_get_exception_code(exc.status_code),
- content=content,
- headers=exc.headers,
- )
- @app.exception_handler(RequestValidationError)
- async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError):
- """
- fastapi 数据验证异常处理
- :param request:
- :param exc:
- :return:
- """
- return await _validation_exception_handler(request, exc)
- @app.exception_handler(ValidationError)
- async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
- """
- pydantic 数据验证异常处理
- :param request:
- :param exc:
- :return:
- """
- return await _validation_exception_handler(request, exc)
- @app.exception_handler(PydanticUserError)
- async def pydantic_user_error_handler(request: Request, exc: PydanticUserError):
- """
- Pydantic 用户异常处理
- :param request:
- :param exc:
- :return:
- """
- return MsgSpecJSONResponse(
- status_code=StandardResponseCode.HTTP_500,
- content={
- 'code': StandardResponseCode.HTTP_500,
- 'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code),
- 'data': None,
- },
- )
- @app.exception_handler(AssertionError)
- async def assertion_error_handler(request: Request, exc: AssertionError):
- """
- 断言错误处理
- :param request:
- :param exc:
- :return:
- """
- if settings.ENVIRONMENT == 'dev':
- content = {
- 'code': StandardResponseCode.HTTP_500,
- 'msg': str(''.join(exc.args) if exc.args else exc.__doc__),
- 'data': None,
- }
- else:
- res = response_base.fail(res=CustomResponseCode.HTTP_500)
- content = res.model_dump()
- return MsgSpecJSONResponse(
- status_code=StandardResponseCode.HTTP_500,
- content=content,
- )
- @app.exception_handler(Exception)
- async def all_exception_handler(request: Request, exc: Exception):
- """
- 全局异常处理
- :param request:
- :param exc:
- :return:
- """
- if isinstance(exc, BaseExceptionMixin):
- return MsgSpecJSONResponse(
- status_code=_get_exception_code(exc.code),
- content={
- 'code': exc.code,
- 'msg': str(exc.msg),
- 'data': exc.data if exc.data else None,
- },
- background=exc.background,
- )
- else:
- import traceback
- log.error(f'未知异常: {exc}')
- log.error(traceback.format_exc())
- if settings.ENVIRONMENT == 'dev':
- content = {
- 'code': StandardResponseCode.HTTP_500,
- 'msg': str(exc),
- 'data': None,
- }
- else:
- res = response_base.fail(res=CustomResponseCode.HTTP_500)
- content = res.model_dump()
- return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_500, content=content)
- if settings.MIDDLEWARE_CORS:
- @app.exception_handler(StandardResponseCode.HTTP_500)
- async def cors_status_code_500_exception_handler(request, exc):
- """
- 跨域 500 异常处理
- `Related issue <https://github.com/encode/starlette/issues/1175>`_
- :param request:
- :param exc:
- :return:
- """
- if isinstance(exc, BaseExceptionMixin):
- content = {
- 'code': exc.code,
- 'msg': exc.msg,
- 'data': exc.data,
- }
- else:
- if settings.ENVIRONMENT == 'dev':
- content = {
- 'code': StandardResponseCode.HTTP_500,
- 'msg': str(exc),
- 'data': None,
- }
- else:
- res = response_base.fail(res=CustomResponseCode.HTTP_500)
- content = res.model_dump()
- response = MsgSpecJSONResponse(
- status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500,
- content=content,
- background=exc.background if isinstance(exc, BaseExceptionMixin) else None,
- )
- origin = request.headers.get('origin')
- if origin:
- cors = CORSMiddleware(
- app=app,
- allow_origins=['*'],
- allow_credentials=True,
- allow_methods=['*'],
- allow_headers=['*'],
- )
- response.headers.update(cors.simple_headers)
- has_cookie = 'cookie' in request.headers
- if cors.allow_all_origins and has_cookie:
- response.headers['Access-Control-Allow-Origin'] = origin
- elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin):
- response.headers['Access-Control-Allow-Origin'] = origin
- response.headers.add_vary_header('Origin')
- return response
|