floats.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. from array import array
  19. from hypothesis.internal.compat import hbytes, hrange, int_to_bytes
  20. from hypothesis.internal.floats import float_to_int, int_to_float
  21. """
  22. This module implements support for arbitrary floating point numbers in
  23. Conjecture. It doesn't make any attempt to get a good distribution, only to
  24. get a format that will shrink well.
  25. It works by defining an encoding of non-negative floating point numbers
  26. (including NaN values with a zero sign bit) that has good lexical shrinking
  27. properties.
  28. This encoding is a tagged union of two separate encodings for floating point
  29. numbers, with the tag being the first bit of 64 and the remaining 63-bits being
  30. the payload.
  31. If the tag bit is 0, the next 7 bits are ignored, and the remaining 7 bytes are
  32. interpreted as a 7 byte integer in big-endian order and then converted to a
  33. float (there is some redundancy here, as 7 * 8 = 56, which is larger than the
  34. largest integer that floating point numbers can represent exactly, so multiple
  35. encodings may map to the same float).
  36. If the tag bit is 1, we instead use somemthing that is closer to the normal
  37. representation of floats (and can represent every non-negative float exactly)
  38. but has a better ordering:
  39. 1. NaNs are ordered after everything else.
  40. 2. Infinity is ordered after every finite number.
  41. 3. The sign is ignored unless two floating point numbers are identical in
  42. absolute magnitude. In that case, the positive is ordered before the
  43. negative.
  44. 4. Positive floating point numbers are ordered first by int(x) where
  45. encoding(x) < encoding(y) if int(x) < int(y).
  46. 5. If int(x) == int(y) then x and y are sorted towards lower denominators of
  47. their fractional parts.
  48. The format of this encoding of floating point goes as follows:
  49. [exponent] [mantissa]
  50. Each of these is the same size their equivalent in IEEE floating point, but are
  51. in a different format.
  52. We translate exponents as follows:
  53. 1. The maximum exponent (2 ** 11 - 1) is left unchanged.
  54. 2. We reorder the remaining exponents so that all of the positive exponents
  55. are first, in increasing order, followed by all of the negative
  56. exponents in decreasing order (where positive/negative is done by the
  57. unbiased exponent e - 1023).
  58. We translate the mantissa as follows:
  59. 1. If the unbiased exponent is <= 0 we reverse it bitwise.
  60. 2. If the unbiased exponent is >= 52 we leave it alone.
  61. 3. If the unbiased exponent is in the range [1, 51] then we reverse the
  62. low k bits, where k is 52 - unbiased exponen.
  63. The low bits correspond to the fractional part of the floating point number.
  64. Reversing it bitwise means that we try to minimize the low bits, which kills
  65. off the higher powers of 2 in the fraction first.
  66. """
  67. MAX_EXPONENT = 0x7ff
  68. SPECIAL_EXPONENTS = (0, MAX_EXPONENT)
  69. BIAS = 1023
  70. MAX_POSITIVE_EXPONENT = (MAX_EXPONENT - 1 - BIAS)
  71. def exponent_key(e):
  72. if e == MAX_EXPONENT:
  73. return float('inf')
  74. unbiased = e - BIAS
  75. if unbiased < 0:
  76. return 10000 - unbiased
  77. else:
  78. return unbiased
  79. ENCODING_TABLE = array('H', sorted(hrange(MAX_EXPONENT + 1), key=exponent_key))
  80. DECODING_TABLE = array('H', [0]) * len(ENCODING_TABLE)
  81. for i, b in enumerate(ENCODING_TABLE):
  82. DECODING_TABLE[b] = i
  83. del i, b
  84. def decode_exponent(e):
  85. """Take draw_bits(11) and turn it into a suitable floating point exponent
  86. such that lexicographically simpler leads to simpler floats."""
  87. assert 0 <= e <= MAX_EXPONENT
  88. return ENCODING_TABLE[e]
  89. def encode_exponent(e):
  90. """Take a floating point exponent and turn it back into the equivalent
  91. result from conjecture."""
  92. assert 0 <= e <= MAX_EXPONENT
  93. return DECODING_TABLE[e]
  94. def reverse_byte(b):
  95. result = 0
  96. for _ in range(8):
  97. result <<= 1
  98. result |= (b & 1)
  99. b >>= 1
  100. return result
  101. # Table mapping individual bytes to the equivalent byte with the bits of the
  102. # byte reversed. e.g. 1=0b1 is mapped to 0xb10000000=0x80=128. We use this
  103. # precalculated table to simplify calculating the bitwise reversal of a longer
  104. # integer.
  105. REVERSE_BITS_TABLE = bytearray(map(reverse_byte, range(256)))
  106. def reverse64(v):
  107. """Reverse a 64-bit integer bitwise.
  108. We do this by breaking it up into 8 bytes. The 64-bit integer is then the
  109. concatenation of each of these bytes. We reverse it by reversing each byte
  110. on its own using the REVERSE_BITS_TABLE above, and then concatenating the
  111. reversed bytes.
  112. In this case concatenating consists of shifting them into the right
  113. position for the word and then oring the bits together.
  114. """
  115. assert v.bit_length() <= 64
  116. return (
  117. (REVERSE_BITS_TABLE[(v >> 0) & 0xff] << 56) |
  118. (REVERSE_BITS_TABLE[(v >> 8) & 0xff] << 48) |
  119. (REVERSE_BITS_TABLE[(v >> 16) & 0xff] << 40) |
  120. (REVERSE_BITS_TABLE[(v >> 24) & 0xff] << 32) |
  121. (REVERSE_BITS_TABLE[(v >> 32) & 0xff] << 24) |
  122. (REVERSE_BITS_TABLE[(v >> 40) & 0xff] << 16) |
  123. (REVERSE_BITS_TABLE[(v >> 48) & 0xff] << 8) |
  124. (REVERSE_BITS_TABLE[(v >> 56) & 0xff] << 0)
  125. )
  126. MANTISSA_MASK = ((1 << 52) - 1)
  127. def reverse_bits(x, n):
  128. assert x.bit_length() <= n <= 64
  129. x = reverse64(x)
  130. x >>= (64 - n)
  131. return x
  132. def update_mantissa(unbiased_exponent, mantissa):
  133. if unbiased_exponent <= 0:
  134. mantissa = reverse_bits(mantissa, 52)
  135. elif unbiased_exponent <= 51:
  136. n_fractional_bits = (52 - unbiased_exponent)
  137. fractional_part = mantissa & ((1 << n_fractional_bits) - 1)
  138. mantissa ^= fractional_part
  139. mantissa |= reverse_bits(fractional_part, n_fractional_bits)
  140. return mantissa
  141. def lex_to_float(i):
  142. assert i.bit_length() <= 64
  143. has_fractional_part = i >> 63
  144. if has_fractional_part:
  145. exponent = (i >> 52) & ((1 << 11) - 1)
  146. exponent = decode_exponent(exponent)
  147. mantissa = i & MANTISSA_MASK
  148. mantissa = update_mantissa(exponent - BIAS, mantissa)
  149. assert mantissa.bit_length() <= 52
  150. return int_to_float((exponent << 52) | mantissa)
  151. else:
  152. integral_part = i & ((1 << 56) - 1)
  153. return float(integral_part)
  154. def float_to_lex(f):
  155. if is_simple(f):
  156. assert f >= 0
  157. return int(f)
  158. i = float_to_int(f)
  159. i &= ((1 << 63) - 1)
  160. exponent = i >> 52
  161. mantissa = i & MANTISSA_MASK
  162. mantissa = update_mantissa(exponent - BIAS, mantissa)
  163. exponent = encode_exponent(exponent)
  164. assert mantissa.bit_length() <= 52
  165. return (1 << 63) | (exponent << 52) | mantissa
  166. def is_simple(f):
  167. try:
  168. i = int(f)
  169. except (ValueError, OverflowError):
  170. return False
  171. if i != f:
  172. return False
  173. return i.bit_length() <= 56
  174. def draw_float(data):
  175. try:
  176. data.start_example()
  177. f = lex_to_float(data.draw_bits(64))
  178. if data.draw_bits(1):
  179. f = -f
  180. return f
  181. finally:
  182. data.stop_example()
  183. def write_float(data, f):
  184. data.write(int_to_bytes(float_to_lex(abs(f)), 8))
  185. sign = float_to_int(f) >> 63
  186. data.write(hbytes([sign]))