common.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # -*- coding: utf-8 -*-
  2. #
  3. # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # https://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from rsa._compat import zip
  17. """Common functionality shared by several modules."""
  18. class NotRelativePrimeError(ValueError):
  19. def __init__(self, a, b, d, msg=None):
  20. super(NotRelativePrimeError, self).__init__(
  21. msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
  22. self.a = a
  23. self.b = b
  24. self.d = d
  25. def bit_size(num):
  26. """
  27. Number of bits needed to represent a integer excluding any prefix
  28. 0 bits.
  29. Usage::
  30. >>> bit_size(1023)
  31. 10
  32. >>> bit_size(1024)
  33. 11
  34. >>> bit_size(1025)
  35. 11
  36. :param num:
  37. Integer value. If num is 0, returns 0. Only the absolute value of the
  38. number is considered. Therefore, signed integers will be abs(num)
  39. before the number's bit length is determined.
  40. :returns:
  41. Returns the number of bits in the integer.
  42. """
  43. try:
  44. return num.bit_length()
  45. except AttributeError:
  46. raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
  47. def byte_size(number):
  48. """
  49. Returns the number of bytes required to hold a specific long number.
  50. The number of bytes is rounded up.
  51. Usage::
  52. >>> byte_size(1 << 1023)
  53. 128
  54. >>> byte_size((1 << 1024) - 1)
  55. 128
  56. >>> byte_size(1 << 1024)
  57. 129
  58. :param number:
  59. An unsigned integer
  60. :returns:
  61. The number of bytes required to hold a specific long number.
  62. """
  63. if number == 0:
  64. return 1
  65. return ceil_div(bit_size(number), 8)
  66. def ceil_div(num, div):
  67. """
  68. Returns the ceiling function of a division between `num` and `div`.
  69. Usage::
  70. >>> ceil_div(100, 7)
  71. 15
  72. >>> ceil_div(100, 10)
  73. 10
  74. >>> ceil_div(1, 4)
  75. 1
  76. :param num: Division's numerator, a number
  77. :param div: Division's divisor, a number
  78. :return: Rounded up result of the division between the parameters.
  79. """
  80. quanta, mod = divmod(num, div)
  81. if mod:
  82. quanta += 1
  83. return quanta
  84. def extended_gcd(a, b):
  85. """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
  86. """
  87. # r = gcd(a,b) i = multiplicitive inverse of a mod b
  88. # or j = multiplicitive inverse of b mod a
  89. # Neg return values for i or j are made positive mod b or a respectively
  90. # Iterateive Version is faster and uses much less stack space
  91. x = 0
  92. y = 1
  93. lx = 1
  94. ly = 0
  95. oa = a # Remember original a/b to remove
  96. ob = b # negative values from return results
  97. while b != 0:
  98. q = a // b
  99. (a, b) = (b, a % b)
  100. (x, lx) = ((lx - (q * x)), x)
  101. (y, ly) = ((ly - (q * y)), y)
  102. if lx < 0:
  103. lx += ob # If neg wrap modulo orignal b
  104. if ly < 0:
  105. ly += oa # If neg wrap modulo orignal a
  106. return a, lx, ly # Return only positive values
  107. def inverse(x, n):
  108. """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
  109. >>> inverse(7, 4)
  110. 3
  111. >>> (inverse(143, 4) * 143) % 4
  112. 1
  113. """
  114. (divider, inv, _) = extended_gcd(x, n)
  115. if divider != 1:
  116. raise NotRelativePrimeError(x, n, divider)
  117. return inv
  118. def crt(a_values, modulo_values):
  119. """Chinese Remainder Theorem.
  120. Calculates x such that x = a[i] (mod m[i]) for each i.
  121. :param a_values: the a-values of the above equation
  122. :param modulo_values: the m-values of the above equation
  123. :returns: x such that x = a[i] (mod m[i]) for each i
  124. >>> crt([2, 3], [3, 5])
  125. 8
  126. >>> crt([2, 3, 2], [3, 5, 7])
  127. 23
  128. >>> crt([2, 3, 0], [7, 11, 15])
  129. 135
  130. """
  131. m = 1
  132. x = 0
  133. for modulo in modulo_values:
  134. m *= modulo
  135. for (m_i, a_i) in zip(modulo_values, a_values):
  136. M_i = m // m_i
  137. inv = inverse(M_i, m_i)
  138. x = (x + a_i * M_i * inv) % m
  139. return x
  140. if __name__ == '__main__':
  141. import doctest
  142. doctest.testmod()