utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # -*- coding: utf-8 -*-
  2. """Utility methods for marshmallow."""
  3. from __future__ import absolute_import, unicode_literals
  4. import collections
  5. import functools
  6. import datetime
  7. import inspect
  8. import json
  9. import re
  10. import time
  11. from calendar import timegm
  12. from email.utils import formatdate, parsedate
  13. from pprint import pprint as py_pprint
  14. from marshmallow.base import FieldABC
  15. from marshmallow.compat import binary_type, text_type, Mapping, Iterable
  16. from marshmallow.exceptions import FieldInstanceResolutionError
  17. EXCLUDE = 'exclude'
  18. INCLUDE = 'include'
  19. RAISE = 'raise'
  20. dateutil_available = False
  21. try:
  22. from dateutil import parser
  23. dateutil_available = True
  24. except ImportError:
  25. dateutil_available = False
  26. class _Missing(object):
  27. def __bool__(self):
  28. return False
  29. __nonzero__ = __bool__ # PY2 compat
  30. def __copy__(self):
  31. return self
  32. def __deepcopy__(self, _):
  33. return self
  34. def __repr__(self):
  35. return '<marshmallow.missing>'
  36. # Singleton value that indicates that a field's value is missing from input
  37. # dict passed to :meth:`Schema.load`. If the field's value is not required,
  38. # it's ``default`` value is used.
  39. missing = _Missing()
  40. def is_generator(obj):
  41. """Return True if ``obj`` is a generator
  42. """
  43. return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
  44. def is_iterable_but_not_string(obj):
  45. """Return True if ``obj`` is an iterable object that isn't a string."""
  46. return (
  47. (isinstance(obj, Iterable) and not hasattr(obj, 'strip')) or is_generator(obj)
  48. )
  49. def is_collection(obj):
  50. """Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
  51. return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)
  52. def is_instance_or_subclass(val, class_):
  53. """Return True if ``val`` is either a subclass or instance of ``class_``."""
  54. try:
  55. return issubclass(val, class_)
  56. except TypeError:
  57. return isinstance(val, class_)
  58. def is_keyed_tuple(obj):
  59. """Return True if ``obj`` has keyed tuple behavior, such as
  60. namedtuples or SQLAlchemy's KeyedTuples.
  61. """
  62. return isinstance(obj, tuple) and hasattr(obj, '_fields')
  63. def pprint(obj, *args, **kwargs):
  64. """Pretty-printing function that can pretty-print OrderedDicts
  65. like regular dictionaries. Useful for printing the output of
  66. :meth:`marshmallow.Schema.dump`.
  67. """
  68. if isinstance(obj, collections.OrderedDict):
  69. print(json.dumps(obj, *args, **kwargs))
  70. else:
  71. py_pprint(obj, *args, **kwargs)
  72. # From pytz: http://pytz.sourceforge.net/
  73. ZERO = datetime.timedelta(0)
  74. class UTC(datetime.tzinfo):
  75. """UTC
  76. Optimized UTC implementation. It unpickles using the single module global
  77. instance defined beneath this class declaration.
  78. """
  79. zone = 'UTC'
  80. _utcoffset = ZERO
  81. _dst = ZERO
  82. _tzname = zone
  83. def fromutc(self, dt):
  84. if dt.tzinfo is None:
  85. return self.localize(dt)
  86. return super(utc.__class__, self).fromutc(dt)
  87. def utcoffset(self, dt):
  88. return ZERO
  89. def tzname(self, dt):
  90. return 'UTC'
  91. def dst(self, dt):
  92. return ZERO
  93. def localize(self, dt, is_dst=False):
  94. """Convert naive time to local time"""
  95. if dt.tzinfo is not None:
  96. raise ValueError('Not naive datetime (tzinfo is already set)')
  97. return dt.replace(tzinfo=self)
  98. def normalize(self, dt, is_dst=False):
  99. """Correct the timezone information on the given datetime"""
  100. if dt.tzinfo is self:
  101. return dt
  102. if dt.tzinfo is None:
  103. raise ValueError('Naive time - no tzinfo set')
  104. return dt.astimezone(self)
  105. def __repr__(self):
  106. return '<UTC>'
  107. def __str__(self):
  108. return 'UTC'
  109. UTC = utc = UTC() # UTC is a singleton
  110. def local_rfcformat(dt):
  111. """Return the RFC822-formatted representation of a timezone-aware datetime
  112. with the UTC offset.
  113. """
  114. weekday = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'][dt.weekday()]
  115. month = [
  116. 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep',
  117. 'Oct', 'Nov', 'Dec',
  118. ][dt.month - 1]
  119. tz_offset = dt.strftime('%z')
  120. return '%s, %02d %s %04d %02d:%02d:%02d %s' % (
  121. weekday, dt.day, month,
  122. dt.year, dt.hour, dt.minute, dt.second, tz_offset,
  123. )
  124. def rfcformat(dt, localtime=False):
  125. """Return the RFC822-formatted representation of a datetime object.
  126. :param datetime dt: The datetime.
  127. :param bool localtime: If ``True``, return the date relative to the local
  128. timezone instead of UTC, displaying the proper offset,
  129. e.g. "Sun, 10 Nov 2013 08:23:45 -0600"
  130. """
  131. if not localtime:
  132. return formatdate(timegm(dt.utctimetuple()))
  133. else:
  134. return local_rfcformat(dt)
  135. # From Django
  136. _iso8601_datetime_re = re.compile(
  137. r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
  138. r'[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
  139. r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
  140. r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$',
  141. )
  142. _iso8601_date_re = re.compile(
  143. r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$',
  144. )
  145. _iso8601_time_re = re.compile(
  146. r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
  147. r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?',
  148. )
  149. def isoformat(dt, localtime=False, *args, **kwargs):
  150. """Return the ISO8601-formatted UTC representation of a datetime object."""
  151. if localtime and dt.tzinfo is not None:
  152. localized = dt
  153. else:
  154. if dt.tzinfo is None:
  155. localized = UTC.localize(dt)
  156. else:
  157. localized = dt.astimezone(UTC)
  158. return localized.isoformat(*args, **kwargs)
  159. def from_rfc(datestring, use_dateutil=True):
  160. """Parse a RFC822-formatted datetime string and return a datetime object.
  161. Use dateutil's parser if possible.
  162. https://stackoverflow.com/questions/885015/how-to-parse-a-rfc-2822-date-time-into-a-python-datetime
  163. """
  164. # Use dateutil's parser if possible
  165. if dateutil_available and use_dateutil:
  166. return parser.parse(datestring)
  167. else:
  168. parsed = parsedate(datestring) # as a tuple
  169. timestamp = time.mktime(parsed)
  170. return datetime.datetime.fromtimestamp(timestamp)
  171. def from_iso_datetime(datetimestring, use_dateutil=True):
  172. """Parse an ISO8601-formatted datetime string and return a datetime object.
  173. Use dateutil's parser if possible and return a timezone-aware datetime.
  174. """
  175. if not _iso8601_datetime_re.match(datetimestring):
  176. raise ValueError('Not a valid ISO8601-formatted datetime string')
  177. # Use dateutil's parser if possible
  178. if dateutil_available and use_dateutil:
  179. return parser.isoparse(datetimestring)
  180. else:
  181. # Strip off timezone info.
  182. return datetime.datetime.strptime(datetimestring[:19], '%Y-%m-%dT%H:%M:%S')
  183. def from_iso_time(timestring, use_dateutil=True):
  184. """Parse an ISO8601-formatted datetime string and return a datetime.time
  185. object.
  186. """
  187. if not _iso8601_time_re.match(timestring):
  188. raise ValueError('Not a valid ISO8601-formatted time string')
  189. if dateutil_available and use_dateutil:
  190. return parser.parse(timestring).time()
  191. else:
  192. if len(timestring) > 8: # has microseconds
  193. fmt = '%H:%M:%S.%f'
  194. else:
  195. fmt = '%H:%M:%S'
  196. return datetime.datetime.strptime(timestring, fmt).time()
  197. def from_iso_date(datestring, use_dateutil=True):
  198. if not _iso8601_date_re.match(datestring):
  199. raise ValueError('Not a valid ISO8601-formatted date string')
  200. if dateutil_available and use_dateutil:
  201. return parser.isoparse(datestring).date()
  202. else:
  203. return datetime.datetime.strptime(datestring[:10], '%Y-%m-%d').date()
  204. def to_iso_date(date, *args, **kwargs):
  205. return datetime.date.isoformat(date)
  206. def ensure_text_type(val):
  207. if isinstance(val, binary_type):
  208. val = val.decode('utf-8')
  209. return text_type(val)
  210. def pluck(dictlist, key):
  211. """Extracts a list of dictionary values from a list of dictionaries.
  212. ::
  213. >>> dlist = [{'id': 1, 'name': 'foo'}, {'id': 2, 'name': 'bar'}]
  214. >>> pluck(dlist, 'id')
  215. [1, 2]
  216. """
  217. return [d[key] for d in dictlist]
  218. # Various utilities for pulling keyed values from objects
  219. def get_value(obj, key, default=missing):
  220. """Helper for pulling a keyed value off various types of objects. Fields use
  221. this method by default to access attributes of the source object. For object `x`
  222. and attribute `i`, this method first tries to access `x[i]`, and then falls back to
  223. `x.i` if an exception is raised.
  224. .. warning::
  225. If an object `x` does not raise an exception when `x[i]` does not exist,
  226. `get_value` will never check the value `x.i`. Consider overriding
  227. `marshmallow.fields.Field.get_value` in this case.
  228. """
  229. if not isinstance(key, int) and '.' in key:
  230. return _get_value_for_keys(obj, key.split('.'), default)
  231. else:
  232. return _get_value_for_key(obj, key, default)
  233. def _get_value_for_keys(obj, keys, default):
  234. if len(keys) == 1:
  235. return _get_value_for_key(obj, keys[0], default)
  236. else:
  237. return _get_value_for_keys(
  238. _get_value_for_key(obj, keys[0], default), keys[1:], default,
  239. )
  240. def _get_value_for_key(obj, key, default):
  241. if not hasattr(obj, '__getitem__'):
  242. return getattr(obj, key, default)
  243. try:
  244. return obj[key]
  245. except (KeyError, IndexError, TypeError, AttributeError):
  246. return getattr(obj, key, default)
  247. def set_value(dct, key, value):
  248. """Set a value in a dict. If `key` contains a '.', it is assumed
  249. be a path (i.e. dot-delimited string) to the value's location.
  250. ::
  251. >>> d = {}
  252. >>> set_value(d, 'foo.bar', 42)
  253. >>> d
  254. {'foo': {'bar': 42}}
  255. """
  256. if '.' in key:
  257. head, rest = key.split('.', 1)
  258. target = dct.setdefault(head, {})
  259. if not isinstance(target, dict):
  260. raise ValueError(
  261. 'Cannot set {key} in {head} '
  262. 'due to existing value: {target}'.format(key=key, head=head, target=target),
  263. )
  264. set_value(target, rest, value)
  265. else:
  266. dct[key] = value
  267. def callable_or_raise(obj):
  268. """Check that an object is callable, else raise a :exc:`ValueError`.
  269. """
  270. if not callable(obj):
  271. raise ValueError('Object {0!r} is not callable.'.format(obj))
  272. return obj
  273. def _signature(func):
  274. if hasattr(inspect, 'signature'):
  275. return list(inspect.signature(func).parameters.keys())
  276. if hasattr(func, '__self__'):
  277. # Remove bound arg to match inspect.signature()
  278. return inspect.getargspec(func).args[1:]
  279. # All args are unbound
  280. return inspect.getargspec(func).args
  281. def get_func_args(func):
  282. """Given a callable, return a tuple of argument names. Handles
  283. `functools.partial` objects and class-based callables.
  284. .. versionchanged:: 3.0.0a1
  285. Do not return bound arguments, eg. ``self``.
  286. """
  287. if isinstance(func, functools.partial):
  288. return _signature(func.func)
  289. if inspect.isfunction(func) or inspect.ismethod(func):
  290. return _signature(func)
  291. # Callable class
  292. return _signature(func.__call__)
  293. def resolve_field_instance(cls_or_instance):
  294. """Return a Schema instance from a Schema class or instance.
  295. :param type|Schema cls_or_instance: Marshmallow Schema class or instance.
  296. """
  297. if isinstance(cls_or_instance, type):
  298. if not issubclass(cls_or_instance, FieldABC):
  299. raise FieldInstanceResolutionError
  300. return cls_or_instance()
  301. else:
  302. if not isinstance(cls_or_instance, FieldABC):
  303. raise FieldInstanceResolutionError
  304. return cls_or_instance