ssh.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import absolute_import, division, print_function
  5. import binascii
  6. import os
  7. import re
  8. import struct
  9. import six
  10. from cryptography import utils
  11. from cryptography.exceptions import UnsupportedAlgorithm
  12. from cryptography.hazmat.backends import _get_backend
  13. from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
  14. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  15. from cryptography.hazmat.primitives.serialization import (
  16. Encoding,
  17. NoEncryption,
  18. PrivateFormat,
  19. PublicFormat,
  20. )
  21. try:
  22. from bcrypt import kdf as _bcrypt_kdf
  23. _bcrypt_supported = True
  24. except ImportError:
  25. _bcrypt_supported = False
  26. def _bcrypt_kdf(*args, **kwargs):
  27. raise UnsupportedAlgorithm("Need bcrypt module")
  28. try:
  29. from base64 import encodebytes as _base64_encode
  30. except ImportError:
  31. from base64 import encodestring as _base64_encode
  32. _SSH_ED25519 = b"ssh-ed25519"
  33. _SSH_RSA = b"ssh-rsa"
  34. _SSH_DSA = b"ssh-dss"
  35. _ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
  36. _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
  37. _ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
  38. _CERT_SUFFIX = b"-cert-v01@openssh.com"
  39. _SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
  40. _SK_MAGIC = b"openssh-key-v1\0"
  41. _SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
  42. _SK_END = b"-----END OPENSSH PRIVATE KEY-----"
  43. _BCRYPT = b"bcrypt"
  44. _NONE = b"none"
  45. _DEFAULT_CIPHER = b"aes256-ctr"
  46. _DEFAULT_ROUNDS = 16
  47. _MAX_PASSWORD = 72
  48. # re is only way to work on bytes-like data
  49. _PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
  50. # padding for max blocksize
  51. _PADDING = memoryview(bytearray(range(1, 1 + 16)))
  52. # ciphers that are actually used in key wrapping
  53. _SSH_CIPHERS = {
  54. b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
  55. b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
  56. }
  57. # map local curve name to key type
  58. _ECDSA_KEY_TYPE = {
  59. "secp256r1": _ECDSA_NISTP256,
  60. "secp384r1": _ECDSA_NISTP384,
  61. "secp521r1": _ECDSA_NISTP521,
  62. }
  63. _U32 = struct.Struct(b">I")
  64. _U64 = struct.Struct(b">Q")
  65. def _ecdsa_key_type(public_key):
  66. """Return SSH key_type and curve_name for private key."""
  67. curve = public_key.curve
  68. if curve.name not in _ECDSA_KEY_TYPE:
  69. raise ValueError(
  70. "Unsupported curve for ssh private key: %r" % curve.name
  71. )
  72. return _ECDSA_KEY_TYPE[curve.name]
  73. def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"):
  74. return b"".join([prefix, _base64_encode(data), suffix])
  75. def _check_block_size(data, block_len):
  76. """Require data to be full blocks"""
  77. if not data or len(data) % block_len != 0:
  78. raise ValueError("Corrupt data: missing padding")
  79. def _check_empty(data):
  80. """All data should have been parsed."""
  81. if data:
  82. raise ValueError("Corrupt data: unparsed data")
  83. def _init_cipher(ciphername, password, salt, rounds, backend):
  84. """Generate key + iv and return cipher."""
  85. if not password:
  86. raise ValueError("Key is password-protected.")
  87. algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
  88. seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
  89. return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend)
  90. def _get_u32(data):
  91. """Uint32"""
  92. if len(data) < 4:
  93. raise ValueError("Invalid data")
  94. return _U32.unpack(data[:4])[0], data[4:]
  95. def _get_u64(data):
  96. """Uint64"""
  97. if len(data) < 8:
  98. raise ValueError("Invalid data")
  99. return _U64.unpack(data[:8])[0], data[8:]
  100. def _get_sshstr(data):
  101. """Bytes with u32 length prefix"""
  102. n, data = _get_u32(data)
  103. if n > len(data):
  104. raise ValueError("Invalid data")
  105. return data[:n], data[n:]
  106. def _get_mpint(data):
  107. """Big integer."""
  108. val, data = _get_sshstr(data)
  109. if val and six.indexbytes(val, 0) > 0x7F:
  110. raise ValueError("Invalid data")
  111. return utils.int_from_bytes(val, "big"), data
  112. def _to_mpint(val):
  113. """Storage format for signed bigint."""
  114. if val < 0:
  115. raise ValueError("negative mpint not allowed")
  116. if not val:
  117. return b""
  118. nbytes = (val.bit_length() + 8) // 8
  119. return utils.int_to_bytes(val, nbytes)
  120. class _FragList(object):
  121. """Build recursive structure without data copy."""
  122. def __init__(self, init=None):
  123. self.flist = []
  124. if init:
  125. self.flist.extend(init)
  126. def put_raw(self, val):
  127. """Add plain bytes"""
  128. self.flist.append(val)
  129. def put_u32(self, val):
  130. """Big-endian uint32"""
  131. self.flist.append(_U32.pack(val))
  132. def put_sshstr(self, val):
  133. """Bytes prefixed with u32 length"""
  134. if isinstance(val, (bytes, memoryview, bytearray)):
  135. self.put_u32(len(val))
  136. self.flist.append(val)
  137. else:
  138. self.put_u32(val.size())
  139. self.flist.extend(val.flist)
  140. def put_mpint(self, val):
  141. """Big-endian bigint prefixed with u32 length"""
  142. self.put_sshstr(_to_mpint(val))
  143. def size(self):
  144. """Current number of bytes"""
  145. return sum(map(len, self.flist))
  146. def render(self, dstbuf, pos=0):
  147. """Write into bytearray"""
  148. for frag in self.flist:
  149. flen = len(frag)
  150. start, pos = pos, pos + flen
  151. dstbuf[start:pos] = frag
  152. return pos
  153. def tobytes(self):
  154. """Return as bytes"""
  155. buf = memoryview(bytearray(self.size()))
  156. self.render(buf)
  157. return buf.tobytes()
  158. class _SSHFormatRSA(object):
  159. """Format for RSA keys.
  160. Public:
  161. mpint e, n
  162. Private:
  163. mpint n, e, d, iqmp, p, q
  164. """
  165. def get_public(self, data):
  166. """RSA public fields"""
  167. e, data = _get_mpint(data)
  168. n, data = _get_mpint(data)
  169. return (e, n), data
  170. def load_public(self, key_type, data, backend):
  171. """Make RSA public key from data."""
  172. (e, n), data = self.get_public(data)
  173. public_numbers = rsa.RSAPublicNumbers(e, n)
  174. public_key = public_numbers.public_key(backend)
  175. return public_key, data
  176. def load_private(self, data, pubfields, backend):
  177. """Make RSA private key from data."""
  178. n, data = _get_mpint(data)
  179. e, data = _get_mpint(data)
  180. d, data = _get_mpint(data)
  181. iqmp, data = _get_mpint(data)
  182. p, data = _get_mpint(data)
  183. q, data = _get_mpint(data)
  184. if (e, n) != pubfields:
  185. raise ValueError("Corrupt data: rsa field mismatch")
  186. dmp1 = rsa.rsa_crt_dmp1(d, p)
  187. dmq1 = rsa.rsa_crt_dmq1(d, q)
  188. public_numbers = rsa.RSAPublicNumbers(e, n)
  189. private_numbers = rsa.RSAPrivateNumbers(
  190. p, q, d, dmp1, dmq1, iqmp, public_numbers
  191. )
  192. private_key = private_numbers.private_key(backend)
  193. return private_key, data
  194. def encode_public(self, public_key, f_pub):
  195. """Write RSA public key"""
  196. pubn = public_key.public_numbers()
  197. f_pub.put_mpint(pubn.e)
  198. f_pub.put_mpint(pubn.n)
  199. def encode_private(self, private_key, f_priv):
  200. """Write RSA private key"""
  201. private_numbers = private_key.private_numbers()
  202. public_numbers = private_numbers.public_numbers
  203. f_priv.put_mpint(public_numbers.n)
  204. f_priv.put_mpint(public_numbers.e)
  205. f_priv.put_mpint(private_numbers.d)
  206. f_priv.put_mpint(private_numbers.iqmp)
  207. f_priv.put_mpint(private_numbers.p)
  208. f_priv.put_mpint(private_numbers.q)
  209. class _SSHFormatDSA(object):
  210. """Format for DSA keys.
  211. Public:
  212. mpint p, q, g, y
  213. Private:
  214. mpint p, q, g, y, x
  215. """
  216. def get_public(self, data):
  217. """DSA public fields"""
  218. p, data = _get_mpint(data)
  219. q, data = _get_mpint(data)
  220. g, data = _get_mpint(data)
  221. y, data = _get_mpint(data)
  222. return (p, q, g, y), data
  223. def load_public(self, key_type, data, backend):
  224. """Make DSA public key from data."""
  225. (p, q, g, y), data = self.get_public(data)
  226. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  227. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  228. self._validate(public_numbers)
  229. public_key = public_numbers.public_key(backend)
  230. return public_key, data
  231. def load_private(self, data, pubfields, backend):
  232. """Make DSA private key from data."""
  233. (p, q, g, y), data = self.get_public(data)
  234. x, data = _get_mpint(data)
  235. if (p, q, g, y) != pubfields:
  236. raise ValueError("Corrupt data: dsa field mismatch")
  237. parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
  238. public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
  239. self._validate(public_numbers)
  240. private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
  241. private_key = private_numbers.private_key(backend)
  242. return private_key, data
  243. def encode_public(self, public_key, f_pub):
  244. """Write DSA public key"""
  245. public_numbers = public_key.public_numbers()
  246. parameter_numbers = public_numbers.parameter_numbers
  247. self._validate(public_numbers)
  248. f_pub.put_mpint(parameter_numbers.p)
  249. f_pub.put_mpint(parameter_numbers.q)
  250. f_pub.put_mpint(parameter_numbers.g)
  251. f_pub.put_mpint(public_numbers.y)
  252. def encode_private(self, private_key, f_priv):
  253. """Write DSA private key"""
  254. self.encode_public(private_key.public_key(), f_priv)
  255. f_priv.put_mpint(private_key.private_numbers().x)
  256. def _validate(self, public_numbers):
  257. parameter_numbers = public_numbers.parameter_numbers
  258. if parameter_numbers.p.bit_length() != 1024:
  259. raise ValueError("SSH supports only 1024 bit DSA keys")
  260. class _SSHFormatECDSA(object):
  261. """Format for ECDSA keys.
  262. Public:
  263. str curve
  264. bytes point
  265. Private:
  266. str curve
  267. bytes point
  268. mpint secret
  269. """
  270. def __init__(self, ssh_curve_name, curve):
  271. self.ssh_curve_name = ssh_curve_name
  272. self.curve = curve
  273. def get_public(self, data):
  274. """ECDSA public fields"""
  275. curve, data = _get_sshstr(data)
  276. point, data = _get_sshstr(data)
  277. if curve != self.ssh_curve_name:
  278. raise ValueError("Curve name mismatch")
  279. if six.indexbytes(point, 0) != 4:
  280. raise NotImplementedError("Need uncompressed point")
  281. return (curve, point), data
  282. def load_public(self, key_type, data, backend):
  283. """Make ECDSA public key from data."""
  284. (curve_name, point), data = self.get_public(data)
  285. public_key = ec.EllipticCurvePublicKey.from_encoded_point(
  286. self.curve, point.tobytes()
  287. )
  288. return public_key, data
  289. def load_private(self, data, pubfields, backend):
  290. """Make ECDSA private key from data."""
  291. (curve_name, point), data = self.get_public(data)
  292. secret, data = _get_mpint(data)
  293. if (curve_name, point) != pubfields:
  294. raise ValueError("Corrupt data: ecdsa field mismatch")
  295. private_key = ec.derive_private_key(secret, self.curve, backend)
  296. return private_key, data
  297. def encode_public(self, public_key, f_pub):
  298. """Write ECDSA public key"""
  299. point = public_key.public_bytes(
  300. Encoding.X962, PublicFormat.UncompressedPoint
  301. )
  302. f_pub.put_sshstr(self.ssh_curve_name)
  303. f_pub.put_sshstr(point)
  304. def encode_private(self, private_key, f_priv):
  305. """Write ECDSA private key"""
  306. public_key = private_key.public_key()
  307. private_numbers = private_key.private_numbers()
  308. self.encode_public(public_key, f_priv)
  309. f_priv.put_mpint(private_numbers.private_value)
  310. class _SSHFormatEd25519(object):
  311. """Format for Ed25519 keys.
  312. Public:
  313. bytes point
  314. Private:
  315. bytes point
  316. bytes secret_and_point
  317. """
  318. def get_public(self, data):
  319. """Ed25519 public fields"""
  320. point, data = _get_sshstr(data)
  321. return (point,), data
  322. def load_public(self, key_type, data, backend):
  323. """Make Ed25519 public key from data."""
  324. (point,), data = self.get_public(data)
  325. public_key = ed25519.Ed25519PublicKey.from_public_bytes(
  326. point.tobytes()
  327. )
  328. return public_key, data
  329. def load_private(self, data, pubfields, backend):
  330. """Make Ed25519 private key from data."""
  331. (point,), data = self.get_public(data)
  332. keypair, data = _get_sshstr(data)
  333. secret = keypair[:32]
  334. point2 = keypair[32:]
  335. if point != point2 or (point,) != pubfields:
  336. raise ValueError("Corrupt data: ed25519 field mismatch")
  337. private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
  338. return private_key, data
  339. def encode_public(self, public_key, f_pub):
  340. """Write Ed25519 public key"""
  341. raw_public_key = public_key.public_bytes(
  342. Encoding.Raw, PublicFormat.Raw
  343. )
  344. f_pub.put_sshstr(raw_public_key)
  345. def encode_private(self, private_key, f_priv):
  346. """Write Ed25519 private key"""
  347. public_key = private_key.public_key()
  348. raw_private_key = private_key.private_bytes(
  349. Encoding.Raw, PrivateFormat.Raw, NoEncryption()
  350. )
  351. raw_public_key = public_key.public_bytes(
  352. Encoding.Raw, PublicFormat.Raw
  353. )
  354. f_keypair = _FragList([raw_private_key, raw_public_key])
  355. self.encode_public(public_key, f_priv)
  356. f_priv.put_sshstr(f_keypair)
  357. _KEY_FORMATS = {
  358. _SSH_RSA: _SSHFormatRSA(),
  359. _SSH_DSA: _SSHFormatDSA(),
  360. _SSH_ED25519: _SSHFormatEd25519(),
  361. _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
  362. _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
  363. _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
  364. }
  365. def _lookup_kformat(key_type):
  366. """Return valid format or throw error"""
  367. if not isinstance(key_type, bytes):
  368. key_type = memoryview(key_type).tobytes()
  369. if key_type in _KEY_FORMATS:
  370. return _KEY_FORMATS[key_type]
  371. raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type)
  372. def load_ssh_private_key(data, password, backend=None):
  373. """Load private key from OpenSSH custom encoding."""
  374. utils._check_byteslike("data", data)
  375. backend = _get_backend(backend)
  376. if password is not None:
  377. utils._check_bytes("password", password)
  378. m = _PEM_RC.search(data)
  379. if not m:
  380. raise ValueError("Not OpenSSH private key format")
  381. p1 = m.start(1)
  382. p2 = m.end(1)
  383. data = binascii.a2b_base64(memoryview(data)[p1:p2])
  384. if not data.startswith(_SK_MAGIC):
  385. raise ValueError("Not OpenSSH private key format")
  386. data = memoryview(data)[len(_SK_MAGIC) :]
  387. # parse header
  388. ciphername, data = _get_sshstr(data)
  389. kdfname, data = _get_sshstr(data)
  390. kdfoptions, data = _get_sshstr(data)
  391. nkeys, data = _get_u32(data)
  392. if nkeys != 1:
  393. raise ValueError("Only one key supported")
  394. # load public key data
  395. pubdata, data = _get_sshstr(data)
  396. pub_key_type, pubdata = _get_sshstr(pubdata)
  397. kformat = _lookup_kformat(pub_key_type)
  398. pubfields, pubdata = kformat.get_public(pubdata)
  399. _check_empty(pubdata)
  400. # load secret data
  401. edata, data = _get_sshstr(data)
  402. _check_empty(data)
  403. if (ciphername, kdfname) != (_NONE, _NONE):
  404. ciphername = ciphername.tobytes()
  405. if ciphername not in _SSH_CIPHERS:
  406. raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername)
  407. if kdfname != _BCRYPT:
  408. raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname)
  409. blklen = _SSH_CIPHERS[ciphername][3]
  410. _check_block_size(edata, blklen)
  411. salt, kbuf = _get_sshstr(kdfoptions)
  412. rounds, kbuf = _get_u32(kbuf)
  413. _check_empty(kbuf)
  414. ciph = _init_cipher(
  415. ciphername, password, salt.tobytes(), rounds, backend
  416. )
  417. edata = memoryview(ciph.decryptor().update(edata))
  418. else:
  419. blklen = 8
  420. _check_block_size(edata, blklen)
  421. ck1, edata = _get_u32(edata)
  422. ck2, edata = _get_u32(edata)
  423. if ck1 != ck2:
  424. raise ValueError("Corrupt data: broken checksum")
  425. # load per-key struct
  426. key_type, edata = _get_sshstr(edata)
  427. if key_type != pub_key_type:
  428. raise ValueError("Corrupt data: key type mismatch")
  429. private_key, edata = kformat.load_private(edata, pubfields, backend)
  430. comment, edata = _get_sshstr(edata)
  431. # yes, SSH does padding check *after* all other parsing is done.
  432. # need to follow as it writes zero-byte padding too.
  433. if edata != _PADDING[: len(edata)]:
  434. raise ValueError("Corrupt data: invalid padding")
  435. return private_key
  436. def serialize_ssh_private_key(private_key, password=None):
  437. """Serialize private key with OpenSSH custom encoding."""
  438. if password is not None:
  439. utils._check_bytes("password", password)
  440. if password and len(password) > _MAX_PASSWORD:
  441. raise ValueError(
  442. "Passwords longer than 72 bytes are not supported by "
  443. "OpenSSH private key format"
  444. )
  445. if isinstance(private_key, ec.EllipticCurvePrivateKey):
  446. key_type = _ecdsa_key_type(private_key.public_key())
  447. elif isinstance(private_key, rsa.RSAPrivateKey):
  448. key_type = _SSH_RSA
  449. elif isinstance(private_key, dsa.DSAPrivateKey):
  450. key_type = _SSH_DSA
  451. elif isinstance(private_key, ed25519.Ed25519PrivateKey):
  452. key_type = _SSH_ED25519
  453. else:
  454. raise ValueError("Unsupported key type")
  455. kformat = _lookup_kformat(key_type)
  456. # setup parameters
  457. f_kdfoptions = _FragList()
  458. if password:
  459. ciphername = _DEFAULT_CIPHER
  460. blklen = _SSH_CIPHERS[ciphername][3]
  461. kdfname = _BCRYPT
  462. rounds = _DEFAULT_ROUNDS
  463. salt = os.urandom(16)
  464. f_kdfoptions.put_sshstr(salt)
  465. f_kdfoptions.put_u32(rounds)
  466. backend = _get_backend(None)
  467. ciph = _init_cipher(ciphername, password, salt, rounds, backend)
  468. else:
  469. ciphername = kdfname = _NONE
  470. blklen = 8
  471. ciph = None
  472. nkeys = 1
  473. checkval = os.urandom(4)
  474. comment = b""
  475. # encode public and private parts together
  476. f_public_key = _FragList()
  477. f_public_key.put_sshstr(key_type)
  478. kformat.encode_public(private_key.public_key(), f_public_key)
  479. f_secrets = _FragList([checkval, checkval])
  480. f_secrets.put_sshstr(key_type)
  481. kformat.encode_private(private_key, f_secrets)
  482. f_secrets.put_sshstr(comment)
  483. f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
  484. # top-level structure
  485. f_main = _FragList()
  486. f_main.put_raw(_SK_MAGIC)
  487. f_main.put_sshstr(ciphername)
  488. f_main.put_sshstr(kdfname)
  489. f_main.put_sshstr(f_kdfoptions)
  490. f_main.put_u32(nkeys)
  491. f_main.put_sshstr(f_public_key)
  492. f_main.put_sshstr(f_secrets)
  493. # copy result info bytearray
  494. slen = f_secrets.size()
  495. mlen = f_main.size()
  496. buf = memoryview(bytearray(mlen + blklen))
  497. f_main.render(buf)
  498. ofs = mlen - slen
  499. # encrypt in-place
  500. if ciph is not None:
  501. ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
  502. txt = _ssh_pem_encode(buf[:mlen])
  503. buf[ofs:mlen] = bytearray(slen)
  504. return txt
  505. def load_ssh_public_key(data, backend=None):
  506. """Load public key from OpenSSH one-line format."""
  507. backend = _get_backend(backend)
  508. utils._check_byteslike("data", data)
  509. m = _SSH_PUBKEY_RC.match(data)
  510. if not m:
  511. raise ValueError("Invalid line format")
  512. key_type = orig_key_type = m.group(1)
  513. key_body = m.group(2)
  514. with_cert = False
  515. if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
  516. with_cert = True
  517. key_type = key_type[: -len(_CERT_SUFFIX)]
  518. kformat = _lookup_kformat(key_type)
  519. try:
  520. data = memoryview(binascii.a2b_base64(key_body))
  521. except (TypeError, binascii.Error):
  522. raise ValueError("Invalid key format")
  523. inner_key_type, data = _get_sshstr(data)
  524. if inner_key_type != orig_key_type:
  525. raise ValueError("Invalid key format")
  526. if with_cert:
  527. nonce, data = _get_sshstr(data)
  528. public_key, data = kformat.load_public(key_type, data, backend)
  529. if with_cert:
  530. serial, data = _get_u64(data)
  531. cctype, data = _get_u32(data)
  532. key_id, data = _get_sshstr(data)
  533. principals, data = _get_sshstr(data)
  534. valid_after, data = _get_u64(data)
  535. valid_before, data = _get_u64(data)
  536. crit_options, data = _get_sshstr(data)
  537. extensions, data = _get_sshstr(data)
  538. reserved, data = _get_sshstr(data)
  539. sig_key, data = _get_sshstr(data)
  540. signature, data = _get_sshstr(data)
  541. _check_empty(data)
  542. return public_key
  543. def serialize_ssh_public_key(public_key):
  544. """One-line public key format for OpenSSH"""
  545. if isinstance(public_key, ec.EllipticCurvePublicKey):
  546. key_type = _ecdsa_key_type(public_key)
  547. elif isinstance(public_key, rsa.RSAPublicKey):
  548. key_type = _SSH_RSA
  549. elif isinstance(public_key, dsa.DSAPublicKey):
  550. key_type = _SSH_DSA
  551. elif isinstance(public_key, ed25519.Ed25519PublicKey):
  552. key_type = _SSH_ED25519
  553. else:
  554. raise ValueError("Unsupported key type")
  555. kformat = _lookup_kformat(key_type)
  556. f_pub = _FragList()
  557. f_pub.put_sshstr(key_type)
  558. kformat.encode_public(public_key, f_pub)
  559. pub = binascii.b2a_base64(f_pub.tobytes()).strip()
  560. return b"".join([key_type, b" ", pub])