numbertheory.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. #! /usr/bin/env python
  2. #
  3. # Provide some simple capabilities from number theory.
  4. #
  5. # Version of 2008.11.14.
  6. #
  7. # Written in 2005 and 2006 by Peter Pearson and placed in the public domain.
  8. # Revision history:
  9. # 2008.11.14: Use pow( base, exponent, modulus ) for modular_exp.
  10. # Make gcd and lcm accept arbitrarly many arguments.
  11. from __future__ import division
  12. from .six import print_, integer_types
  13. from .six.moves import reduce
  14. import math
  15. import types
  16. class Error( Exception ):
  17. """Base class for exceptions in this module."""
  18. pass
  19. class SquareRootError( Error ):
  20. pass
  21. class NegativeExponentError( Error ):
  22. pass
  23. def modular_exp( base, exponent, modulus ):
  24. "Raise base to exponent, reducing by modulus"
  25. if exponent < 0:
  26. raise NegativeExponentError( "Negative exponents (%d) not allowed" \
  27. % exponent )
  28. return pow( base, exponent, modulus )
  29. # result = 1L
  30. # x = exponent
  31. # b = base + 0L
  32. # while x > 0:
  33. # if x % 2 > 0: result = (result * b) % modulus
  34. # x = x // 2
  35. # b = ( b * b ) % modulus
  36. # return result
  37. def polynomial_reduce_mod( poly, polymod, p ):
  38. """Reduce poly by polymod, integer arithmetic modulo p.
  39. Polynomials are represented as lists of coefficients
  40. of increasing powers of x."""
  41. # This module has been tested only by extensive use
  42. # in calculating modular square roots.
  43. # Just to make this easy, require a monic polynomial:
  44. assert polymod[-1] == 1
  45. assert len( polymod ) > 1
  46. while len( poly ) >= len( polymod ):
  47. if poly[-1] != 0:
  48. for i in range( 2, len( polymod ) + 1 ):
  49. poly[-i] = ( poly[-i] - poly[-1] * polymod[-i] ) % p
  50. poly = poly[0:-1]
  51. return poly
  52. def polynomial_multiply_mod( m1, m2, polymod, p ):
  53. """Polynomial multiplication modulo a polynomial over ints mod p.
  54. Polynomials are represented as lists of coefficients
  55. of increasing powers of x."""
  56. # This is just a seat-of-the-pants implementation.
  57. # This module has been tested only by extensive use
  58. # in calculating modular square roots.
  59. # Initialize the product to zero:
  60. prod = ( len( m1 ) + len( m2 ) - 1 ) * [0]
  61. # Add together all the cross-terms:
  62. for i in range( len( m1 ) ):
  63. for j in range( len( m2 ) ):
  64. prod[i+j] = ( prod[i+j] + m1[i] * m2[j] ) % p
  65. return polynomial_reduce_mod( prod, polymod, p )
  66. def polynomial_exp_mod( base, exponent, polymod, p ):
  67. """Polynomial exponentiation modulo a polynomial over ints mod p.
  68. Polynomials are represented as lists of coefficients
  69. of increasing powers of x."""
  70. # Based on the Handbook of Applied Cryptography, algorithm 2.227.
  71. # This module has been tested only by extensive use
  72. # in calculating modular square roots.
  73. assert exponent < p
  74. if exponent == 0: return [ 1 ]
  75. G = base
  76. k = exponent
  77. if k%2 == 1: s = G
  78. else: s = [ 1 ]
  79. while k > 1:
  80. k = k // 2
  81. G = polynomial_multiply_mod( G, G, polymod, p )
  82. if k%2 == 1: s = polynomial_multiply_mod( G, s, polymod, p )
  83. return s
  84. def jacobi( a, n ):
  85. """Jacobi symbol"""
  86. # Based on the Handbook of Applied Cryptography (HAC), algorithm 2.149.
  87. # This function has been tested by comparison with a small
  88. # table printed in HAC, and by extensive use in calculating
  89. # modular square roots.
  90. assert n >= 3
  91. assert n%2 == 1
  92. a = a % n
  93. if a == 0: return 0
  94. if a == 1: return 1
  95. a1, e = a, 0
  96. while a1%2 == 0:
  97. a1, e = a1//2, e+1
  98. if e%2 == 0 or n%8 == 1 or n%8 == 7: s = 1
  99. else: s = -1
  100. if a1 == 1: return s
  101. if n%4 == 3 and a1%4 == 3: s = -s
  102. return s * jacobi( n % a1, a1 )
  103. def square_root_mod_prime( a, p ):
  104. """Modular square root of a, mod p, p prime."""
  105. # Based on the Handbook of Applied Cryptography, algorithms 3.34 to 3.39.
  106. # This module has been tested for all values in [0,p-1] for
  107. # every prime p from 3 to 1229.
  108. assert 0 <= a < p
  109. assert 1 < p
  110. if a == 0: return 0
  111. if p == 2: return a
  112. jac = jacobi( a, p )
  113. if jac == -1: raise SquareRootError( "%d has no square root modulo %d" \
  114. % ( a, p ) )
  115. if p % 4 == 3: return modular_exp( a, (p+1)//4, p )
  116. if p % 8 == 5:
  117. d = modular_exp( a, (p-1)//4, p )
  118. if d == 1: return modular_exp( a, (p+3)//8, p )
  119. if d == p-1: return ( 2 * a * modular_exp( 4*a, (p-5)//8, p ) ) % p
  120. raise RuntimeError("Shouldn't get here.")
  121. for b in range( 2, p ):
  122. if jacobi( b*b-4*a, p ) == -1:
  123. f = ( a, -b, 1 )
  124. ff = polynomial_exp_mod( ( 0, 1 ), (p+1)//2, f, p )
  125. assert ff[1] == 0
  126. return ff[0]
  127. raise RuntimeError("No b found.")
  128. def inverse_mod( a, m ):
  129. """Inverse of a mod m."""
  130. if a < 0 or m <= a: a = a % m
  131. # From Ferguson and Schneier, roughly:
  132. c, d = a, m
  133. uc, vc, ud, vd = 1, 0, 0, 1
  134. while c != 0:
  135. q, c, d = divmod( d, c ) + ( c, )
  136. uc, vc, ud, vd = ud - q*uc, vd - q*vc, uc, vc
  137. # At this point, d is the GCD, and ud*a+vd*m = d.
  138. # If d == 1, this means that ud is a inverse.
  139. assert d == 1
  140. if ud > 0: return ud
  141. else: return ud + m
  142. def gcd2(a, b):
  143. """Greatest common divisor using Euclid's algorithm."""
  144. while a:
  145. a, b = b%a, a
  146. return b
  147. def gcd( *a ):
  148. """Greatest common divisor.
  149. Usage: gcd( [ 2, 4, 6 ] )
  150. or: gcd( 2, 4, 6 )
  151. """
  152. if len( a ) > 1: return reduce( gcd2, a )
  153. if hasattr( a[0], "__iter__" ): return reduce( gcd2, a[0] )
  154. return a[0]
  155. def lcm2(a,b):
  156. """Least common multiple of two integers."""
  157. return (a*b)//gcd(a,b)
  158. def lcm( *a ):
  159. """Least common multiple.
  160. Usage: lcm( [ 3, 4, 5 ] )
  161. or: lcm( 3, 4, 5 )
  162. """
  163. if len( a ) > 1: return reduce( lcm2, a )
  164. if hasattr( a[0], "__iter__" ): return reduce( lcm2, a[0] )
  165. return a[0]
  166. def factorization( n ):
  167. """Decompose n into a list of (prime,exponent) pairs."""
  168. assert isinstance( n, integer_types )
  169. if n < 2: return []
  170. result = []
  171. d = 2
  172. # Test the small primes:
  173. for d in smallprimes:
  174. if d > n: break
  175. q, r = divmod( n, d )
  176. if r == 0:
  177. count = 1
  178. while d <= n:
  179. n = q
  180. q, r = divmod( n, d )
  181. if r != 0: break
  182. count = count + 1
  183. result.append( ( d, count ) )
  184. # If n is still greater than the last of our small primes,
  185. # it may require further work:
  186. if n > smallprimes[-1]:
  187. if is_prime( n ): # If what's left is prime, it's easy:
  188. result.append( ( n, 1 ) )
  189. else: # Ugh. Search stupidly for a divisor:
  190. d = smallprimes[-1]
  191. while 1:
  192. d = d + 2 # Try the next divisor.
  193. q, r = divmod( n, d )
  194. if q < d: break # n < d*d means we're done, n = 1 or prime.
  195. if r == 0: # d divides n. How many times?
  196. count = 1
  197. n = q
  198. while d <= n: # As long as d might still divide n,
  199. q, r = divmod( n, d ) # see if it does.
  200. if r != 0: break
  201. n = q # It does. Reduce n, increase count.
  202. count = count + 1
  203. result.append( ( d, count ) )
  204. if n > 1: result.append( ( n, 1 ) )
  205. return result
  206. def phi( n ):
  207. """Return the Euler totient function of n."""
  208. assert isinstance( n, integer_types )
  209. if n < 3: return 1
  210. result = 1
  211. ff = factorization( n )
  212. for f in ff:
  213. e = f[1]
  214. if e > 1:
  215. result = result * f[0] ** (e-1) * ( f[0] - 1 )
  216. else:
  217. result = result * ( f[0] - 1 )
  218. return result
  219. def carmichael( n ):
  220. """Return Carmichael function of n.
  221. Carmichael(n) is the smallest integer x such that
  222. m**x = 1 mod n for all m relatively prime to n.
  223. """
  224. return carmichael_of_factorized( factorization( n ) )
  225. def carmichael_of_factorized( f_list ):
  226. """Return the Carmichael function of a number that is
  227. represented as a list of (prime,exponent) pairs.
  228. """
  229. if len( f_list ) < 1: return 1
  230. result = carmichael_of_ppower( f_list[0] )
  231. for i in range( 1, len( f_list ) ):
  232. result = lcm( result, carmichael_of_ppower( f_list[i] ) )
  233. return result
  234. def carmichael_of_ppower( pp ):
  235. """Carmichael function of the given power of the given prime.
  236. """
  237. p, a = pp
  238. if p == 2 and a > 2: return 2**(a-2)
  239. else: return (p-1) * p**(a-1)
  240. def order_mod( x, m ):
  241. """Return the order of x in the multiplicative group mod m.
  242. """
  243. # Warning: this implementation is not very clever, and will
  244. # take a long time if m is very large.
  245. if m <= 1: return 0
  246. assert gcd( x, m ) == 1
  247. z = x
  248. result = 1
  249. while z != 1:
  250. z = ( z * x ) % m
  251. result = result + 1
  252. return result
  253. def largest_factor_relatively_prime( a, b ):
  254. """Return the largest factor of a relatively prime to b.
  255. """
  256. while 1:
  257. d = gcd( a, b )
  258. if d <= 1: break
  259. b = d
  260. while 1:
  261. q, r = divmod( a, d )
  262. if r > 0:
  263. break
  264. a = q
  265. return a
  266. def kinda_order_mod( x, m ):
  267. """Return the order of x in the multiplicative group mod m',
  268. where m' is the largest factor of m relatively prime to x.
  269. """
  270. return order_mod( x, largest_factor_relatively_prime( m, x ) )
  271. def is_prime( n ):
  272. """Return True if x is prime, False otherwise.
  273. We use the Miller-Rabin test, as given in Menezes et al. p. 138.
  274. This test is not exact: there are composite values n for which
  275. it returns True.
  276. In testing the odd numbers from 10000001 to 19999999,
  277. about 66 composites got past the first test,
  278. 5 got past the second test, and none got past the third.
  279. Since factors of 2, 3, 5, 7, and 11 were detected during
  280. preliminary screening, the number of numbers tested by
  281. Miller-Rabin was (19999999 - 10000001)*(2/3)*(4/5)*(6/7)
  282. = 4.57 million.
  283. """
  284. # (This is used to study the risk of false positives:)
  285. global miller_rabin_test_count
  286. miller_rabin_test_count = 0
  287. if n <= smallprimes[-1]:
  288. if n in smallprimes: return True
  289. else: return False
  290. if gcd( n, 2*3*5*7*11 ) != 1: return False
  291. # Choose a number of iterations sufficient to reduce the
  292. # probability of accepting a composite below 2**-80
  293. # (from Menezes et al. Table 4.4):
  294. t = 40
  295. n_bits = 1 + int( math.log( n, 2 ) )
  296. for k, tt in ( ( 100, 27 ),
  297. ( 150, 18 ),
  298. ( 200, 15 ),
  299. ( 250, 12 ),
  300. ( 300, 9 ),
  301. ( 350, 8 ),
  302. ( 400, 7 ),
  303. ( 450, 6 ),
  304. ( 550, 5 ),
  305. ( 650, 4 ),
  306. ( 850, 3 ),
  307. ( 1300, 2 ),
  308. ):
  309. if n_bits < k: break
  310. t = tt
  311. # Run the test t times:
  312. s = 0
  313. r = n - 1
  314. while ( r % 2 ) == 0:
  315. s = s + 1
  316. r = r // 2
  317. for i in range( t ):
  318. a = smallprimes[ i ]
  319. y = modular_exp( a, r, n )
  320. if y != 1 and y != n-1:
  321. j = 1
  322. while j <= s - 1 and y != n - 1:
  323. y = modular_exp( y, 2, n )
  324. if y == 1:
  325. miller_rabin_test_count = i + 1
  326. return False
  327. j = j + 1
  328. if y != n-1:
  329. miller_rabin_test_count = i + 1
  330. return False
  331. return True
  332. def next_prime( starting_value ):
  333. "Return the smallest prime larger than the starting value."
  334. if starting_value < 2: return 2
  335. result = ( starting_value + 1 ) | 1
  336. while not is_prime( result ): result = result + 2
  337. return result
  338. smallprimes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41,
  339. 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
  340. 101, 103, 107, 109, 113, 127, 131, 137, 139, 149,
  341. 151, 157, 163, 167, 173, 179, 181, 191, 193, 197,
  342. 199, 211, 223, 227, 229, 233, 239, 241, 251, 257,
  343. 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
  344. 317, 331, 337, 347, 349, 353, 359, 367, 373, 379,
  345. 383, 389, 397, 401, 409, 419, 421, 431, 433, 439,
  346. 443, 449, 457, 461, 463, 467, 479, 487, 491, 499,
  347. 503, 509, 521, 523, 541, 547, 557, 563, 569, 571,
  348. 577, 587, 593, 599, 601, 607, 613, 617, 619, 631,
  349. 641, 643, 647, 653, 659, 661, 673, 677, 683, 691,
  350. 701, 709, 719, 727, 733, 739, 743, 751, 757, 761,
  351. 769, 773, 787, 797, 809, 811, 821, 823, 827, 829,
  352. 839, 853, 857, 859, 863, 877, 881, 883, 887, 907,
  353. 911, 919, 929, 937, 941, 947, 953, 967, 971, 977,
  354. 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033,
  355. 1039, 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093,
  356. 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163,
  357. 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229]
  358. miller_rabin_test_count = 0
  359. def __main__():
  360. # Making sure locally defined exceptions work:
  361. # p = modular_exp( 2, -2, 3 )
  362. # p = square_root_mod_prime( 2, 3 )
  363. print_("Testing gcd...")
  364. assert gcd( 3*5*7, 3*5*11, 3*5*13 ) == 3*5
  365. assert gcd( [ 3*5*7, 3*5*11, 3*5*13 ] ) == 3*5
  366. assert gcd( 3 ) == 3
  367. print_("Testing lcm...")
  368. assert lcm( 3, 5*3, 7*3 ) == 3*5*7
  369. assert lcm( [ 3, 5*3, 7*3 ] ) == 3*5*7
  370. assert lcm( 3 ) == 3
  371. print_("Testing next_prime...")
  372. bigprimes = ( 999671,
  373. 999683,
  374. 999721,
  375. 999727,
  376. 999749,
  377. 999763,
  378. 999769,
  379. 999773,
  380. 999809,
  381. 999853,
  382. 999863,
  383. 999883,
  384. 999907,
  385. 999917,
  386. 999931,
  387. 999953,
  388. 999959,
  389. 999961,
  390. 999979,
  391. 999983 )
  392. for i in range( len( bigprimes ) - 1 ):
  393. assert next_prime( bigprimes[i] ) == bigprimes[ i+1 ]
  394. error_tally = 0
  395. # Test the square_root_mod_prime function:
  396. for p in smallprimes:
  397. print_("Testing square_root_mod_prime for modulus p = %d." % p)
  398. squares = []
  399. for root in range( 0, 1+p//2 ):
  400. sq = ( root * root ) % p
  401. squares.append( sq )
  402. calculated = square_root_mod_prime( sq, p )
  403. if ( calculated * calculated ) % p != sq:
  404. error_tally = error_tally + 1
  405. print_("Failed to find %d as sqrt( %d ) mod %d. Said %d." % \
  406. ( root, sq, p, calculated ))
  407. for nonsquare in range( 0, p ):
  408. if nonsquare not in squares:
  409. try:
  410. calculated = square_root_mod_prime( nonsquare, p )
  411. except SquareRootError:
  412. pass
  413. else:
  414. error_tally = error_tally + 1
  415. print_("Failed to report no root for sqrt( %d ) mod %d." % \
  416. ( nonsquare, p ))
  417. # Test the jacobi function:
  418. for m in range( 3, 400, 2 ):
  419. print_("Testing jacobi for modulus m = %d." % m)
  420. if is_prime( m ):
  421. squares = []
  422. for root in range( 1, m ):
  423. if jacobi( root * root, m ) != 1:
  424. error_tally = error_tally + 1
  425. print_("jacobi( %d * %d, %d ) != 1" % ( root, root, m ))
  426. squares.append( root * root % m )
  427. for i in range( 1, m ):
  428. if not i in squares:
  429. if jacobi( i, m ) != -1:
  430. error_tally = error_tally + 1
  431. print_("jacobi( %d, %d ) != -1" % ( i, m ))
  432. else: # m is not prime.
  433. f = factorization( m )
  434. for a in range( 1, m ):
  435. c = 1
  436. for i in f:
  437. c = c * jacobi( a, i[0] ) ** i[1]
  438. if c != jacobi( a, m ):
  439. error_tally = error_tally + 1
  440. print_("%d != jacobi( %d, %d )" % ( c, a, m ))
  441. # Test the inverse_mod function:
  442. print_("Testing inverse_mod . . .")
  443. import random
  444. n_tests = 0
  445. for i in range( 100 ):
  446. m = random.randint( 20, 10000 )
  447. for j in range( 100 ):
  448. a = random.randint( 1, m-1 )
  449. if gcd( a, m ) == 1:
  450. n_tests = n_tests + 1
  451. inv = inverse_mod( a, m )
  452. if inv <= 0 or inv >= m or ( a * inv ) % m != 1:
  453. error_tally = error_tally + 1
  454. print_("%d = inverse_mod( %d, %d ) is wrong." % ( inv, a, m ))
  455. assert n_tests > 1000
  456. print_(n_tests, " tests of inverse_mod completed.")
  457. class FailedTest(Exception): pass
  458. print_(error_tally, "errors detected.")
  459. if error_tally != 0:
  460. raise FailedTest("%d errors detected" % error_tally)
  461. if __name__ == '__main__':
  462. __main__()