util.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
  2. #
  3. # This file is part of paramiko.
  4. #
  5. # Paramiko is free software; you can redistribute it and/or modify it under the
  6. # terms of the GNU Lesser General Public License as published by the Free
  7. # Software Foundation; either version 2.1 of the License, or (at your option)
  8. # any later version.
  9. #
  10. # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
  11. # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  12. # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
  13. # details.
  14. #
  15. # You should have received a copy of the GNU Lesser General Public License
  16. # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
  17. # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
  18. """
  19. Useful functions used by the rest of paramiko.
  20. """
  21. from __future__ import generators
  22. import errno
  23. import sys
  24. import struct
  25. import traceback
  26. import threading
  27. import logging
  28. from paramiko.common import DEBUG, zero_byte, xffffffff, max_byte
  29. from paramiko.py3compat import PY2, long, byte_chr, byte_ord, b
  30. from paramiko.config import SSHConfig
  31. def inflate_long(s, always_positive=False):
  32. """turns a normalized byte string into a long-int
  33. (adapted from Crypto.Util.number)"""
  34. out = long(0)
  35. negative = 0
  36. if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
  37. negative = 1
  38. if len(s) % 4:
  39. filler = zero_byte
  40. if negative:
  41. filler = max_byte
  42. # never convert this to ``s +=`` because this is a string, not a number
  43. # noinspection PyAugmentAssignment
  44. s = filler * (4 - len(s) % 4) + s
  45. for i in range(0, len(s), 4):
  46. out = (out << 32) + struct.unpack('>I', s[i:i + 4])[0]
  47. if negative:
  48. out -= (long(1) << (8 * len(s)))
  49. return out
  50. deflate_zero = zero_byte if PY2 else 0
  51. deflate_ff = max_byte if PY2 else 0xff
  52. def deflate_long(n, add_sign_padding=True):
  53. """turns a long-int into a normalized byte string
  54. (adapted from Crypto.Util.number)"""
  55. # after much testing, this algorithm was deemed to be the fastest
  56. s = bytes()
  57. n = long(n)
  58. while (n != 0) and (n != -1):
  59. s = struct.pack('>I', n & xffffffff) + s
  60. n >>= 32
  61. # strip off leading zeros, FFs
  62. for i in enumerate(s):
  63. if (n == 0) and (i[1] != deflate_zero):
  64. break
  65. if (n == -1) and (i[1] != deflate_ff):
  66. break
  67. else:
  68. # degenerate case, n was either 0 or -1
  69. i = (0,)
  70. if n == 0:
  71. s = zero_byte
  72. else:
  73. s = max_byte
  74. s = s[i[0]:]
  75. if add_sign_padding:
  76. if (n == 0) and (byte_ord(s[0]) >= 0x80):
  77. s = zero_byte + s
  78. if (n == -1) and (byte_ord(s[0]) < 0x80):
  79. s = max_byte + s
  80. return s
  81. def format_binary(data, prefix=''):
  82. x = 0
  83. out = []
  84. while len(data) > x + 16:
  85. out.append(format_binary_line(data[x:x + 16]))
  86. x += 16
  87. if x < len(data):
  88. out.append(format_binary_line(data[x:]))
  89. return [prefix + line for line in out]
  90. def format_binary_line(data):
  91. left = ' '.join(['%02X' % byte_ord(c) for c in data])
  92. right = ''.join([('.%c..' % c)[(byte_ord(c) + 63) // 95] for c in data])
  93. return '%-50s %s' % (left, right)
  94. def safe_string(s):
  95. out = b('')
  96. for c in s:
  97. i = byte_ord(c)
  98. if 32 <= i <= 127:
  99. out += byte_chr(i)
  100. else:
  101. out += b('%%%02X' % i)
  102. return out
  103. def bit_length(n):
  104. try:
  105. return n.bit_length()
  106. except AttributeError:
  107. norm = deflate_long(n, False)
  108. hbyte = byte_ord(norm[0])
  109. if hbyte == 0:
  110. return 1
  111. bitlen = len(norm) * 8
  112. while not (hbyte & 0x80):
  113. hbyte <<= 1
  114. bitlen -= 1
  115. return bitlen
  116. def tb_strings():
  117. return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')
  118. def generate_key_bytes(hash_alg, salt, key, nbytes):
  119. """
  120. Given a password, passphrase, or other human-source key, scramble it
  121. through a secure hash into some keyworthy bytes. This specific algorithm
  122. is used for encrypting/decrypting private key files.
  123. :param function hash_alg: A function which creates a new hash object, such
  124. as ``hashlib.sha256``.
  125. :param salt: data to salt the hash with.
  126. :type salt: byte string
  127. :param str key: human-entered password or passphrase.
  128. :param int nbytes: number of bytes to generate.
  129. :return: Key data `str`
  130. """
  131. keydata = bytes()
  132. digest = bytes()
  133. if len(salt) > 8:
  134. salt = salt[:8]
  135. while nbytes > 0:
  136. hash_obj = hash_alg()
  137. if len(digest) > 0:
  138. hash_obj.update(digest)
  139. hash_obj.update(b(key))
  140. hash_obj.update(salt)
  141. digest = hash_obj.digest()
  142. size = min(nbytes, len(digest))
  143. keydata += digest[:size]
  144. nbytes -= size
  145. return keydata
  146. def load_host_keys(filename):
  147. """
  148. Read a file of known SSH host keys, in the format used by openssh, and
  149. return a compound dict of ``hostname -> keytype ->`` `PKey
  150. <paramiko.pkey.PKey>`. The hostname may be an IP address or DNS name. The
  151. keytype will be either ``"ssh-rsa"`` or ``"ssh-dss"``.
  152. This type of file unfortunately doesn't exist on Windows, but on posix,
  153. it will usually be stored in ``os.path.expanduser("~/.ssh/known_hosts")``.
  154. Since 1.5.3, this is just a wrapper around `.HostKeys`.
  155. :param str filename: name of the file to read host keys from
  156. :return:
  157. nested dict of `.PKey` objects, indexed by hostname and then keytype
  158. """
  159. from paramiko.hostkeys import HostKeys
  160. return HostKeys(filename)
  161. def parse_ssh_config(file_obj):
  162. """
  163. Provided only as a backward-compatible wrapper around `.SSHConfig`.
  164. """
  165. config = SSHConfig()
  166. config.parse(file_obj)
  167. return config
  168. def lookup_ssh_host_config(hostname, config):
  169. """
  170. Provided only as a backward-compatible wrapper around `.SSHConfig`.
  171. """
  172. return config.lookup(hostname)
  173. def mod_inverse(x, m):
  174. # it's crazy how small Python can make this function.
  175. u1, u2, u3 = 1, 0, m
  176. v1, v2, v3 = 0, 1, x
  177. while v3 > 0:
  178. q = u3 // v3
  179. u1, v1 = v1, u1 - v1 * q
  180. u2, v2 = v2, u2 - v2 * q
  181. u3, v3 = v3, u3 - v3 * q
  182. if u2 < 0:
  183. u2 += m
  184. return u2
  185. _g_thread_ids = {}
  186. _g_thread_counter = 0
  187. _g_thread_lock = threading.Lock()
  188. def get_thread_id():
  189. global _g_thread_ids, _g_thread_counter, _g_thread_lock
  190. tid = id(threading.currentThread())
  191. try:
  192. return _g_thread_ids[tid]
  193. except KeyError:
  194. _g_thread_lock.acquire()
  195. try:
  196. _g_thread_counter += 1
  197. ret = _g_thread_ids[tid] = _g_thread_counter
  198. finally:
  199. _g_thread_lock.release()
  200. return ret
  201. def log_to_file(filename, level=DEBUG):
  202. """send paramiko logs to a logfile,
  203. if they're not already going somewhere"""
  204. l = logging.getLogger("paramiko")
  205. if len(l.handlers) > 0:
  206. return
  207. l.setLevel(level)
  208. f = open(filename, 'a')
  209. lh = logging.StreamHandler(f)
  210. frm = '%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s' # noqa
  211. lh.setFormatter(logging.Formatter(frm, '%Y%m%d-%H:%M:%S'))
  212. l.addHandler(lh)
  213. # make only one filter object, so it doesn't get applied more than once
  214. class PFilter (object):
  215. def filter(self, record):
  216. record._threadid = get_thread_id()
  217. return True
  218. _pfilter = PFilter()
  219. def get_logger(name):
  220. l = logging.getLogger(name)
  221. l.addFilter(_pfilter)
  222. return l
  223. def retry_on_signal(function):
  224. """Retries function until it doesn't raise an EINTR error"""
  225. while True:
  226. try:
  227. return function()
  228. except EnvironmentError as e:
  229. if e.errno != errno.EINTR:
  230. raise
  231. def constant_time_bytes_eq(a, b):
  232. if len(a) != len(b):
  233. return False
  234. res = 0
  235. # noinspection PyUnresolvedReferences
  236. for i in (xrange if PY2 else range)(len(a)): # noqa: F821
  237. res |= byte_ord(a[i]) ^ byte_ord(b[i])
  238. return res == 0
  239. class ClosingContextManager(object):
  240. def __enter__(self):
  241. return self
  242. def __exit__(self, type, value, traceback):
  243. self.close()
  244. def clamp_value(minimum, val, maximum):
  245. return max(minimum, min(val, maximum))