class_registry.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # -*- coding: utf-8 -*-
  2. """A registry of :class:`Schema <marshmallow.Schema>` classes. This allows for string
  3. lookup of schemas, which may be used with
  4. class:`fields.Nested <marshmallow.fields.Nested>`.
  5. .. warning::
  6. This module is treated as private API.
  7. Users should not need to use this module directly.
  8. """
  9. from __future__ import unicode_literals
  10. from marshmallow.exceptions import RegistryError
  11. # {
  12. # <class_name>: <list of class objects>
  13. # <module_path_to_class>: <list of class objects>
  14. # }
  15. _registry = {}
  16. def register(classname, cls):
  17. """Add a class to the registry of serializer classes. When a class is
  18. registered, an entry for both its classname and its full, module-qualified
  19. path are added to the registry.
  20. Example: ::
  21. class MyClass:
  22. pass
  23. register('MyClass', MyClass)
  24. # Registry:
  25. # {
  26. # 'MyClass': [path.to.MyClass],
  27. # 'path.to.MyClass': [path.to.MyClass],
  28. # }
  29. """
  30. # Module where the class is located
  31. module = cls.__module__
  32. # Full module path to the class
  33. # e.g. user.schemas.UserSchema
  34. fullpath = '.'.join([module, classname])
  35. # If the class is already registered; need to check if the entries are
  36. # in the same module as cls to avoid having multiple instances of the same
  37. # class in the registry
  38. if classname in _registry and not \
  39. any(each.__module__ == module for each in _registry[classname]):
  40. _registry[classname].append(cls)
  41. elif classname not in _registry:
  42. _registry[classname] = [cls]
  43. # Also register the full path
  44. if fullpath not in _registry:
  45. _registry.setdefault(fullpath, []).append(cls)
  46. else:
  47. # If fullpath does exist, replace existing entry
  48. _registry[fullpath] = [cls]
  49. return None
  50. def get_class(classname, all=False):
  51. """Retrieve a class from the registry.
  52. :raises: marshmallow.exceptions.RegistryError if the class cannot be found
  53. or if there are multiple entries for the given class name.
  54. """
  55. try:
  56. classes = _registry[classname]
  57. except KeyError:
  58. raise RegistryError('Class with name {0!r} was not found. You may need '
  59. 'to import the class.'.format(classname))
  60. if len(classes) > 1:
  61. if all:
  62. return _registry[classname]
  63. raise RegistryError('Multiple classes with name {0!r} '
  64. 'were found. Please use the full, '
  65. 'module-qualified path.'.format(classname))
  66. else:
  67. return _registry[classname][0]