123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- # -*- coding: utf-8 -*-
- #
- # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # https://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from rsa._compat import zip
- """Common functionality shared by several modules."""
- class NotRelativePrimeError(ValueError):
- def __init__(self, a, b, d, msg=None):
- super(NotRelativePrimeError, self).__init__(
- msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
- self.a = a
- self.b = b
- self.d = d
- def bit_size(num):
- """
- Number of bits needed to represent a integer excluding any prefix
- 0 bits.
- Usage::
- >>> bit_size(1023)
- 10
- >>> bit_size(1024)
- 11
- >>> bit_size(1025)
- 11
- :param num:
- Integer value. If num is 0, returns 0. Only the absolute value of the
- number is considered. Therefore, signed integers will be abs(num)
- before the number's bit length is determined.
- :returns:
- Returns the number of bits in the integer.
- """
- try:
- return num.bit_length()
- except AttributeError:
- raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
- def byte_size(number):
- """
- Returns the number of bytes required to hold a specific long number.
- The number of bytes is rounded up.
- Usage::
- >>> byte_size(1 << 1023)
- 128
- >>> byte_size((1 << 1024) - 1)
- 128
- >>> byte_size(1 << 1024)
- 129
- :param number:
- An unsigned integer
- :returns:
- The number of bytes required to hold a specific long number.
- """
- if number == 0:
- return 1
- return ceil_div(bit_size(number), 8)
- def ceil_div(num, div):
- """
- Returns the ceiling function of a division between `num` and `div`.
- Usage::
- >>> ceil_div(100, 7)
- 15
- >>> ceil_div(100, 10)
- 10
- >>> ceil_div(1, 4)
- 1
- :param num: Division's numerator, a number
- :param div: Division's divisor, a number
- :return: Rounded up result of the division between the parameters.
- """
- quanta, mod = divmod(num, div)
- if mod:
- quanta += 1
- return quanta
- def extended_gcd(a, b):
- """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
- """
- # r = gcd(a,b) i = multiplicitive inverse of a mod b
- # or j = multiplicitive inverse of b mod a
- # Neg return values for i or j are made positive mod b or a respectively
- # Iterateive Version is faster and uses much less stack space
- x = 0
- y = 1
- lx = 1
- ly = 0
- oa = a # Remember original a/b to remove
- ob = b # negative values from return results
- while b != 0:
- q = a // b
- (a, b) = (b, a % b)
- (x, lx) = ((lx - (q * x)), x)
- (y, ly) = ((ly - (q * y)), y)
- if lx < 0:
- lx += ob # If neg wrap modulo orignal b
- if ly < 0:
- ly += oa # If neg wrap modulo orignal a
- return a, lx, ly # Return only positive values
- def inverse(x, n):
- """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
- >>> inverse(7, 4)
- 3
- >>> (inverse(143, 4) * 143) % 4
- 1
- """
- (divider, inv, _) = extended_gcd(x, n)
- if divider != 1:
- raise NotRelativePrimeError(x, n, divider)
- return inv
- def crt(a_values, modulo_values):
- """Chinese Remainder Theorem.
- Calculates x such that x = a[i] (mod m[i]) for each i.
- :param a_values: the a-values of the above equation
- :param modulo_values: the m-values of the above equation
- :returns: x such that x = a[i] (mod m[i]) for each i
- >>> crt([2, 3], [3, 5])
- 8
- >>> crt([2, 3, 2], [3, 5, 7])
- 23
- >>> crt([2, 3, 0], [7, 11, 15])
- 135
- """
- m = 1
- x = 0
- for modulo in modulo_values:
- m *= modulo
- for (m_i, a_i) in zip(modulo_values, a_values):
- M_i = m // m_i
- inv = inverse(M_i, m_i)
- x = (x + a_i * M_i * inv) % m
- return x
- if __name__ == '__main__':
- import doctest
- doctest.testmod()
|