shells.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # -*- coding: utf-8 -*-
  2. import ast
  3. import six
  4. import traceback
  5. import warnings
  6. import importlib
  7. from typing import ( # NOQA
  8. Dict,
  9. List,
  10. Tuple,
  11. Union,
  12. )
  13. from django import VERSION as DJANGO_VERSION
  14. from django.apps.config import MODELS_MODULE_NAME
  15. from django.utils.module_loading import import_string
  16. from django_extensions.collision_resolvers import CollisionResolvingRunner
  17. from django_extensions.import_subclasses import SubclassesFinder
  18. from django_extensions.utils.deprecation import RemovedInNextVersionWarning
  19. SHELL_PLUS_DJANGO_IMPORTS = [
  20. 'from django.core.cache import cache',
  21. 'from django.conf import settings',
  22. 'from django.contrib.auth import get_user_model',
  23. 'from django.db import transaction',
  24. 'from django.db.models import Avg, Case, Count, F, Max, Min, Prefetch, Q, Sum, When',
  25. 'from django.utils import timezone',
  26. ]
  27. if DJANGO_VERSION < (1, 10):
  28. SHELL_PLUS_DJANGO_IMPORTS.append(
  29. 'from django.core.urlresolvers import reverse',
  30. )
  31. else:
  32. SHELL_PLUS_DJANGO_IMPORTS.append(
  33. 'from django.urls import reverse',
  34. )
  35. if DJANGO_VERSION >= (1, 11):
  36. SHELL_PLUS_DJANGO_IMPORTS.append(
  37. 'from django.db.models import Exists, OuterRef, Subquery',
  38. )
  39. class ObjectImportError(Exception):
  40. pass
  41. def get_app_name(mod_name):
  42. """
  43. Retrieve application name from models.py module path
  44. >>> get_app_name('testapp.models.foo')
  45. 'testapp'
  46. 'testapp' instead of 'some.testapp' for compatibility:
  47. >>> get_app_name('some.testapp.models.foo')
  48. 'testapp'
  49. >>> get_app_name('some.models.testapp.models.foo')
  50. 'testapp'
  51. >>> get_app_name('testapp.foo')
  52. 'testapp'
  53. >>> get_app_name('some.testapp.foo')
  54. 'testapp'
  55. """
  56. rparts = list(reversed(mod_name.split('.')))
  57. try:
  58. try:
  59. return rparts[rparts.index(MODELS_MODULE_NAME) + 1]
  60. except ValueError:
  61. # MODELS_MODULE_NAME ('models' string) is not found
  62. return rparts[1]
  63. except IndexError:
  64. # Some weird model naming scheme like in Sentry.
  65. return mod_name
  66. def import_items(import_directives, style, quiet_load=False):
  67. """
  68. Import the items in import_directives and return a list of the imported items
  69. Each item in import_directives should be one of the following forms
  70. * a tuple like ('module.submodule', ('classname1', 'classname2')), which indicates a 'from module.submodule import classname1, classname2'
  71. * a tuple like ('module.submodule', 'classname1'), which indicates a 'from module.submodule import classname1'
  72. * a tuple like ('module.submodule', '*'), which indicates a 'from module.submodule import *'
  73. * a simple 'module.submodule' which indicates 'import module.submodule'.
  74. Returns a dict mapping the names to the imported items
  75. """
  76. imported_objects = {}
  77. for directive in import_directives:
  78. if isinstance(directive, six.string_types):
  79. directive = directive.strip()
  80. try:
  81. if isinstance(directive, six.string_types) and directive.startswith(("from ", "import ")):
  82. try:
  83. node = ast.parse(directive)
  84. except Exception as exc:
  85. if not quiet_load:
  86. print(style.ERROR("Error parsing: %r %s" % (directive, exc)))
  87. continue
  88. if not all(isinstance(body, (ast.Import, ast.ImportFrom)) for body in node.body):
  89. if not quiet_load:
  90. print(style.ERROR("Only specify import statements: %r" % directive))
  91. continue
  92. if not quiet_load:
  93. print(style.SQL_COLTYPE("%s" % directive))
  94. for body in node.body:
  95. if isinstance(body, ast.Import):
  96. for name in body.names:
  97. asname = name.asname or name.name
  98. imported_objects[asname] = importlib.import_module(name.name)
  99. if isinstance(body, ast.ImportFrom):
  100. imported_object = importlib.__import__(body.module, {}, {}, [name.name for name in body.names])
  101. for name in body.names:
  102. asname = name.asname or name.name
  103. try:
  104. if name.name == "*":
  105. for k in dir(imported_object):
  106. imported_objects[k] = getattr(imported_object, k)
  107. else:
  108. imported_objects[asname] = getattr(imported_object, name.name)
  109. except AttributeError as exc:
  110. print(dir(imported_object))
  111. # raise
  112. raise ImportError(exc)
  113. else:
  114. warnings.warn("Old style import definitions are deprecated. You should use the new style which is similar to normal Python imports. ", RemovedInNextVersionWarning, stacklevel=2)
  115. if isinstance(directive, six.string_types):
  116. imported_object = __import__(directive)
  117. imported_objects[directive.split('.')[0]] = imported_object
  118. if not quiet_load:
  119. print(style.SQL_COLTYPE("import %s" % directive))
  120. continue
  121. elif isinstance(directive, (list, tuple)) and len(directive) == 2:
  122. if not isinstance(directive[0], six.string_types):
  123. if not quiet_load:
  124. print(style.ERROR("Unable to import %r: module name must be of type string" % directive[0]))
  125. continue
  126. if isinstance(directive[1], (list, tuple)) and all(isinstance(e, six.string_types) for e in directive[1]):
  127. # Try the ('module.submodule', ('classname1', 'classname2')) form
  128. imported_object = __import__(directive[0], {}, {}, directive[1])
  129. imported_names = []
  130. for name in directive[1]:
  131. try:
  132. imported_objects[name] = getattr(imported_object, name)
  133. except AttributeError:
  134. if not quiet_load:
  135. print(style.ERROR("Unable to import %r from %r: %r does not exist" % (name, directive[0], name)))
  136. else:
  137. imported_names.append(name)
  138. if not quiet_load:
  139. print(style.SQL_COLTYPE("from %s import %s" % (directive[0], ', '.join(imported_names))))
  140. elif isinstance(directive[1], six.string_types):
  141. # If it is a tuple, but the second item isn't a list, so we have something like ('module.submodule', 'classname1')
  142. # Check for the special '*' to import all
  143. if directive[1] == '*':
  144. imported_object = __import__(directive[0], {}, {}, directive[1])
  145. for k in dir(imported_object):
  146. imported_objects[k] = getattr(imported_object, k)
  147. if not quiet_load:
  148. print(style.SQL_COLTYPE("from %s import *" % directive[0]))
  149. else:
  150. imported_object = getattr(__import__(directive[0], {}, {}, [directive[1]]), directive[1])
  151. imported_objects[directive[1]] = imported_object
  152. if not quiet_load:
  153. print(style.SQL_COLTYPE("from %s import %s" % (directive[0], directive[1])))
  154. else:
  155. if not quiet_load:
  156. print(style.ERROR("Unable to import %r from %r: names must be of type string" % (directive[1], directive[0])))
  157. else:
  158. if not quiet_load:
  159. print(style.ERROR("Unable to import %r: names must be of type string" % directive))
  160. except ImportError:
  161. if not quiet_load:
  162. print(style.ERROR("Unable to import %r" % directive))
  163. return imported_objects
  164. def import_objects(options, style):
  165. from django.apps import apps
  166. from django import setup
  167. if not apps.ready:
  168. setup()
  169. from django.conf import settings
  170. dont_load_cli = options.get('dont_load', [])
  171. dont_load_conf = getattr(settings, 'SHELL_PLUS_DONT_LOAD', [])
  172. dont_load = dont_load_cli + dont_load_conf
  173. dont_load_any_models = '*' in dont_load
  174. quiet_load = options.get('quiet_load')
  175. model_aliases = getattr(settings, 'SHELL_PLUS_MODEL_ALIASES', {})
  176. app_prefixes = getattr(settings, 'SHELL_PLUS_APP_PREFIXES', {})
  177. SHELL_PLUS_PRE_IMPORTS = getattr(settings, 'SHELL_PLUS_PRE_IMPORTS', {})
  178. imported_objects = {}
  179. load_models = {}
  180. def get_dict_from_names_to_possible_models(): # type: () -> Dict[str, List[str]]
  181. """
  182. Collect dictionary from names to possible models. Model is represented as his full path.
  183. Name of model can be alias if SHELL_PLUS_MODEL_ALIASES or SHELL_PLUS_APP_PREFIXES is specified for this model.
  184. This dictionary is used by collision resolver.
  185. At this phase we can't import any models, because collision resolver can change results.
  186. :return: Dict[str, List[str]]. Key is name, value is list of full model's path's.
  187. """
  188. models_to_import = {} # type: Dict[str, List[str]]
  189. for app_mod, models in sorted(six.iteritems(load_models)):
  190. app_name = get_app_name(app_mod)
  191. app_aliases = model_aliases.get(app_name, {})
  192. prefix = app_prefixes.get(app_name)
  193. for model_name in sorted(models):
  194. if "%s.%s" % (app_name, model_name) in dont_load:
  195. continue
  196. alias = app_aliases.get(model_name)
  197. if not alias:
  198. if prefix:
  199. alias = "%s_%s" % (prefix, model_name)
  200. else:
  201. alias = model_name
  202. models_to_import.setdefault(alias, [])
  203. models_to_import[alias].append("%s.%s" % (app_mod, model_name))
  204. return models_to_import
  205. def import_subclasses():
  206. base_classes_to_import = getattr(settings, 'SHELL_PLUS_SUBCLASSES_IMPORT', []) # type: List[Union[str, type]]
  207. if base_classes_to_import:
  208. if not quiet_load:
  209. print(style.SQL_TABLE("# Shell Plus Subclasses Imports"))
  210. perform_automatic_imports(SubclassesFinder(base_classes_to_import).collect_subclasses())
  211. def import_models():
  212. """
  213. Perform collision resolving and imports all models.
  214. When collisions are resolved we can perform imports and print information's, because it is last phase.
  215. This function updates imported_objects dictionary.
  216. """
  217. modules_to_models = CollisionResolvingRunner().run_collision_resolver(get_dict_from_names_to_possible_models())
  218. perform_automatic_imports(modules_to_models)
  219. def perform_automatic_imports(modules_to_classes): # type: (Dict[str, List[Tuple[str, str]]]) -> ()
  220. """
  221. Import elements from given dictionary.
  222. :param modules_to_classes: dictionary from module name to tuple.
  223. First element of tuple is model name, second is model alias.
  224. If both elements are equal than element is imported without alias.
  225. """
  226. for full_module_path, models in modules_to_classes.items():
  227. model_labels = []
  228. for (model_name, alias) in sorted(models):
  229. try:
  230. imported_objects[alias] = import_string("%s.%s" % (full_module_path, model_name))
  231. if model_name == alias:
  232. model_labels.append(model_name)
  233. else:
  234. model_labels.append("%s (as %s)" % (model_name, alias))
  235. except ImportError as e:
  236. if options.get("traceback"):
  237. traceback.print_exc()
  238. if not options.get('quiet_load'):
  239. print(style.ERROR(
  240. "Failed to import '%s' from '%s' reason: %s" % (model_name, full_module_path, str(e))))
  241. if not options.get('quiet_load'):
  242. print(style.SQL_COLTYPE("from %s import %s" % (full_module_path, ", ".join(model_labels))))
  243. def get_apps_and_models():
  244. for app in apps.get_app_configs():
  245. if app.models_module:
  246. yield app.models_module, app.get_models()
  247. mongoengine = False
  248. try:
  249. from mongoengine.base import _document_registry
  250. mongoengine = True
  251. except ImportError:
  252. pass
  253. # Perform pre-imports before any other imports
  254. if SHELL_PLUS_PRE_IMPORTS:
  255. if not quiet_load:
  256. print(style.SQL_TABLE("# Shell Plus User Pre Imports"))
  257. imports = import_items(SHELL_PLUS_PRE_IMPORTS, style, quiet_load=quiet_load)
  258. for k, v in six.iteritems(imports):
  259. imported_objects[k] = v
  260. if mongoengine and not dont_load_any_models:
  261. for name, mod in six.iteritems(_document_registry):
  262. name = name.split('.')[-1]
  263. app_name = get_app_name(mod.__module__)
  264. if app_name in dont_load or ("%s.%s" % (app_name, name)) in dont_load:
  265. continue
  266. load_models.setdefault(mod.__module__, [])
  267. load_models[mod.__module__].append(name)
  268. if not dont_load_any_models:
  269. for app_mod, app_models in get_apps_and_models():
  270. if not app_models:
  271. continue
  272. app_name = get_app_name(app_mod.__name__)
  273. if app_name in dont_load:
  274. continue
  275. for mod in app_models:
  276. if "%s.%s" % (app_name, mod.__name__) in dont_load:
  277. continue
  278. if mod.__module__:
  279. # Only add the module to the dict if `__module__` is not empty.
  280. load_models.setdefault(mod.__module__, [])
  281. load_models[mod.__module__].append(mod.__name__)
  282. import_subclasses()
  283. if not quiet_load:
  284. print(style.SQL_TABLE("# Shell Plus Model Imports%s") % (' SKIPPED' if dont_load_any_models else ''))
  285. import_models()
  286. # Imports often used from Django
  287. if getattr(settings, 'SHELL_PLUS_DJANGO_IMPORTS', True):
  288. if not quiet_load:
  289. print(style.SQL_TABLE("# Shell Plus Django Imports"))
  290. imports = import_items(SHELL_PLUS_DJANGO_IMPORTS, style, quiet_load=quiet_load)
  291. for k, v in six.iteritems(imports):
  292. imported_objects[k] = v
  293. SHELL_PLUS_IMPORTS = getattr(settings, 'SHELL_PLUS_IMPORTS', {})
  294. if SHELL_PLUS_IMPORTS:
  295. if not quiet_load:
  296. print(style.SQL_TABLE("# Shell Plus User Imports"))
  297. imports = import_items(SHELL_PLUS_IMPORTS, style, quiet_load=quiet_load)
  298. for k, v in six.iteritems(imports):
  299. imported_objects[k] = v
  300. # Perform post-imports after any other imports
  301. SHELL_PLUS_POST_IMPORTS = getattr(settings, 'SHELL_PLUS_POST_IMPORTS', {})
  302. if SHELL_PLUS_POST_IMPORTS:
  303. if not quiet_load:
  304. print(style.SQL_TABLE("# Shell Plus User Post Imports"))
  305. imports = import_items(SHELL_PLUS_POST_IMPORTS, style, quiet_load=quiet_load)
  306. for k, v in six.iteritems(imports):
  307. imported_objects[k] = v
  308. return imported_objects