_IntegerNative.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # ===================================================================
  2. #
  3. # Copyright (c) 2014, Legrandin <helderijs@gmail.com>
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions
  8. # are met:
  9. #
  10. # 1. Redistributions of source code must retain the above copyright
  11. # notice, this list of conditions and the following disclaimer.
  12. # 2. Redistributions in binary form must reproduce the above copyright
  13. # notice, this list of conditions and the following disclaimer in
  14. # the documentation and/or other materials provided with the
  15. # distribution.
  16. #
  17. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  18. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  19. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
  20. # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
  21. # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  22. # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  23. # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  25. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  26. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
  27. # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  28. # POSSIBILITY OF SUCH DAMAGE.
  29. # ===================================================================
  30. from ._IntegerBase import IntegerBase
  31. from Cryptodome.Util.number import long_to_bytes, bytes_to_long
  32. class IntegerNative(IntegerBase):
  33. """A class to model a natural integer (including zero)"""
  34. def __init__(self, value):
  35. if isinstance(value, float):
  36. raise ValueError("A floating point type is not a natural number")
  37. try:
  38. self._value = value._value
  39. except AttributeError:
  40. self._value = value
  41. # Conversions
  42. def __int__(self):
  43. return self._value
  44. def __str__(self):
  45. return str(int(self))
  46. def __repr__(self):
  47. return "Integer(%s)" % str(self)
  48. def to_bytes(self, block_size=0):
  49. if self._value < 0:
  50. raise ValueError("Conversion only valid for non-negative numbers")
  51. result = long_to_bytes(self._value, block_size)
  52. if len(result) > block_size > 0:
  53. raise ValueError("Value too large to encode")
  54. return result
  55. @classmethod
  56. def from_bytes(cls, byte_string):
  57. return cls(bytes_to_long(byte_string))
  58. # Relations
  59. def __eq__(self, term):
  60. if term is None:
  61. return False
  62. return self._value == int(term)
  63. def __ne__(self, term):
  64. return not self.__eq__(term)
  65. def __lt__(self, term):
  66. return self._value < int(term)
  67. def __le__(self, term):
  68. return self.__lt__(term) or self.__eq__(term)
  69. def __gt__(self, term):
  70. return not self.__le__(term)
  71. def __ge__(self, term):
  72. return not self.__lt__(term)
  73. def __nonzero__(self):
  74. return self._value != 0
  75. __bool__ = __nonzero__
  76. def is_negative(self):
  77. return self._value < 0
  78. # Arithmetic operations
  79. def __add__(self, term):
  80. return self.__class__(self._value + int(term))
  81. def __sub__(self, term):
  82. return self.__class__(self._value - int(term))
  83. def __mul__(self, factor):
  84. return self.__class__(self._value * int(factor))
  85. def __floordiv__(self, divisor):
  86. return self.__class__(self._value // int(divisor))
  87. def __mod__(self, divisor):
  88. divisor_value = int(divisor)
  89. if divisor_value < 0:
  90. raise ValueError("Modulus must be positive")
  91. return self.__class__(self._value % divisor_value)
  92. def inplace_pow(self, exponent, modulus=None):
  93. exp_value = int(exponent)
  94. if exp_value < 0:
  95. raise ValueError("Exponent must not be negative")
  96. if modulus is not None:
  97. mod_value = int(modulus)
  98. if mod_value < 0:
  99. raise ValueError("Modulus must be positive")
  100. if mod_value == 0:
  101. raise ZeroDivisionError("Modulus cannot be zero")
  102. else:
  103. mod_value = None
  104. self._value = pow(self._value, exp_value, mod_value)
  105. return self
  106. def __pow__(self, exponent, modulus=None):
  107. result = self.__class__(self)
  108. return result.inplace_pow(exponent, modulus)
  109. def __abs__(self):
  110. return abs(self._value)
  111. def sqrt(self, modulus=None):
  112. value = self._value
  113. if modulus is None:
  114. if value < 0:
  115. raise ValueError("Square root of negative value")
  116. # http://stackoverflow.com/questions/15390807/integer-square-root-in-python
  117. x = value
  118. y = (x + 1) // 2
  119. while y < x:
  120. x = y
  121. y = (x + value // x) // 2
  122. result = x
  123. else:
  124. if modulus <= 0:
  125. raise ValueError("Modulus must be positive")
  126. result = self._tonelli_shanks(self % modulus, modulus)
  127. return self.__class__(result)
  128. def __iadd__(self, term):
  129. self._value += int(term)
  130. return self
  131. def __isub__(self, term):
  132. self._value -= int(term)
  133. return self
  134. def __imul__(self, term):
  135. self._value *= int(term)
  136. return self
  137. def __imod__(self, term):
  138. modulus = int(term)
  139. if modulus == 0:
  140. raise ZeroDivisionError("Division by zero")
  141. if modulus < 0:
  142. raise ValueError("Modulus must be positive")
  143. self._value %= modulus
  144. return self
  145. # Boolean/bit operations
  146. def __and__(self, term):
  147. return self.__class__(self._value & int(term))
  148. def __or__(self, term):
  149. return self.__class__(self._value | int(term))
  150. def __rshift__(self, pos):
  151. try:
  152. return self.__class__(self._value >> int(pos))
  153. except OverflowError:
  154. if self._value >= 0:
  155. return 0
  156. else:
  157. return -1
  158. def __irshift__(self, pos):
  159. try:
  160. self._value >>= int(pos)
  161. except OverflowError:
  162. if self._value >= 0:
  163. return 0
  164. else:
  165. return -1
  166. return self
  167. def __lshift__(self, pos):
  168. try:
  169. return self.__class__(self._value << int(pos))
  170. except OverflowError:
  171. raise ValueError("Incorrect shift count")
  172. def __ilshift__(self, pos):
  173. try:
  174. self._value <<= int(pos)
  175. except OverflowError:
  176. raise ValueError("Incorrect shift count")
  177. return self
  178. def get_bit(self, n):
  179. if self._value < 0:
  180. raise ValueError("no bit representation for negative values")
  181. try:
  182. try:
  183. result = (self._value >> n._value) & 1
  184. if n._value < 0:
  185. raise ValueError("negative bit count")
  186. except AttributeError:
  187. result = (self._value >> n) & 1
  188. if n < 0:
  189. raise ValueError("negative bit count")
  190. except OverflowError:
  191. result = 0
  192. return result
  193. # Extra
  194. def is_odd(self):
  195. return (self._value & 1) == 1
  196. def is_even(self):
  197. return (self._value & 1) == 0
  198. def size_in_bits(self):
  199. if self._value < 0:
  200. raise ValueError("Conversion only valid for non-negative numbers")
  201. if self._value == 0:
  202. return 1
  203. bit_size = 0
  204. tmp = self._value
  205. while tmp:
  206. tmp >>= 1
  207. bit_size += 1
  208. return bit_size
  209. def size_in_bytes(self):
  210. return (self.size_in_bits() - 1) // 8 + 1
  211. def is_perfect_square(self):
  212. if self._value < 0:
  213. return False
  214. if self._value in (0, 1):
  215. return True
  216. x = self._value // 2
  217. square_x = x ** 2
  218. while square_x > self._value:
  219. x = (square_x + self._value) // (2 * x)
  220. square_x = x ** 2
  221. return self._value == x ** 2
  222. def fail_if_divisible_by(self, small_prime):
  223. if (self._value % int(small_prime)) == 0:
  224. raise ValueError("Value is composite")
  225. def multiply_accumulate(self, a, b):
  226. self._value += int(a) * int(b)
  227. return self
  228. def set(self, source):
  229. self._value = int(source)
  230. def inplace_inverse(self, modulus):
  231. modulus = int(modulus)
  232. if modulus == 0:
  233. raise ZeroDivisionError("Modulus cannot be zero")
  234. if modulus < 0:
  235. raise ValueError("Modulus cannot be negative")
  236. r_p, r_n = self._value, modulus
  237. s_p, s_n = 1, 0
  238. while r_n > 0:
  239. q = r_p // r_n
  240. r_p, r_n = r_n, r_p - q * r_n
  241. s_p, s_n = s_n, s_p - q * s_n
  242. if r_p != 1:
  243. raise ValueError("No inverse value can be computed" + str(r_p))
  244. while s_p < 0:
  245. s_p += modulus
  246. self._value = s_p
  247. return self
  248. def inverse(self, modulus):
  249. result = self.__class__(self)
  250. result.inplace_inverse(modulus)
  251. return result
  252. def gcd(self, term):
  253. r_p, r_n = abs(self._value), abs(int(term))
  254. while r_n > 0:
  255. q = r_p // r_n
  256. r_p, r_n = r_n, r_p - q * r_n
  257. return self.__class__(r_p)
  258. def lcm(self, term):
  259. term = int(term)
  260. if self._value == 0 or term == 0:
  261. return self.__class__(0)
  262. return self.__class__(abs((self._value * term) // self.gcd(term)._value))
  263. @staticmethod
  264. def jacobi_symbol(a, n):
  265. a = int(a)
  266. n = int(n)
  267. if n <= 0:
  268. raise ValueError("n must be a positive integer")
  269. if (n & 1) == 0:
  270. raise ValueError("n must be even for the Jacobi symbol")
  271. # Step 1
  272. a = a % n
  273. # Step 2
  274. if a == 1 or n == 1:
  275. return 1
  276. # Step 3
  277. if a == 0:
  278. return 0
  279. # Step 4
  280. e = 0
  281. a1 = a
  282. while (a1 & 1) == 0:
  283. a1 >>= 1
  284. e += 1
  285. # Step 5
  286. if (e & 1) == 0:
  287. s = 1
  288. elif n % 8 in (1, 7):
  289. s = 1
  290. else:
  291. s = -1
  292. # Step 6
  293. if n % 4 == 3 and a1 % 4 == 3:
  294. s = -s
  295. # Step 7
  296. n1 = n % a1
  297. # Step 8
  298. return s * IntegerNative.jacobi_symbol(n1, a1)