test_number.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. #
  3. # SelfTest/Util/test_number.py: Self-test for parts of the Crypto.Util.number module
  4. #
  5. # Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
  6. #
  7. # ===================================================================
  8. # The contents of this file are dedicated to the public domain. To
  9. # the extent that dedication to the public domain is not available,
  10. # everyone is granted a worldwide, perpetual, royalty-free,
  11. # non-exclusive license to exercise all rights associated with the
  12. # contents of this file for any purpose whatsoever.
  13. # No rights are reserved.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  16. # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  17. # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  18. # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  19. # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  20. # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  21. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. # SOFTWARE.
  23. # ===================================================================
  24. """Self-tests for (some of) Crypto.Util.number"""
  25. __revision__ = "$Id$"
  26. import sys
  27. if sys.version_info[0] == 2 and sys.version_info[1] == 1:
  28. from Crypto.Util.py21compat import *
  29. import unittest
  30. # NB: In some places, we compare tuples instead of just output values so that
  31. # if any inputs cause a test failure, we'll be able to tell which ones.
  32. class MiscTests(unittest.TestCase):
  33. def setUp(self):
  34. global number, math
  35. from Crypto.Util import number
  36. import math
  37. def test_ceil_shift(self):
  38. """Util.number.ceil_shift"""
  39. self.assertRaises(AssertionError, number.ceil_shift, -1, 1)
  40. self.assertRaises(AssertionError, number.ceil_shift, 1, -1)
  41. # b = 0
  42. self.assertEqual(0, number.ceil_shift(0, 0))
  43. self.assertEqual(1, number.ceil_shift(1, 0))
  44. self.assertEqual(2, number.ceil_shift(2, 0))
  45. self.assertEqual(3, number.ceil_shift(3, 0))
  46. # b = 1
  47. self.assertEqual(0, number.ceil_shift(0, 1))
  48. self.assertEqual(1, number.ceil_shift(1, 1))
  49. self.assertEqual(1, number.ceil_shift(2, 1))
  50. self.assertEqual(2, number.ceil_shift(3, 1))
  51. # b = 2
  52. self.assertEqual(0, number.ceil_shift(0, 2))
  53. self.assertEqual(1, number.ceil_shift(1, 2))
  54. self.assertEqual(1, number.ceil_shift(2, 2))
  55. self.assertEqual(1, number.ceil_shift(3, 2))
  56. self.assertEqual(1, number.ceil_shift(4, 2))
  57. self.assertEqual(2, number.ceil_shift(5, 2))
  58. self.assertEqual(2, number.ceil_shift(6, 2))
  59. self.assertEqual(2, number.ceil_shift(7, 2))
  60. self.assertEqual(2, number.ceil_shift(8, 2))
  61. self.assertEqual(3, number.ceil_shift(9, 2))
  62. for b in range(3, 1+129, 3): # 3, 6, ... , 129
  63. self.assertEqual(0, number.ceil_shift(0, b))
  64. n = 1L
  65. while n <= 2L**(b+2):
  66. (q, r) = divmod(n-1, 2L**b)
  67. expected = q + int(not not r)
  68. self.assertEqual((n-1, b, expected),
  69. (n-1, b, number.ceil_shift(n-1, b)))
  70. (q, r) = divmod(n, 2L**b)
  71. expected = q + int(not not r)
  72. self.assertEqual((n, b, expected),
  73. (n, b, number.ceil_shift(n, b)))
  74. (q, r) = divmod(n+1, 2L**b)
  75. expected = q + int(not not r)
  76. self.assertEqual((n+1, b, expected),
  77. (n+1, b, number.ceil_shift(n+1, b)))
  78. n *= 2
  79. def test_ceil_div(self):
  80. """Util.number.ceil_div"""
  81. self.assertRaises(TypeError, number.ceil_div, "1", 1)
  82. self.assertRaises(ZeroDivisionError, number.ceil_div, 1, 0)
  83. self.assertRaises(ZeroDivisionError, number.ceil_div, -1, 0)
  84. # b = -1
  85. self.assertEqual(0, number.ceil_div(0, -1))
  86. self.assertEqual(-1, number.ceil_div(1, -1))
  87. self.assertEqual(-2, number.ceil_div(2, -1))
  88. self.assertEqual(-3, number.ceil_div(3, -1))
  89. # b = 1
  90. self.assertEqual(0, number.ceil_div(0, 1))
  91. self.assertEqual(1, number.ceil_div(1, 1))
  92. self.assertEqual(2, number.ceil_div(2, 1))
  93. self.assertEqual(3, number.ceil_div(3, 1))
  94. # b = 2
  95. self.assertEqual(0, number.ceil_div(0, 2))
  96. self.assertEqual(1, number.ceil_div(1, 2))
  97. self.assertEqual(1, number.ceil_div(2, 2))
  98. self.assertEqual(2, number.ceil_div(3, 2))
  99. self.assertEqual(2, number.ceil_div(4, 2))
  100. self.assertEqual(3, number.ceil_div(5, 2))
  101. # b = 3
  102. self.assertEqual(0, number.ceil_div(0, 3))
  103. self.assertEqual(1, number.ceil_div(1, 3))
  104. self.assertEqual(1, number.ceil_div(2, 3))
  105. self.assertEqual(1, number.ceil_div(3, 3))
  106. self.assertEqual(2, number.ceil_div(4, 3))
  107. self.assertEqual(2, number.ceil_div(5, 3))
  108. self.assertEqual(2, number.ceil_div(6, 3))
  109. self.assertEqual(3, number.ceil_div(7, 3))
  110. # b = 4
  111. self.assertEqual(0, number.ceil_div(0, 4))
  112. self.assertEqual(1, number.ceil_div(1, 4))
  113. self.assertEqual(1, number.ceil_div(2, 4))
  114. self.assertEqual(1, number.ceil_div(3, 4))
  115. self.assertEqual(1, number.ceil_div(4, 4))
  116. self.assertEqual(2, number.ceil_div(5, 4))
  117. self.assertEqual(2, number.ceil_div(6, 4))
  118. self.assertEqual(2, number.ceil_div(7, 4))
  119. self.assertEqual(2, number.ceil_div(8, 4))
  120. self.assertEqual(3, number.ceil_div(9, 4))
  121. # b = -4
  122. self.assertEqual(3, number.ceil_div(-9, -4))
  123. self.assertEqual(2, number.ceil_div(-8, -4))
  124. self.assertEqual(2, number.ceil_div(-7, -4))
  125. self.assertEqual(2, number.ceil_div(-6, -4))
  126. self.assertEqual(2, number.ceil_div(-5, -4))
  127. self.assertEqual(1, number.ceil_div(-4, -4))
  128. self.assertEqual(1, number.ceil_div(-3, -4))
  129. self.assertEqual(1, number.ceil_div(-2, -4))
  130. self.assertEqual(1, number.ceil_div(-1, -4))
  131. self.assertEqual(0, number.ceil_div(0, -4))
  132. self.assertEqual(0, number.ceil_div(1, -4))
  133. self.assertEqual(0, number.ceil_div(2, -4))
  134. self.assertEqual(0, number.ceil_div(3, -4))
  135. self.assertEqual(-1, number.ceil_div(4, -4))
  136. self.assertEqual(-1, number.ceil_div(5, -4))
  137. self.assertEqual(-1, number.ceil_div(6, -4))
  138. self.assertEqual(-1, number.ceil_div(7, -4))
  139. self.assertEqual(-2, number.ceil_div(8, -4))
  140. self.assertEqual(-2, number.ceil_div(9, -4))
  141. def test_exact_log2(self):
  142. """Util.number.exact_log2"""
  143. self.assertRaises(TypeError, number.exact_log2, "0")
  144. self.assertRaises(ValueError, number.exact_log2, -1)
  145. self.assertRaises(ValueError, number.exact_log2, 0)
  146. self.assertEqual(0, number.exact_log2(1))
  147. self.assertEqual(1, number.exact_log2(2))
  148. self.assertRaises(ValueError, number.exact_log2, 3)
  149. self.assertEqual(2, number.exact_log2(4))
  150. self.assertRaises(ValueError, number.exact_log2, 5)
  151. self.assertRaises(ValueError, number.exact_log2, 6)
  152. self.assertRaises(ValueError, number.exact_log2, 7)
  153. e = 3
  154. n = 8
  155. while e < 16:
  156. if n == 2**e:
  157. self.assertEqual(e, number.exact_log2(n), "expected=2**%d, n=%d" % (e, n))
  158. e += 1
  159. else:
  160. self.assertRaises(ValueError, number.exact_log2, n)
  161. n += 1
  162. for e in range(16, 1+64, 2):
  163. self.assertRaises(ValueError, number.exact_log2, 2L**e-1)
  164. self.assertEqual(e, number.exact_log2(2L**e))
  165. self.assertRaises(ValueError, number.exact_log2, 2L**e+1)
  166. def test_exact_div(self):
  167. """Util.number.exact_div"""
  168. # Positive numbers
  169. self.assertEqual(1, number.exact_div(1, 1))
  170. self.assertRaises(ValueError, number.exact_div, 1, 2)
  171. self.assertEqual(1, number.exact_div(2, 2))
  172. self.assertRaises(ValueError, number.exact_div, 3, 2)
  173. self.assertEqual(2, number.exact_div(4, 2))
  174. # Negative numbers
  175. self.assertEqual(-1, number.exact_div(-1, 1))
  176. self.assertEqual(-1, number.exact_div(1, -1))
  177. self.assertRaises(ValueError, number.exact_div, -1, 2)
  178. self.assertEqual(1, number.exact_div(-2, -2))
  179. self.assertEqual(-2, number.exact_div(-4, 2))
  180. # Zero dividend
  181. self.assertEqual(0, number.exact_div(0, 1))
  182. self.assertEqual(0, number.exact_div(0, 2))
  183. # Zero divisor (allow_divzero == False)
  184. self.assertRaises(ZeroDivisionError, number.exact_div, 0, 0)
  185. self.assertRaises(ZeroDivisionError, number.exact_div, 1, 0)
  186. # Zero divisor (allow_divzero == True)
  187. self.assertEqual(0, number.exact_div(0, 0, allow_divzero=True))
  188. self.assertRaises(ValueError, number.exact_div, 1, 0, allow_divzero=True)
  189. def test_floor_div(self):
  190. """Util.number.floor_div"""
  191. self.assertRaises(TypeError, number.floor_div, "1", 1)
  192. for a in range(-10, 10):
  193. for b in range(-10, 10):
  194. if b == 0:
  195. self.assertRaises(ZeroDivisionError, number.floor_div, a, b)
  196. else:
  197. self.assertEqual((a, b, int(math.floor(float(a) / b))),
  198. (a, b, number.floor_div(a, b)))
  199. def test_getStrongPrime(self):
  200. """Util.number.getStrongPrime"""
  201. self.assertRaises(ValueError, number.getStrongPrime, 256)
  202. self.assertRaises(ValueError, number.getStrongPrime, 513)
  203. bits = 512
  204. x = number.getStrongPrime(bits)
  205. self.assertNotEqual(x % 2, 0)
  206. self.assertEqual(x > (1L << bits-1)-1, 1)
  207. self.assertEqual(x < (1L << bits), 1)
  208. e = 2**16+1
  209. x = number.getStrongPrime(bits, e)
  210. self.assertEqual(number.GCD(x-1, e), 1)
  211. self.assertNotEqual(x % 2, 0)
  212. self.assertEqual(x > (1L << bits-1)-1, 1)
  213. self.assertEqual(x < (1L << bits), 1)
  214. e = 2**16+2
  215. x = number.getStrongPrime(bits, e)
  216. self.assertEqual(number.GCD((x-1)>>1, e), 1)
  217. self.assertNotEqual(x % 2, 0)
  218. self.assertEqual(x > (1L << bits-1)-1, 1)
  219. self.assertEqual(x < (1L << bits), 1)
  220. def test_isPrime(self):
  221. """Util.number.isPrime"""
  222. self.assertEqual(number.isPrime(-3), False) # Regression test: negative numbers should not be prime
  223. self.assertEqual(number.isPrime(-2), False) # Regression test: negative numbers should not be prime
  224. self.assertEqual(number.isPrime(1), False) # Regression test: isPrime(1) caused some versions of PyCrypto to crash.
  225. self.assertEqual(number.isPrime(2), True)
  226. self.assertEqual(number.isPrime(3), True)
  227. self.assertEqual(number.isPrime(4), False)
  228. self.assertEqual(number.isPrime(2L**1279-1), True)
  229. self.assertEqual(number.isPrime(-(2L**1279-1)), False) # Regression test: negative numbers should not be prime
  230. # test some known gmp pseudo-primes taken from
  231. # http://www.trnicely.net/misc/mpzspsp.html
  232. for composite in (43 * 127 * 211, 61 * 151 * 211, 15259 * 30517,
  233. 346141L * 692281L, 1007119L * 2014237L, 3589477L * 7178953L,
  234. 4859419L * 9718837L, 2730439L * 5460877L,
  235. 245127919L * 490255837L, 963939391L * 1927878781L,
  236. 4186358431L * 8372716861L, 1576820467L * 3153640933L):
  237. self.assertEqual(number.isPrime(long(composite)), False)
  238. def test_size(self):
  239. self.assertEqual(number.size(2),2)
  240. self.assertEqual(number.size(3),2)
  241. self.assertEqual(number.size(0xa2),8)
  242. self.assertEqual(number.size(0xa2ba40),8*3)
  243. self.assertEqual(number.size(0xa2ba40ee07e3b2bd2f02ce227f36a195024486e49c19cb41bbbdfbba98b22b0e577c2eeaffa20d883a76e65e394c69d4b3c05a1e8fadda27edb2a42bc000fe888b9b32c22d15add0cd76b3e7936e19955b220dd17d4ea904b1ec102b2e4de7751222aa99151024c7cb41cc5ea21d00eeb41f7c800834d2c6e06bce3bce7ea9a5L), 1024)
  244. def test_negative_number_roundtrip_mpzToLongObj_longObjToMPZ(self):
  245. """Test that mpzToLongObj and longObjToMPZ (internal functions) roundtrip negative numbers correctly."""
  246. n = -100000000000000000000000000000000000L
  247. e = 2L
  248. k = number._fastmath.rsa_construct(n, e)
  249. self.assertEqual(n, k.n)
  250. self.assertEqual(e, k.e)
  251. def get_tests(config={}):
  252. from Crypto.SelfTest.st_common import list_test_cases
  253. return list_test_cases(MiscTests)
  254. if __name__ == '__main__':
  255. suite = lambda: unittest.TestSuite(get_tests())
  256. unittest.main(defaultTest='suite')
  257. # vim:set ts=4 sw=4 sts=4 expandtab: