data.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. import sys
  19. from enum import IntEnum
  20. from hypothesis.errors import Frozen, StopTest, InvalidArgument
  21. from hypothesis.internal.compat import hbytes, hrange, text_type, \
  22. bit_length, benchmark_time, int_from_bytes, unicode_safe_repr
  23. from hypothesis.internal.coverage import IN_COVERAGE_TESTS
  24. from hypothesis.internal.escalation import mark_for_escalation
  25. class Status(IntEnum):
  26. OVERRUN = 0
  27. INVALID = 1
  28. VALID = 2
  29. INTERESTING = 3
  30. global_test_counter = 0
  31. MAX_DEPTH = 100
  32. class ConjectureData(object):
  33. @classmethod
  34. def for_buffer(self, buffer):
  35. buffer = hbytes(buffer)
  36. return ConjectureData(
  37. max_length=len(buffer),
  38. draw_bytes=lambda data, n:
  39. hbytes(buffer[data.index:data.index + n])
  40. )
  41. def __init__(self, max_length, draw_bytes):
  42. self.max_length = max_length
  43. self.is_find = False
  44. self._draw_bytes = draw_bytes
  45. self.overdraw = 0
  46. self.level = 0
  47. self.block_starts = {}
  48. self.blocks = []
  49. self.buffer = bytearray()
  50. self.output = u''
  51. self.status = Status.VALID
  52. self.frozen = False
  53. self.intervals_by_level = []
  54. self.interval_stack = []
  55. global global_test_counter
  56. self.testcounter = global_test_counter
  57. global_test_counter += 1
  58. self.start_time = benchmark_time()
  59. self.events = set()
  60. self.forced_indices = set()
  61. self.capped_indices = {}
  62. self.interesting_origin = None
  63. self.tags = set()
  64. self.draw_times = []
  65. self.__intervals = None
  66. self.shrinking_blocks = set()
  67. self.discarded = set()
  68. def __assert_not_frozen(self, name):
  69. if self.frozen:
  70. raise Frozen(
  71. 'Cannot call %s on frozen ConjectureData' % (
  72. name,))
  73. def add_tag(self, tag):
  74. self.tags.add(tag)
  75. @property
  76. def depth(self):
  77. return len(self.interval_stack)
  78. @property
  79. def index(self):
  80. return len(self.buffer)
  81. def note(self, value):
  82. self.__assert_not_frozen('note')
  83. if not isinstance(value, text_type):
  84. value = unicode_safe_repr(value)
  85. self.output += value
  86. def draw(self, strategy):
  87. if self.is_find and not strategy.supports_find:
  88. raise InvalidArgument((
  89. 'Cannot use strategy %r within a call to find (presumably '
  90. 'because it would be invalid after the call had ended).'
  91. ) % (strategy,))
  92. if strategy.is_empty:
  93. self.mark_invalid()
  94. if self.depth >= MAX_DEPTH:
  95. self.mark_invalid()
  96. if self.depth == 0 and not IN_COVERAGE_TESTS: # pragma: no cover
  97. original_tracer = sys.gettrace()
  98. try:
  99. sys.settrace(None)
  100. return self.__draw(strategy)
  101. finally:
  102. sys.settrace(original_tracer)
  103. else:
  104. return self.__draw(strategy)
  105. def __draw(self, strategy):
  106. at_top_level = self.depth == 0
  107. self.start_example()
  108. try:
  109. if not at_top_level:
  110. return strategy.do_draw(self)
  111. else:
  112. start_time = benchmark_time()
  113. try:
  114. return strategy.do_draw(self)
  115. except BaseException as e:
  116. mark_for_escalation(e)
  117. raise
  118. finally:
  119. self.draw_times.append(benchmark_time() - start_time)
  120. finally:
  121. if not self.frozen:
  122. self.stop_example()
  123. def start_example(self):
  124. self.__assert_not_frozen('start_example')
  125. self.interval_stack.append(self.index)
  126. self.level += 1
  127. def stop_example(self, discard=False):
  128. if self.frozen:
  129. return
  130. self.level -= 1
  131. while self.level >= len(self.intervals_by_level):
  132. self.intervals_by_level.append([])
  133. k = self.interval_stack.pop()
  134. if k != self.index:
  135. t = (k, self.index)
  136. self.intervals_by_level[self.level].append(t)
  137. if discard:
  138. self.discarded.add(t)
  139. def note_event(self, event):
  140. self.events.add(event)
  141. @property
  142. def intervals(self):
  143. assert self.frozen
  144. if self.__intervals is None:
  145. intervals = set(self.blocks)
  146. for l in self.intervals_by_level:
  147. intervals.update(l)
  148. for i in hrange(len(l) - 1):
  149. if (
  150. l[i] not in self.discarded and
  151. l[i + 1] not in self.discarded and
  152. l[i][1] == l[i + 1][0]
  153. ):
  154. intervals.add((l[i][0], l[i + 1][1]))
  155. for i in hrange(len(self.blocks) - 1):
  156. intervals.add((self.blocks[i][0], self.blocks[i + 1][1]))
  157. # Intervals are sorted as longest first, then by interval start.
  158. self.__intervals = tuple(sorted(
  159. set(intervals),
  160. key=lambda se: (se[0] - se[1], se[0])
  161. ))
  162. del self.intervals_by_level
  163. return self.__intervals
  164. def freeze(self):
  165. if self.frozen:
  166. assert isinstance(self.buffer, hbytes)
  167. return
  168. self.frozen = True
  169. self.finish_time = benchmark_time()
  170. self.buffer = hbytes(self.buffer)
  171. self.events = frozenset(self.events)
  172. del self._draw_bytes
  173. def draw_bits(self, n):
  174. self.__assert_not_frozen('draw_bits')
  175. if n == 0:
  176. result = 0
  177. elif n % 8 == 0:
  178. return int_from_bytes(self.draw_bytes(n // 8))
  179. else:
  180. n_bytes = (n // 8) + 1
  181. self.__check_capacity(n_bytes)
  182. buf = bytearray(self._draw_bytes(self, n_bytes))
  183. assert len(buf) == n_bytes
  184. mask = (1 << (n % 8)) - 1
  185. buf[0] &= mask
  186. self.capped_indices[self.index] = mask
  187. buf = hbytes(buf)
  188. self.__write(buf)
  189. result = int_from_bytes(buf)
  190. assert bit_length(result) <= n
  191. return result
  192. def write(self, string):
  193. self.__assert_not_frozen('write')
  194. self.__check_capacity(len(string))
  195. assert isinstance(string, hbytes)
  196. original = self.index
  197. self.__write(string)
  198. self.forced_indices.update(hrange(original, self.index))
  199. return string
  200. def __check_capacity(self, n):
  201. if self.index + n > self.max_length:
  202. self.overdraw = self.index + n - self.max_length
  203. self.status = Status.OVERRUN
  204. self.freeze()
  205. raise StopTest(self.testcounter)
  206. def __write(self, result):
  207. initial = self.index
  208. n = len(result)
  209. self.block_starts.setdefault(n, []).append(initial)
  210. self.blocks.append((initial, initial + n))
  211. assert len(result) == n
  212. assert self.index == initial
  213. self.buffer.extend(result)
  214. def draw_bytes(self, n):
  215. self.__assert_not_frozen('draw_bytes')
  216. if n == 0:
  217. return hbytes(b'')
  218. self.__check_capacity(n)
  219. result = self._draw_bytes(self, n)
  220. assert len(result) == n
  221. self.__write(result)
  222. return hbytes(result)
  223. def mark_interesting(self, interesting_origin=None):
  224. self.__assert_not_frozen('mark_interesting')
  225. self.interesting_origin = interesting_origin
  226. self.status = Status.INTERESTING
  227. self.freeze()
  228. raise StopTest(self.testcounter)
  229. def mark_invalid(self):
  230. self.__assert_not_frozen('mark_invalid')
  231. self.status = Status.INVALID
  232. self.freeze()
  233. raise StopTest(self.testcounter)