build_tree.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from typing import Any, Sequence
  4. from common.enums import BuildTreeType
  5. from utils.serializers import RowData, select_list_serialize
  6. def get_tree_nodes(row: Sequence[RowData]) -> list[dict[str, Any]]:
  7. """获取所有树形结构节点"""
  8. tree_nodes = select_list_serialize(row)
  9. tree_nodes.sort(key=lambda x: x['sort'])
  10. return tree_nodes
  11. def traversal_to_tree(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
  12. """
  13. 通过遍历算法构造树形结构
  14. :param nodes:
  15. :return:
  16. """
  17. tree = []
  18. node_dict = {node['id']: node for node in nodes}
  19. for node in nodes:
  20. parent_id = node['parent_id']
  21. if parent_id is None:
  22. tree.append(node)
  23. else:
  24. parent_node = node_dict.get(parent_id)
  25. if parent_node is not None:
  26. if 'children' not in parent_node:
  27. parent_node['children'] = []
  28. if node not in parent_node['children']:
  29. parent_node['children'].append(node)
  30. else:
  31. if node not in tree:
  32. tree.append(node)
  33. return tree
  34. def recursive_to_tree(nodes: list[dict[str, Any]], *, parent_id: int | None = None) -> list[dict[str, Any]]:
  35. """
  36. 通过递归算法构造树形结构(性能影响较大)
  37. :param nodes:
  38. :param parent_id:
  39. :return:
  40. """
  41. tree = []
  42. for node in nodes:
  43. if node['parent_id'] == parent_id:
  44. child_node = recursive_to_tree(nodes, parent_id=node['id'])
  45. if child_node:
  46. node['children'] = child_node
  47. tree.append(node)
  48. return tree
  49. def get_tree_data(
  50. row: Sequence[RowData], build_type: BuildTreeType = BuildTreeType.traversal, *, parent_id: int | None = None
  51. ) -> list[dict[str, Any]]:
  52. """
  53. 获取树形结构数据
  54. :param row:
  55. :param build_type:
  56. :param parent_id:
  57. :return:
  58. """
  59. nodes = get_tree_nodes(row)
  60. match build_type:
  61. case BuildTreeType.traversal:
  62. tree = traversal_to_tree(nodes)
  63. case BuildTreeType.recursive:
  64. tree = recursive_to_tree(nodes, parent_id=parent_id)
  65. case _:
  66. raise ValueError(f'无效的算法类型:{build_type}')
  67. return tree