pagination.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. import math
  5. from typing import TYPE_CHECKING, Dict, Generic, Sequence, TypeVar
  6. from fastapi import Depends, Query
  7. from fastapi_pagination import pagination_ctx
  8. from fastapi_pagination.bases import AbstractPage, AbstractParams, RawParams
  9. from fastapi_pagination.ext.sqlalchemy import paginate
  10. from fastapi_pagination.links.bases import create_links
  11. from pydantic import BaseModel
  12. if TYPE_CHECKING:
  13. from sqlalchemy import Select
  14. from sqlalchemy.ext.asyncio import AsyncSession
  15. T = TypeVar('T')
  16. DataT = TypeVar('DataT')
  17. SchemaT = TypeVar('SchemaT')
  18. class _Params(BaseModel, AbstractParams):
  19. page: int = Query(1, ge=1, description='Page number')
  20. size: int = Query(20, gt=0, le=100, description='Page size') # 默认 20 条记录
  21. def to_raw_params(self) -> RawParams:
  22. return RawParams(
  23. limit=self.size,
  24. offset=self.size * (self.page - 1),
  25. )
  26. class _Page(AbstractPage[T], Generic[T]):
  27. items: Sequence[T] # 数据
  28. total: int # 总数据数
  29. page: int # 第n页
  30. size: int # 每页数量
  31. total_pages: int # 总页数
  32. links: Dict[str, str | None] # 跳转链接
  33. __params_type__ = _Params # 使用自定义的Params
  34. @classmethod
  35. def create(
  36. cls,
  37. items: Sequence[T],
  38. total: int,
  39. params: _Params,
  40. ) -> _Page[T]:
  41. page = params.page
  42. size = params.size
  43. total_pages = math.ceil(total / params.size)
  44. links = create_links(**{
  45. 'first': {'page': 1, 'size': f'{size}'},
  46. 'last': {'page': f'{math.ceil(total / params.size)}', 'size': f'{size}'} if total > 0 else None,
  47. 'next': {'page': f'{page + 1}', 'size': f'{size}'} if (page + 1) <= total_pages else None,
  48. 'prev': {'page': f'{page - 1}', 'size': f'{size}'} if (page - 1) >= 1 else None,
  49. }).model_dump()
  50. return cls(items=items, total=total, page=params.page, size=params.size, total_pages=total_pages, links=links)
  51. class _PageData(BaseModel, Generic[DataT]):
  52. page_data: DataT | None = None
  53. async def paging_data(db: AsyncSession, select: Select, page_data_schema: SchemaT) -> dict:
  54. """
  55. 基于 SQLAlchemy 创建分页数据
  56. :param db:
  57. :param select:
  58. :param page_data_schema:
  59. :return:
  60. """
  61. _paginate = await paginate(db, select)
  62. page_data = _PageData[_Page[page_data_schema]](page_data=_paginate).model_dump()['page_data']
  63. return page_data
  64. # 分页依赖注入
  65. DependsPagination = Depends(pagination_ctx(_Page))