import_subclasses.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # -*- coding: utf-8 -*-
  2. from importlib import import_module
  3. from inspect import (
  4. getmembers,
  5. isclass,
  6. )
  7. from pkgutil import walk_packages
  8. from typing import ( # NOQA
  9. Dict,
  10. List,
  11. Tuple,
  12. Union,
  13. )
  14. from django.conf import settings
  15. from django.utils.module_loading import import_string
  16. class SubclassesFinder:
  17. def __init__(self, base_classes_from_settings):
  18. self.base_classes = []
  19. for element in base_classes_from_settings:
  20. if isinstance(element, str):
  21. element = import_string(element)
  22. self.base_classes.append(element)
  23. def _should_be_imported(self, candidate_to_import): # type: (Tuple[str, type]) -> bool
  24. for base_class in self.base_classes:
  25. if issubclass(candidate_to_import[1], base_class):
  26. return True
  27. return False
  28. def collect_subclasses(self): # type: () -> Dict[str, List[Tuple[str, str]]]
  29. """
  30. Collect all subclasses of user-defined base classes from project.
  31. :return: Dictionary from module name to list of tuples.
  32. First element of tuple is model name and second is alias.
  33. Currently we set alias equal to model name,
  34. but in future functionality of aliasing subclasses can be added.
  35. """
  36. result = {} # type: Dict[str, List[Tuple[str, str]]]
  37. for loader, module_name, is_pkg in walk_packages(path=[settings.BASE_DIR]):
  38. subclasses_from_module = self._collect_classes_from_module(module_name)
  39. if subclasses_from_module:
  40. result[module_name] = subclasses_from_module
  41. return result
  42. def _collect_classes_from_module(self, module_name): # type: (str) -> List[Tuple[str, str]]
  43. for excluded_module in getattr(settings, 'SHELL_PLUS_SUBCLASSES_IMPORT_MODULES_BLACKLIST', []):
  44. if module_name.startswith(excluded_module):
  45. return []
  46. imported_module = import_module(module_name)
  47. classes_to_import = getmembers(
  48. imported_module, lambda element: isclass(element) and element.__module__ == imported_module.__name__
  49. )
  50. classes_to_import = list(filter(self._should_be_imported, classes_to_import))
  51. return [(name, name) for name, _ in classes_to_import]