impl.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. from copy import copy
  19. from collections import Iterable, OrderedDict
  20. import attr
  21. import numpy as np
  22. import pandas
  23. import hypothesis.strategies as st
  24. import hypothesis.extra.numpy as npst
  25. import hypothesis.internal.conjecture.utils as cu
  26. from hypothesis.errors import InvalidArgument
  27. from hypothesis.control import reject
  28. from hypothesis.internal.compat import hrange
  29. from hypothesis.internal.coverage import check, check_function
  30. from hypothesis.internal.validation import check_type, try_convert, \
  31. check_strategy, check_valid_size, check_valid_interval
  32. try:
  33. from pandas.api.types import is_categorical_dtype
  34. except ImportError: # pragma: no cover
  35. def is_categorical_dtype(dt):
  36. if isinstance(dt, np.dtype):
  37. return False
  38. return dt == 'category'
  39. def dtype_for_elements_strategy(s):
  40. return st.shared(
  41. s.map(lambda x: pandas.Series([x]).dtype),
  42. key=('hypothesis.extra.pandas.dtype_for_elements_strategy', s),
  43. )
  44. def infer_dtype_if_necessary(dtype, values, elements, draw):
  45. if dtype is None and not values:
  46. return draw(dtype_for_elements_strategy(elements))
  47. return dtype
  48. @check_function
  49. def elements_and_dtype(elements, dtype, source=None):
  50. if source is None:
  51. prefix = ''
  52. else:
  53. prefix = '%s.' % (source,)
  54. if elements is not None:
  55. check_strategy(elements, '%selements' % (prefix,))
  56. else:
  57. with check('dtype is not None'):
  58. if dtype is None:
  59. raise InvalidArgument((
  60. 'At least one of %(prefix)selements or %(prefix)sdtype '
  61. 'must be provided.') % {'prefix': prefix})
  62. with check('is_categorical_dtype'):
  63. if is_categorical_dtype(dtype):
  64. raise InvalidArgument(
  65. '%sdtype is categorical, which is currently unsupported' % (
  66. prefix,
  67. ))
  68. dtype = try_convert(np.dtype, dtype, 'dtype')
  69. if elements is None:
  70. elements = npst.from_dtype(dtype)
  71. elif dtype is not None:
  72. def convert_element(value):
  73. name = 'draw(%selements)' % (prefix,)
  74. try:
  75. return np.array([value], dtype=dtype)[0]
  76. except TypeError:
  77. raise InvalidArgument(
  78. 'Cannot convert %s=%r of type %s to dtype %s' % (
  79. name, value, type(value).__name__, dtype.str
  80. )
  81. )
  82. except ValueError:
  83. raise InvalidArgument(
  84. 'Cannot convert %s=%r to type %s' % (
  85. name, value, dtype.str,
  86. )
  87. )
  88. elements = elements.map(convert_element)
  89. assert elements is not None
  90. return elements, dtype
  91. class ValueIndexStrategy(st.SearchStrategy):
  92. def __init__(self, elements, dtype, min_size, max_size, unique):
  93. super(ValueIndexStrategy, self).__init__()
  94. self.elements = elements
  95. self.dtype = dtype
  96. self.min_size = min_size
  97. self.max_size = max_size
  98. self.unique = unique
  99. def do_draw(self, data):
  100. result = []
  101. seen = set()
  102. iterator = cu.many(
  103. data, min_size=self.min_size, max_size=self.max_size,
  104. average_size=(self.min_size + self.max_size) / 2
  105. )
  106. while iterator.more():
  107. elt = data.draw(self.elements)
  108. if self.unique:
  109. if elt in seen:
  110. iterator.reject()
  111. continue
  112. seen.add(elt)
  113. result.append(elt)
  114. dtype = infer_dtype_if_necessary(
  115. dtype=self.dtype, values=result, elements=self.elements,
  116. draw=data.draw
  117. )
  118. return pandas.Index(result, dtype=dtype, tupleize_cols=False)
  119. DEFAULT_MAX_SIZE = 10
  120. @st.cacheable
  121. @st.defines_strategy
  122. def range_indexes(min_size=0, max_size=None):
  123. """Provides a strategy which generates an :class:`~pandas.Index` whose
  124. values are 0, 1, ..., n for some n.
  125. Arguments:
  126. * min_size is the smallest number of elements the index can have.
  127. * max_size is the largest number of elements the index can have. If None
  128. it will default to some suitable value based on min_size.
  129. """
  130. check_valid_size(min_size, 'min_size')
  131. check_valid_size(max_size, 'max_size')
  132. if max_size is None:
  133. max_size = min([min_size + DEFAULT_MAX_SIZE, 2 ** 63 - 1])
  134. check_valid_interval(min_size, max_size, 'min_size', 'max_size')
  135. return st.integers(min_size, max_size).map(pandas.RangeIndex)
  136. @st.cacheable
  137. @st.defines_strategy
  138. def indexes(
  139. elements=None, dtype=None, min_size=0, max_size=None, unique=True,
  140. ):
  141. """Provides a strategy for producing a :class:`pandas.Index`.
  142. Arguments:
  143. * elements is a strategy which will be used to generate the individual
  144. values of the index. If None, it will be inferred from the dtype. Note:
  145. even if the elements strategy produces tuples, the generated value
  146. will not be a MultiIndex, but instead be a normal index whose elements
  147. are tuples.
  148. * dtype is the dtype of the resulting index. If None, it will be inferred
  149. from the elements strategy. At least one of dtype or elements must be
  150. provided.
  151. * min_size is the minimum number of elements in the index.
  152. * max_size is the maximum number of elements in the index. If None then it
  153. will default to a suitable small size. If you want larger indexes you
  154. should pass a max_size explicitly.
  155. * unique specifies whether all of the elements in the resulting index
  156. should be distinct.
  157. """
  158. check_valid_size(min_size, 'min_size')
  159. check_valid_size(max_size, 'max_size')
  160. check_valid_interval(min_size, max_size, 'min_size', 'max_size')
  161. check_type(bool, unique, 'unique')
  162. elements, dtype = elements_and_dtype(elements, dtype)
  163. if max_size is None:
  164. max_size = min_size + DEFAULT_MAX_SIZE
  165. return ValueIndexStrategy(
  166. elements, dtype, min_size, max_size, unique)
  167. @st.defines_strategy
  168. def series(elements=None, dtype=None, index=None, fill=None, unique=False):
  169. """Provides a strategy for producing a :class:`pandas.Series`.
  170. Arguments:
  171. * elements: a strategy that will be used to generate the individual
  172. values in the series. If None, we will attempt to infer a suitable
  173. default from the dtype.
  174. * dtype: the dtype of the resulting series and may be any value
  175. that can be passed to :class:`numpy.dtype`. If None, will use
  176. pandas's standard behaviour to infer it from the type of the elements
  177. values. Note that if the type of values that comes out of your
  178. elements strategy varies, then so will the resulting dtype of the
  179. series.
  180. * index: If not None, a strategy for generating indexes for the
  181. resulting Series. This can generate either :class:`pandas.Index`
  182. objects or any sequence of values (which will be passed to the
  183. Index constructor).
  184. You will probably find it most convenient to use the
  185. :func:`~hypothesis.extra.pandas.indexes` or
  186. :func:`~hypothesis.extra.pandas.range_indexes` function to produce
  187. values for this argument.
  188. Usage:
  189. .. code-block:: pycon
  190. >>> series(dtype=int).example()
  191. 0 -2001747478
  192. 1 1153062837
  193. """
  194. if index is None:
  195. index = range_indexes()
  196. else:
  197. check_strategy(index)
  198. elements, dtype = elements_and_dtype(elements, dtype)
  199. index_strategy = index
  200. @st.composite
  201. def result(draw):
  202. index = draw(index_strategy)
  203. if len(index) > 0:
  204. if dtype is not None:
  205. result_data = draw(npst.arrays(
  206. dtype=dtype, elements=elements, shape=len(index),
  207. fill=fill, unique=unique,
  208. ))
  209. else:
  210. result_data = list(draw(npst.arrays(
  211. dtype=object, elements=elements, shape=len(index),
  212. fill=fill, unique=unique,
  213. )))
  214. return pandas.Series(
  215. result_data, index=index, dtype=dtype
  216. )
  217. else:
  218. return pandas.Series(
  219. (), index=index,
  220. dtype=dtype if dtype is not None else draw(
  221. dtype_for_elements_strategy(elements)))
  222. return result()
  223. @attr.s(slots=True)
  224. class column(object):
  225. """Data object for describing a column in a DataFrame.
  226. Arguments:
  227. * name: the column name, or None to default to the column position. Must
  228. be hashable, but can otherwise be any value supported as a pandas column
  229. name.
  230. * elements: the strategy for generating values in this column, or None
  231. to infer it from the dtype.
  232. * dtype: the dtype of the column, or None to infer it from the element
  233. strategy. At least one of dtype or elements must be provided.
  234. * fill: A default value for elements of the column. See
  235. :func:`~hypothesis.extra.numpy.arrays` for a full explanation.
  236. * unique: If all values in this column should be distinct.
  237. """
  238. name = attr.ib(default=None)
  239. elements = attr.ib(default=None)
  240. dtype = attr.ib(default=None)
  241. fill = attr.ib(default=None)
  242. unique = attr.ib(default=False)
  243. def columns(
  244. names_or_number, dtype=None, elements=None, fill=None, unique=False
  245. ):
  246. """A convenience function for producing a list of :class:`column` objects
  247. of the same general shape.
  248. The names_or_number argument is either a sequence of values, the
  249. elements of which will be used as the name for individual column
  250. objects, or a number, in which case that many unnamed columns will
  251. be created. All other arguments are passed through verbatim to
  252. create the columns.
  253. """
  254. try:
  255. names = list(names_or_number)
  256. except TypeError:
  257. names = [None] * names_or_number
  258. return [
  259. column(
  260. name=n, dtype=dtype, elements=elements, fill=fill, unique=unique
  261. ) for n in names
  262. ]
  263. @st.defines_strategy
  264. def data_frames(
  265. columns=None, rows=None, index=None
  266. ):
  267. """Provides a strategy for producing a :class:`pandas.DataFrame`.
  268. Arguments:
  269. * columns: An iterable of :class:`column` objects describing the shape
  270. of the generated DataFrame.
  271. * rows: A strategy for generating a row object. Should generate
  272. either dicts mapping column names to values or a sequence mapping
  273. column position to the value in that position (note that unlike the
  274. :class:`pandas.DataFrame` constructor, single values are not allowed
  275. here. Passing e.g. an integer is an error, even if there is only one
  276. column).
  277. At least one of rows and columns must be provided. If both are
  278. provided then the generated rows will be validated against the
  279. columns and an error will be raised if they don't match.
  280. Caveats on using rows:
  281. * In general you should prefer using columns to rows, and only use
  282. rows if the columns interface is insufficiently flexible to
  283. describe what you need - you will get better performance and
  284. example quality that way.
  285. * If you provide rows and not columns, then the shape and dtype of
  286. the resulting DataFrame may vary. e.g. if you have a mix of int
  287. and float in the values for one column in your row entries, the
  288. column will sometimes have an integral dtype and sometimes a float.
  289. * index: If not None, a strategy for generating indexes for the
  290. resulting DataFrame. This can generate either :class:`pandas.Index`
  291. objects or any sequence of values (which will be passed to the
  292. Index constructor).
  293. You will probably find it most convenient to use the
  294. :func:`~hypothesis.extra.pandas.indexes` or
  295. :func:`~hypothesis.extra.pandas.range_indexes` function to produce
  296. values for this argument.
  297. Usage:
  298. The expected usage pattern is that you use :class:`column` and
  299. :func:`columns` to specify a fixed shape of the DataFrame you want as
  300. follows. For example the following gives a two column data frame:
  301. .. code-block:: pycon
  302. >>> from hypothesis.extra.pandas import column, data_frames
  303. >>> data_frames([
  304. ... column('A', dtype=int), column('B', dtype=float)]).example()
  305. A B
  306. 0 2021915903 1.793898e+232
  307. 1 1146643993 inf
  308. 2 -2096165693 1.000000e+07
  309. If you want the values in different columns to interact in some way you
  310. can use the rows argument. For example the following gives a two column
  311. DataFrame where the value in the first column is always at most the value
  312. in the second:
  313. .. code-block:: pycon
  314. >>> from hypothesis.extra.pandas import column, data_frames
  315. >>> import hypothesis.strategies as st
  316. >>> data_frames(
  317. ... rows=st.tuples(st.floats(allow_nan=False),
  318. ... st.floats(allow_nan=False)).map(sorted)
  319. ... ).example()
  320. 0 1
  321. 0 -3.402823e+38 9.007199e+15
  322. 1 -1.562796e-298 5.000000e-01
  323. You can also combine the two:
  324. .. code-block:: pycon
  325. >>> from hypothesis.extra.pandas import columns, data_frames
  326. >>> import hypothesis.strategies as st
  327. >>> data_frames(
  328. ... columns=columns(["lo", "hi"], dtype=float),
  329. ... rows=st.tuples(st.floats(allow_nan=False),
  330. ... st.floats(allow_nan=False)).map(sorted)
  331. ... ).example()
  332. lo hi
  333. 0 9.314723e-49 4.353037e+45
  334. 1 -9.999900e-01 1.000000e+07
  335. 2 -2.152861e+134 -1.069317e-73
  336. (Note that the column dtype must still be specified and will not be
  337. inferred from the rows. This restriction may be lifted in future).
  338. Combining rows and columns has the following behaviour:
  339. * The column names and dtypes will be used.
  340. * If the column is required to be unique, this will be enforced.
  341. * Any values missing from the generated rows will be provided using the
  342. column's fill.
  343. * Any values in the row not present in the column specification (if
  344. dicts are passed, if there are keys with no corresponding column name,
  345. if sequences are passed if there are too many items) will result in
  346. InvalidArgument being raised.
  347. """
  348. if index is None:
  349. index = range_indexes()
  350. else:
  351. check_strategy(index)
  352. index_strategy = index
  353. if columns is None:
  354. if rows is None:
  355. raise InvalidArgument(
  356. 'At least one of rows and columns must be provided'
  357. )
  358. else:
  359. @st.composite
  360. def rows_only(draw):
  361. index = draw(index_strategy)
  362. @check_function
  363. def row():
  364. result = draw(rows)
  365. check_type(Iterable, result, 'draw(row)')
  366. return result
  367. if len(index) > 0:
  368. return pandas.DataFrame(
  369. [row() for _ in index],
  370. index=index
  371. )
  372. else:
  373. # If we haven't drawn any rows we need to draw one row and
  374. # then discard it so that we get a consistent shape for the
  375. # DataFrame.
  376. base = pandas.DataFrame([row()])
  377. return base.drop(0)
  378. return rows_only()
  379. assert columns is not None
  380. columns = try_convert(tuple, columns, 'columns')
  381. rewritten_columns = []
  382. column_names = set()
  383. for i, c in enumerate(columns):
  384. check_type(column, c, 'columns[%d]' % (i,))
  385. c = copy(c)
  386. if c.name is None:
  387. label = 'columns[%d]' % (i,)
  388. c.name = i
  389. else:
  390. label = c.name
  391. try:
  392. hash(c.name)
  393. except TypeError:
  394. raise InvalidArgument(
  395. 'Column names must be hashable, but columns[%d].name was '
  396. '%r of type %s, which cannot be hashed.' % (
  397. i, c.name, type(c.name).__name__,))
  398. if c.name in column_names:
  399. raise InvalidArgument(
  400. 'duplicate definition of column name %r' % (c.name,))
  401. column_names.add(c.name)
  402. c.elements, c.dtype = elements_and_dtype(
  403. c.elements, c.dtype, label
  404. )
  405. if c.dtype is None and rows is not None:
  406. raise InvalidArgument(
  407. 'Must specify a dtype for all columns when combining rows with'
  408. ' columns.'
  409. )
  410. c.fill = npst.fill_for(
  411. fill=c.fill, elements=c.elements, unique=c.unique,
  412. name=label
  413. )
  414. rewritten_columns.append(c)
  415. if rows is None:
  416. @st.composite
  417. def just_draw_columns(draw):
  418. index = draw(index_strategy)
  419. local_index_strategy = st.just(index)
  420. data = OrderedDict((c.name, None) for c in rewritten_columns)
  421. # Depending on how the columns are going to be generated we group
  422. # them differently to get better shrinking. For columns with fill
  423. # enabled, the elements can be shrunk independently of the size,
  424. # so we can just shrink by shrinking the index then shrinking the
  425. # length and are generally much more free to move data around.
  426. # For columns with no filling the problem is harder, and drawing
  427. # them like that would result in rows being very far apart from
  428. # each other in the underlying data stream, which gets in the way
  429. # of shrinking. So what we do is reorder and draw those columns
  430. # row wise, so that the values of each row are next to each other.
  431. # This makes life easier for the shrinker when deleting blocks of
  432. # data.
  433. columns_without_fill = [
  434. c for c in rewritten_columns if c.fill.is_empty]
  435. if columns_without_fill:
  436. for c in columns_without_fill:
  437. data[c.name] = pandas.Series(
  438. np.zeros(shape=len(index), dtype=c.dtype),
  439. index=index,
  440. )
  441. seen = {
  442. c.name: set() for c in columns_without_fill if c.unique}
  443. for i in hrange(len(index)):
  444. for c in columns_without_fill:
  445. if c.unique:
  446. for _ in range(5):
  447. value = draw(c.elements)
  448. if value not in seen[c.name]:
  449. seen[c.name].add(value)
  450. break
  451. else:
  452. reject()
  453. else:
  454. value = draw(c.elements)
  455. data[c.name][i] = value
  456. for c in rewritten_columns:
  457. if not c.fill.is_empty:
  458. data[c.name] = draw(series(
  459. index=local_index_strategy, dtype=c.dtype,
  460. elements=c.elements, fill=c.fill, unique=c.unique))
  461. return pandas.DataFrame(data, index=index)
  462. return just_draw_columns()
  463. else:
  464. @st.composite
  465. def assign_rows(draw):
  466. index = draw(index_strategy)
  467. result = pandas.DataFrame(OrderedDict(
  468. (c.name, pandas.Series(
  469. np.zeros(dtype=c.dtype, shape=len(index)), dtype=c.dtype))
  470. for c in rewritten_columns
  471. ), index=index)
  472. fills = {}
  473. any_unique = any(c.unique for c in rewritten_columns)
  474. if any_unique:
  475. all_seen = [
  476. set() if c.unique else None for c in rewritten_columns]
  477. while all_seen[-1] is None:
  478. all_seen.pop()
  479. for row_index in hrange(len(index)):
  480. for _ in hrange(5):
  481. original_row = draw(rows)
  482. row = original_row
  483. if isinstance(row, dict):
  484. as_list = [None] * len(rewritten_columns)
  485. for i, c in enumerate(rewritten_columns):
  486. try:
  487. as_list[i] = row[c.name]
  488. except KeyError:
  489. try:
  490. as_list[i] = fills[i]
  491. except KeyError:
  492. fills[i] = draw(c.fill)
  493. as_list[i] = fills[i]
  494. for k in row:
  495. if k not in column_names:
  496. raise InvalidArgument((
  497. 'Row %r contains column %r not in '
  498. 'columns %r)' % (
  499. row, k, [
  500. c.name for c in rewritten_columns
  501. ])))
  502. row = as_list
  503. if any_unique:
  504. has_duplicate = False
  505. for seen, value in zip(all_seen, row):
  506. if seen is None:
  507. continue
  508. if value in seen:
  509. has_duplicate = True
  510. break
  511. seen.add(value)
  512. if has_duplicate:
  513. continue
  514. row = list(try_convert(tuple, row, 'draw(rows)'))
  515. if len(row) > len(rewritten_columns):
  516. raise InvalidArgument((
  517. 'Row %r contains too many entries. Has %d but '
  518. 'expected at most %d') % (
  519. original_row, len(row), len(rewritten_columns)
  520. ))
  521. while len(row) < len(rewritten_columns):
  522. row.append(draw(rewritten_columns[len(row)].fill))
  523. result.iloc[row_index] = row
  524. break
  525. else:
  526. reject()
  527. return result
  528. return assign_rows()