numpy.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. import math
  19. import numpy as np
  20. import hypothesis.strategies as st
  21. import hypothesis.internal.conjecture.utils as cu
  22. from hypothesis.errors import InvalidArgument
  23. from hypothesis.searchstrategy import SearchStrategy
  24. from hypothesis.internal.compat import hrange, text_type
  25. from hypothesis.internal.coverage import check_function
  26. from hypothesis.internal.reflection import proxies
  27. TIME_RESOLUTIONS = tuple('Y M D h m s ms us ns ps fs as'.split())
  28. @st.defines_strategy_with_reusable_values
  29. def from_dtype(dtype):
  30. # Compound datatypes, eg 'f4,f4,f4'
  31. if dtype.names is not None:
  32. # mapping np.void.type over a strategy is nonsense, so return now.
  33. return st.tuples(
  34. *[from_dtype(dtype.fields[name][0]) for name in dtype.names])
  35. # Subarray datatypes, eg '(2, 3)i4'
  36. if dtype.subdtype is not None:
  37. subtype, shape = dtype.subdtype
  38. return arrays(subtype, shape)
  39. # Scalar datatypes
  40. if dtype.kind == u'b':
  41. result = st.booleans()
  42. elif dtype.kind == u'f':
  43. result = st.floats()
  44. elif dtype.kind == u'c':
  45. result = st.complex_numbers()
  46. elif dtype.kind in (u'S', u'a'):
  47. # Numpy strings are null-terminated; only allow round-trippable values.
  48. # `itemsize == 0` means 'fixed length determined at array creation'
  49. result = st.binary(max_size=dtype.itemsize or None
  50. ).filter(lambda b: b[-1:] != b'\0')
  51. elif dtype.kind == u'u':
  52. result = st.integers(min_value=0,
  53. max_value=2 ** (8 * dtype.itemsize) - 1)
  54. elif dtype.kind == u'i':
  55. overflow = 2 ** (8 * dtype.itemsize - 1)
  56. result = st.integers(min_value=-overflow, max_value=overflow - 1)
  57. elif dtype.kind == u'U':
  58. # Encoded in UTF-32 (four bytes/codepoint) and null-terminated
  59. result = st.text(max_size=(dtype.itemsize or 0) // 4 or None
  60. ).filter(lambda b: b[-1:] != u'\0')
  61. elif dtype.kind in (u'm', u'M'):
  62. if '[' in dtype.str:
  63. res = st.just(dtype.str.split('[')[-1][:-1])
  64. else:
  65. res = st.sampled_from(TIME_RESOLUTIONS)
  66. result = st.builds(dtype.type, st.integers(-2**63, 2**63 - 1), res)
  67. else:
  68. raise InvalidArgument(u'No strategy inference for {}'.format(dtype))
  69. return result.map(dtype.type)
  70. @check_function
  71. def check_argument(condition, fail_message, *f_args, **f_kwargs):
  72. if not condition:
  73. raise InvalidArgument(fail_message.format(*f_args, **f_kwargs))
  74. @check_function
  75. def order_check(name, floor, small, large):
  76. check_argument(
  77. floor <= small, u'min_{name} must be at least {} but was {}',
  78. floor, small, name=name
  79. )
  80. check_argument(
  81. small <= large, u'min_{name}={} is larger than max_{name}={}',
  82. small, large, name=name
  83. )
  84. class ArrayStrategy(SearchStrategy):
  85. def __init__(self, element_strategy, shape, dtype, fill, unique):
  86. self.shape = tuple(shape)
  87. self.fill = fill
  88. check_argument(shape,
  89. u'Array shape must have at least one dimension, '
  90. u'provided shape was {}', shape)
  91. check_argument(all(isinstance(s, int) for s in shape),
  92. u'Array shape must be integer in each dimension, '
  93. u'provided shape was {}', shape)
  94. self.array_size = int(np.prod(shape))
  95. self.dtype = dtype
  96. self.element_strategy = element_strategy
  97. self.unique = unique
  98. def do_draw(self, data):
  99. if 0 in self.shape:
  100. return np.zeros(dtype=self.dtype, shape=self.shape)
  101. # This could legitimately be a np.empty, but the performance gains for
  102. # that would be so marginal that there's really not much point risking
  103. # undefined behaviour shenanigans.
  104. result = np.zeros(shape=self.array_size, dtype=self.dtype)
  105. if self.fill.is_empty:
  106. # We have no fill value (either because the user explicitly
  107. # disabled it or because the default behaviour was used and our
  108. # elements strategy does not produce reusable values), so we must
  109. # generate a fully dense array with a freshly drawn value for each
  110. # entry.
  111. if self.unique:
  112. seen = set()
  113. elements = cu.many(
  114. data,
  115. min_size=self.array_size, max_size=self.array_size,
  116. average_size=self.array_size
  117. )
  118. i = 0
  119. while elements.more():
  120. # We assign first because this means we check for
  121. # uniqueness after numpy has converted it to the relevant
  122. # type for us. Because we don't increment the counter on
  123. # a duplicate we will overwrite it on the next draw.
  124. result[i] = data.draw(self.element_strategy)
  125. if result[i] not in seen:
  126. seen.add(result[i])
  127. i += 1
  128. else:
  129. elements.reject()
  130. else:
  131. for i in hrange(len(result)):
  132. result[i] = data.draw(self.element_strategy)
  133. else:
  134. # We draw numpy arrays as "sparse with an offset". We draw a
  135. # collection of index assignments within the array and assign
  136. # fresh values from our elements strategy to those indices. If at
  137. # the end we have not assigned every element then we draw a single
  138. # value from our fill strategy and use that to populate the
  139. # remaining positions with that strategy.
  140. elements = cu.many(
  141. data,
  142. min_size=0, max_size=self.array_size,
  143. # sqrt isn't chosen for any particularly principled reason. It
  144. # just grows reasonably quickly but sublinearly, and for small
  145. # arrays it represents a decent fraction of the array size.
  146. average_size=math.sqrt(self.array_size),
  147. )
  148. needs_fill = np.full(self.array_size, True)
  149. seen = set()
  150. while elements.more():
  151. i = cu.integer_range(data, 0, self.array_size - 1)
  152. if not needs_fill[i]:
  153. elements.reject()
  154. continue
  155. result[i] = data.draw(self.element_strategy)
  156. if self.unique:
  157. if result[i] in seen:
  158. elements.reject()
  159. continue
  160. else:
  161. seen.add(result[i])
  162. needs_fill[i] = False
  163. if needs_fill.any():
  164. # We didn't fill all of the indices in the early loop, so we
  165. # put a fill value into the rest.
  166. # We have to do this hilarious little song and dance to work
  167. # around numpy's special handling of iterable values. If the
  168. # value here were e.g. a tuple then neither array creation
  169. # nor putmask would do the right thing. But by creating an
  170. # array of size one and then assigning the fill value as a
  171. # single element, we both get an array with the right value in
  172. # it and putmask will do the right thing by repeating the
  173. # values of the array across the mask.
  174. one_element = np.zeros(shape=1, dtype=self.dtype)
  175. one_element[0] = data.draw(self.fill)
  176. fill_value = one_element[0]
  177. if self.unique:
  178. try:
  179. is_nan = np.isnan(fill_value)
  180. except TypeError:
  181. is_nan = False
  182. if not is_nan:
  183. raise InvalidArgument(
  184. 'Cannot fill unique array with non-NaN '
  185. 'value %r' % (fill_value,))
  186. np.putmask(result, needs_fill, one_element)
  187. return result.reshape(self.shape)
  188. @check_function
  189. def fill_for(elements, unique, fill, name=''):
  190. if fill is None:
  191. if unique or not elements.has_reusable_values:
  192. fill = st.nothing()
  193. else:
  194. fill = elements
  195. else:
  196. st.check_strategy(fill, '%s.fill' % (name,) if name else 'fill')
  197. return fill
  198. @st.composite
  199. def arrays(
  200. draw, dtype, shape, elements=None, fill=None, unique=False
  201. ):
  202. """Returns a strategy for generating :class:`numpy's
  203. ndarrays<numpy.ndarray>`.
  204. * ``dtype`` may be any valid input to :class:`numpy.dtype <numpy.dtype>`
  205. (this includes ``dtype`` objects), or a strategy that generates such
  206. values.
  207. * ``shape`` may be an integer >= 0, a tuple of length >= 0 of such
  208. integers, or a strategy that generates such values.
  209. * ``elements`` is a strategy for generating values to put in the array.
  210. If it is None a suitable value will be inferred based on the dtype,
  211. which may give any legal value (including eg ``NaN`` for floats).
  212. If you have more specific requirements, you should supply your own
  213. elements strategy.
  214. * ``fill`` is a strategy that may be used to generate a single background
  215. value for the array. If None, a suitable default will be inferred
  216. based on the other arguments. If set to
  217. :func:`st.nothing() <hypothesis.strategies.nothing>` then filling
  218. behaviour will be disabled entirely and every element will be generated
  219. independently.
  220. * ``unique`` specifies if the elements of the array should all be
  221. distinct from one another. Note that in this case multiple NaN values
  222. may still be allowed. If fill is also set, the only valid values for
  223. it to return are NaN values (anything for which :func:`numpy.isnan`
  224. returns True. So e.g. for complex numbers (nan+1j) is also a valid fill).
  225. Note that if unique is set to True the generated values must be hashable.
  226. Arrays of specified ``dtype`` and ``shape`` are generated for example
  227. like this:
  228. .. code-block:: pycon
  229. >>> import numpy as np
  230. >>> arrays(np.int8, (2, 3)).example()
  231. array([[-8, 6, 3],
  232. [-6, 4, 6]], dtype=int8)
  233. - See :doc:`What you can generate and how <data>`.
  234. .. code-block:: pycon
  235. >>> import numpy as np
  236. >>> from hypothesis.strategies import floats
  237. >>> arrays(np.float, 3, elements=floats(0, 1)).example()
  238. array([ 0.88974794, 0.77387938, 0.1977879 ])
  239. Array values are generated in two parts:
  240. 1. Some subset of the coordinates of the array are populated with a value
  241. drawn from the elements strategy (or its inferred form).
  242. 2. If any coordinates were not assigned in the previous step, a single
  243. value is drawn from the fill strategy and is assigned to all remaining
  244. places.
  245. You can set fill to :func:`~hypothesis.strategies.nothing` if you want to
  246. disable this behaviour and draw a value for every element.
  247. If fill is set to None then it will attempt to infer the correct behaviour
  248. automatically: If unique is True, no filling will occur by default.
  249. Otherwise, if it looks safe to reuse the values of elements across
  250. multiple coordinates (this will be the case for any inferred strategy, and
  251. for most of the builtins, but is not the case for mutable values or
  252. strategies built with flatmap, map, composite, etc) then it will use the
  253. elements strategy as the fill, else it will default to having no fill.
  254. Having a fill helps Hypothesis craft high quality examples, but its
  255. main importance is when the array generated is large: Hypothesis is
  256. primarily designed around testing small examples. If you have arrays with
  257. hundreds or more elements, having a fill value is essential if you want
  258. your tests to run in reasonable time.
  259. """
  260. if isinstance(dtype, SearchStrategy):
  261. dtype = draw(dtype)
  262. dtype = np.dtype(dtype)
  263. if elements is None:
  264. elements = from_dtype(dtype)
  265. if isinstance(shape, SearchStrategy):
  266. shape = draw(shape)
  267. if isinstance(shape, int):
  268. shape = (shape,)
  269. shape = tuple(shape)
  270. if not shape:
  271. if dtype.kind != u'O':
  272. return draw(elements)
  273. fill = fill_for(
  274. elements=elements, unique=unique, fill=fill
  275. )
  276. return draw(ArrayStrategy(elements, shape, dtype, fill, unique))
  277. @st.defines_strategy
  278. def array_shapes(min_dims=1, max_dims=3, min_side=1, max_side=10):
  279. """Return a strategy for array shapes (tuples of int >= 1)."""
  280. order_check('dims', 1, min_dims, max_dims)
  281. order_check('side', 1, min_side, max_side)
  282. return st.lists(st.integers(min_side, max_side),
  283. min_size=min_dims, max_size=max_dims).map(tuple)
  284. @st.defines_strategy
  285. def scalar_dtypes():
  286. """Return a strategy that can return any non-flexible scalar dtype."""
  287. return st.one_of(boolean_dtypes(),
  288. integer_dtypes(), unsigned_integer_dtypes(),
  289. floating_dtypes(), complex_number_dtypes(),
  290. datetime64_dtypes(), timedelta64_dtypes())
  291. def defines_dtype_strategy(strat):
  292. @st.defines_strategy
  293. @proxies(strat)
  294. def inner(*args, **kwargs):
  295. return strat(*args, **kwargs).map(np.dtype)
  296. return inner
  297. @defines_dtype_strategy
  298. def boolean_dtypes():
  299. return st.just('?')
  300. def dtype_factory(kind, sizes, valid_sizes, endianness):
  301. # Utility function, shared logic for most integer and string types
  302. valid_endian = ('?', '<', '=', '>')
  303. check_argument(endianness in valid_endian,
  304. u'Unknown endianness: was {}, must be in {}',
  305. endianness, valid_endian)
  306. if valid_sizes is not None:
  307. if isinstance(sizes, int):
  308. sizes = (sizes,)
  309. check_argument(sizes, 'Dtype must have at least one possible size.')
  310. check_argument(all(s in valid_sizes for s in sizes),
  311. u'Invalid sizes: was {} must be an item or sequence '
  312. u'in {}', sizes, valid_sizes)
  313. if all(isinstance(s, int) for s in sizes):
  314. sizes = sorted(set(s // 8 for s in sizes))
  315. strat = st.sampled_from(sizes)
  316. if '{}' not in kind:
  317. kind += '{}'
  318. if endianness == '?':
  319. return strat.map(('<' + kind).format) | strat.map(('>' + kind).format)
  320. return strat.map((endianness + kind).format)
  321. @defines_dtype_strategy
  322. def unsigned_integer_dtypes(endianness='?', sizes=(8, 16, 32, 64)):
  323. """Return a strategy for unsigned integer dtypes.
  324. endianness may be ``<`` for little-endian, ``>`` for big-endian,
  325. ``=`` for native byte order, or ``?`` to allow either byte order.
  326. This argument only applies to dtypes of more than one byte.
  327. sizes must be a collection of integer sizes in bits. The default
  328. (8, 16, 32, 64) covers the full range of sizes.
  329. """
  330. return dtype_factory('u', sizes, (8, 16, 32, 64), endianness)
  331. @defines_dtype_strategy
  332. def integer_dtypes(endianness='?', sizes=(8, 16, 32, 64)):
  333. """Return a strategy for signed integer dtypes.
  334. endianness and sizes are treated as for
  335. :func:`unsigned_integer_dtypes`.
  336. """
  337. return dtype_factory('i', sizes, (8, 16, 32, 64), endianness)
  338. @defines_dtype_strategy
  339. def floating_dtypes(endianness='?', sizes=(16, 32, 64)):
  340. """Return a strategy for floating-point dtypes.
  341. sizes is the size in bits of floating-point number. Some machines support
  342. 96- or 128-bit floats, but these are not generated by default.
  343. Larger floats (96 and 128 bit real parts) are not supported on all
  344. platforms and therefore disabled by default. To generate these dtypes,
  345. include these values in the sizes argument.
  346. """
  347. return dtype_factory('f', sizes, (16, 32, 64, 96, 128), endianness)
  348. @defines_dtype_strategy
  349. def complex_number_dtypes(endianness='?', sizes=(64, 128)):
  350. """Return a strategy for complex-number dtypes.
  351. sizes is the total size in bits of a complex number, which consists
  352. of two floats. Complex halfs (a 16-bit real part) are not supported
  353. by numpy and will not be generated by this strategy.
  354. """
  355. return dtype_factory('c', sizes, (64, 128, 192, 256), endianness)
  356. @check_function
  357. def validate_time_slice(max_period, min_period):
  358. check_argument(max_period in TIME_RESOLUTIONS,
  359. u'max_period {} must be a valid resolution in {}',
  360. max_period, TIME_RESOLUTIONS)
  361. check_argument(min_period in TIME_RESOLUTIONS,
  362. u'min_period {} must be a valid resolution in {}',
  363. min_period, TIME_RESOLUTIONS)
  364. start = TIME_RESOLUTIONS.index(max_period)
  365. end = TIME_RESOLUTIONS.index(min_period) + 1
  366. check_argument(start < end,
  367. u'max_period {} must be earlier in sequence {} than '
  368. u'min_period {}', max_period, TIME_RESOLUTIONS, min_period)
  369. return TIME_RESOLUTIONS[start:end]
  370. @defines_dtype_strategy
  371. def datetime64_dtypes(max_period='Y', min_period='ns', endianness='?'):
  372. """Return a strategy for datetime64 dtypes, with various precisions from
  373. year to attosecond."""
  374. return dtype_factory('datetime64[{}]',
  375. validate_time_slice(max_period, min_period),
  376. TIME_RESOLUTIONS, endianness)
  377. @defines_dtype_strategy
  378. def timedelta64_dtypes(max_period='Y', min_period='ns', endianness='?'):
  379. """Return a strategy for timedelta64 dtypes, with various precisions from
  380. year to attosecond."""
  381. return dtype_factory('timedelta64[{}]',
  382. validate_time_slice(max_period, min_period),
  383. TIME_RESOLUTIONS, endianness)
  384. @defines_dtype_strategy
  385. def byte_string_dtypes(endianness='?', min_len=0, max_len=16):
  386. """Return a strategy for generating bytestring dtypes, of various lengths
  387. and byteorder."""
  388. order_check('len', 0, min_len, max_len)
  389. return dtype_factory('S', list(range(min_len, max_len + 1)),
  390. None, endianness)
  391. @defines_dtype_strategy
  392. def unicode_string_dtypes(endianness='?', min_len=0, max_len=16):
  393. """Return a strategy for generating unicode string dtypes, of various
  394. lengths and byteorder."""
  395. order_check('len', 0, min_len, max_len)
  396. return dtype_factory('U', list(range(min_len, max_len + 1)),
  397. None, endianness)
  398. @defines_dtype_strategy
  399. def array_dtypes(subtype_strategy=scalar_dtypes(),
  400. min_size=1, max_size=5, allow_subarrays=False):
  401. """Return a strategy for generating array (compound) dtypes, with members
  402. drawn from the given subtype strategy."""
  403. order_check('size', 0, min_size, max_size)
  404. native_strings = st.text if text_type is str else st.binary
  405. elements = st.tuples(native_strings(), subtype_strategy)
  406. if allow_subarrays:
  407. elements |= st.tuples(native_strings(), subtype_strategy,
  408. array_shapes(max_dims=2, max_side=2))
  409. return st.lists(elements=elements, min_size=min_size, max_size=max_size,
  410. unique_by=lambda d: d[0])
  411. @st.defines_strategy
  412. def nested_dtypes(subtype_strategy=scalar_dtypes(),
  413. max_leaves=10, max_itemsize=None):
  414. """Return the most-general dtype strategy.
  415. Elements drawn from this strategy may be simple (from the
  416. subtype_strategy), or several such values drawn from
  417. :func:`array_dtypes` with ``allow_subarrays=True``. Subdtypes in an
  418. array dtype may be nested to any depth, subject to the max_leaves
  419. argument.
  420. """
  421. return st.recursive(subtype_strategy,
  422. lambda x: array_dtypes(x, allow_subarrays=True),
  423. max_leaves).filter(
  424. lambda d: max_itemsize is None or d.itemsize <= max_itemsize)