#!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from pydantic.errors import PydanticUserError from starlette.exceptions import HTTPException from starlette.middleware.cors import CORSMiddleware from uvicorn.protocols.http.h11_impl import STATUS_PHRASES from common.exception.errors import BaseExceptionMixin from common.log import log from common.response.response_code import CustomResponseCode, StandardResponseCode from common.response.response_schema import response_base from common.schema import ( CUSTOM_USAGE_ERROR_MESSAGES, CUSTOM_VALIDATION_ERROR_MESSAGES, ) from core.conf import settings from utils.serializers import MsgSpecJSONResponse def _get_exception_code(status_code: int): """ 获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接 `python 状态码标准支持 `__ `IANA 状态码注册表 `__ :param status_code: :return: """ try: STATUS_PHRASES[status_code] except Exception: code = StandardResponseCode.HTTP_400 else: code = status_code return code async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError): """ 数据验证异常处理 :param e: :return: """ errors = [] for error in e.errors(): custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type']) if custom_message: ctx = error.get('ctx') if not ctx: error['msg'] = custom_message else: error['msg'] = custom_message.format(**ctx) ctx_error = ctx.get('error') if ctx_error: error['ctx']['error'] = ( ctx_error.__str__().replace("'", '"') if isinstance(ctx_error, Exception) else None ) errors.append(error) error = errors[0] if error.get('type') == 'json_invalid': message = 'json解析失败' else: error_input = error.get('input') field = str(error.get('loc')[-1]) error_msg = error.get('msg') message = f'{error_msg}{field},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg msg = f'请求参数非法: {message}' data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None content = { 'code': StandardResponseCode.HTTP_422, 'msg': msg, 'data': data, } request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息 return MsgSpecJSONResponse(status_code=422, content=content) def register_exception(app: FastAPI): @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ 全局HTTP异常处理 :param request: :param exc: :return: """ if settings.ENVIRONMENT == 'dev': content = { 'code': exc.status_code, 'msg': exc.detail, 'data': None, } else: res = response_base.fail(res=CustomResponseCode.HTTP_400) content = res.model_dump() request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息 return MsgSpecJSONResponse( status_code=_get_exception_code(exc.status_code), content=content, headers=exc.headers, ) @app.exception_handler(RequestValidationError) async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError): """ fastapi 数据验证异常处理 :param request: :param exc: :return: """ return await _validation_exception_handler(request, exc) @app.exception_handler(ValidationError) async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): """ pydantic 数据验证异常处理 :param request: :param exc: :return: """ return await _validation_exception_handler(request, exc) @app.exception_handler(PydanticUserError) async def pydantic_user_error_handler(request: Request, exc: PydanticUserError): """ Pydantic 用户异常处理 :param request: :param exc: :return: """ return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_500, content={ 'code': StandardResponseCode.HTTP_500, 'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code), 'data': None, }, ) @app.exception_handler(AssertionError) async def assertion_error_handler(request: Request, exc: AssertionError): """ 断言错误处理 :param request: :param exc: :return: """ if settings.ENVIRONMENT == 'dev': content = { 'code': StandardResponseCode.HTTP_500, 'msg': str(''.join(exc.args) if exc.args else exc.__doc__), 'data': None, } else: res = response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_500, content=content, ) @app.exception_handler(Exception) async def all_exception_handler(request: Request, exc: Exception): """ 全局异常处理 :param request: :param exc: :return: """ if isinstance(exc, BaseExceptionMixin): return MsgSpecJSONResponse( status_code=_get_exception_code(exc.code), content={ 'code': exc.code, 'msg': str(exc.msg), 'data': exc.data if exc.data else None, }, background=exc.background, ) else: import traceback log.error(f'未知异常: {exc}') log.error(traceback.format_exc()) if settings.ENVIRONMENT == 'dev': content = { 'code': StandardResponseCode.HTTP_500, 'msg': str(exc), 'data': None, } else: res = response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_500, content=content) if settings.MIDDLEWARE_CORS: @app.exception_handler(StandardResponseCode.HTTP_500) async def cors_status_code_500_exception_handler(request, exc): """ 跨域 500 异常处理 `Related issue `_ :param request: :param exc: :return: """ if isinstance(exc, BaseExceptionMixin): content = { 'code': exc.code, 'msg': exc.msg, 'data': exc.data, } else: if settings.ENVIRONMENT == 'dev': content = { 'code': StandardResponseCode.HTTP_500, 'msg': str(exc), 'data': None, } else: res = response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() response = MsgSpecJSONResponse( status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500, content=content, background=exc.background if isinstance(exc, BaseExceptionMixin) else None, ) origin = request.headers.get('origin') if origin: cors = CORSMiddleware( app=app, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) response.headers.update(cors.simple_headers) has_cookie = 'cookie' in request.headers if cors.allow_all_origins and has_cookie: response.headers['Access-Control-Allow-Origin'] = origin elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin): response.headers['Access-Control-Allow-Origin'] = origin response.headers.add_vary_header('Origin') return response