recursive.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. from contextlib import contextmanager
  19. from hypothesis.errors import InvalidArgument
  20. from hypothesis.internal.lazyformat import lazyformat
  21. from hypothesis.internal.reflection import get_pretty_function_description
  22. from hypothesis.searchstrategy.strategies import OneOfStrategy, \
  23. SearchStrategy
  24. class LimitReached(BaseException):
  25. pass
  26. class LimitedStrategy(SearchStrategy):
  27. def __init__(self, strategy):
  28. super(LimitedStrategy, self).__init__()
  29. self.base_strategy = strategy
  30. self.marker = 0
  31. self.currently_capped = False
  32. def do_validate(self):
  33. self.base_strategy.validate()
  34. def do_draw(self, data):
  35. assert self.currently_capped
  36. if self.marker <= 0:
  37. raise LimitReached()
  38. self.marker -= 1
  39. return data.draw(self.base_strategy)
  40. @contextmanager
  41. def capped(self, max_templates):
  42. assert not self.currently_capped
  43. try:
  44. self.currently_capped = True
  45. self.marker = max_templates
  46. yield
  47. finally:
  48. self.currently_capped = False
  49. class RecursiveStrategy(SearchStrategy):
  50. def __init__(self, base, extend, max_leaves):
  51. self.max_leaves = max_leaves
  52. self.base = base
  53. self.limited_base = LimitedStrategy(base)
  54. self.extend = extend
  55. strategies = [self.limited_base, self.extend(self.limited_base)]
  56. while 2 ** len(strategies) <= max_leaves:
  57. strategies.append(
  58. extend(OneOfStrategy(tuple(strategies), bias=0.8)))
  59. self.strategy = OneOfStrategy(strategies)
  60. def __repr__(self):
  61. if not hasattr(self, '_cached_repr'):
  62. self._cached_repr = 'recursive(%r, %s, max_leaves=%d)' % (
  63. self.base, get_pretty_function_description(self.extend),
  64. self.max_leaves
  65. )
  66. return self._cached_repr
  67. def do_validate(self):
  68. if not isinstance(self.base, SearchStrategy):
  69. raise InvalidArgument(
  70. 'Expected base to be SearchStrategy but got %r' % (self.base,)
  71. )
  72. extended = self.extend(self.limited_base)
  73. if not isinstance(extended, SearchStrategy):
  74. raise InvalidArgument(
  75. 'Expected extend(%r) to be a SearchStrategy but got %r' % (
  76. self.limited_base, extended
  77. ))
  78. self.limited_base.validate()
  79. self.extend(self.limited_base).validate()
  80. def do_draw(self, data):
  81. count = 0
  82. while True:
  83. try:
  84. with self.limited_base.capped(self.max_leaves):
  85. return data.draw(self.strategy)
  86. except LimitReached:
  87. if count == 0:
  88. data.note_event(lazyformat(
  89. 'Draw for %r exceeded max_leaves '
  90. 'and had to be retried', self,))
  91. count += 1