12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- from typing import Any
- from fastapi import Request, Response
- from fastapi.security.utils import get_authorization_scheme_param
- from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError
- from starlette.requests import HTTPConnection
- from app.admin.schema.org import CurrentIntentOrgIns
- from common.exception.errors import TokenError
- from common.log import log
- from common.security.jwt_call_center import jwt_call_center_authentication
- from core.conf import settings
- from utils.serializers import MsgSpecJSONResponse
- class _AuthenticationError(AuthenticationError):
- """重写内部认证错误类"""
- def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None):
- self.code = code
- self.msg = msg
- self.headers = headers
- class JwtCallCenterAuthMiddleware(AuthenticationBackend):
- """JWT 认证中间件"""
- @staticmethod
- def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> Response:
- """覆盖内部认证错误处理"""
- return MsgSpecJSONResponse(content={'code': exc.code, 'msg': exc.msg, 'data': None}, status_code=exc.code)
- async def authenticate(self, request: Request) -> tuple[AuthCredentials, CurrentIntentOrgIns] | None:
- token = request.headers.get('Authorization')
- if not token:
- return
- if request.url.path in settings.TOKEN_EXCLUDE:
- return
- scheme, token = get_authorization_scheme_param(token)
- if scheme.lower() != "bearer":
- return
- try:
- org = await jwt_call_center_authentication(token)
- except TokenError as exc:
- raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
- except Exception as e:
- log.error(f'JWT 授权异常:{e}')
- raise _AuthenticationError(code=getattr(e, 'code', 500), msg=getattr(e, 'msg', 'Internal Server Error'))
- # 请注意,此返回使用非标准模式,所以在认证通过时,将丢失某些标准特性
- # 标准返回模式请查看:https://www.starlette.io/authentication/
- return AuthCredentials(['authenticated']), org
|