123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- # Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
- #
- # This file is part of paramiko.
- #
- # Paramiko is free software; you can redistribute it and/or modify it under the
- # terms of the GNU Lesser General Public License as published by the Free
- # Software Foundation; either version 2.1 of the License, or (at your option)
- # any later version.
- #
- # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
- # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
- # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
- # details.
- #
- # You should have received a copy of the GNU Lesser General Public License
- # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
- # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
- """
- ECDSA keys
- """
- from cryptography.exceptions import InvalidSignature
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes, serialization
- from cryptography.hazmat.primitives.asymmetric import ec
- from cryptography.hazmat.primitives.asymmetric.utils import (
- decode_dss_signature, encode_dss_signature
- )
- from paramiko.common import four_byte
- from paramiko.message import Message
- from paramiko.pkey import PKey
- from paramiko.ssh_exception import SSHException
- from paramiko.util import deflate_long
- class _ECDSACurve(object):
- """
- Represents a specific ECDSA Curve (nistp256, nistp384, etc).
- Handles the generation of the key format identifier and the selection of
- the proper hash function. Also grabs the proper curve from the 'ecdsa'
- package.
- """
- def __init__(self, curve_class, nist_name):
- self.nist_name = nist_name
- self.key_length = curve_class.key_size
- # Defined in RFC 5656 6.2
- self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
- # Defined in RFC 5656 6.2.1
- if self.key_length <= 256:
- self.hash_object = hashes.SHA256
- elif self.key_length <= 384:
- self.hash_object = hashes.SHA384
- else:
- self.hash_object = hashes.SHA512
- self.curve_class = curve_class
- class _ECDSACurveSet(object):
- """
- A collection to hold the ECDSA curves. Allows querying by oid and by key
- format identifier. The two ways in which ECDSAKey needs to be able to look
- up curves.
- """
- def __init__(self, ecdsa_curves):
- self.ecdsa_curves = ecdsa_curves
- def get_key_format_identifier_list(self):
- return [curve.key_format_identifier for curve in self.ecdsa_curves]
- def get_by_curve_class(self, curve_class):
- for curve in self.ecdsa_curves:
- if curve.curve_class == curve_class:
- return curve
- def get_by_key_format_identifier(self, key_format_identifier):
- for curve in self.ecdsa_curves:
- if curve.key_format_identifier == key_format_identifier:
- return curve
- def get_by_key_length(self, key_length):
- for curve in self.ecdsa_curves:
- if curve.key_length == key_length:
- return curve
- class ECDSAKey(PKey):
- """
- Representation of an ECDSA key which can be used to sign and verify SSH2
- data.
- """
- _ECDSA_CURVES = _ECDSACurveSet([
- _ECDSACurve(ec.SECP256R1, 'nistp256'),
- _ECDSACurve(ec.SECP384R1, 'nistp384'),
- _ECDSACurve(ec.SECP521R1, 'nistp521'),
- ])
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None, validate_point=True):
- self.verifying_key = None
- self.signing_key = None
- if file_obj is not None:
- self._from_private_key(file_obj, password)
- return
- if filename is not None:
- self._from_private_key_file(filename, password)
- return
- if (msg is None) and (data is not None):
- msg = Message(data)
- if vals is not None:
- self.signing_key, self.verifying_key = vals
- c_class = self.signing_key.curve.__class__
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
- else:
- if msg is None:
- raise SSHException('Key object may not be empty')
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
- msg.get_text())
- if self.ecdsa_curve is None:
- raise SSHException('Invalid key')
- curvename = msg.get_text()
- if curvename != self.ecdsa_curve.nist_name:
- raise SSHException("Can't handle curve of type %s" % curvename)
- pointinfo = msg.get_binary()
- try:
- numbers = ec.EllipticCurvePublicNumbers.from_encoded_point(
- self.ecdsa_curve.curve_class(), pointinfo
- )
- except ValueError:
- raise SSHException("Invalid public key")
- self.verifying_key = numbers.public_key(backend=default_backend())
- @classmethod
- def supported_key_format_identifiers(cls):
- return cls._ECDSA_CURVES.get_key_format_identifier_list()
- def asbytes(self):
- key = self.verifying_key
- m = Message()
- m.add_string(self.ecdsa_curve.key_format_identifier)
- m.add_string(self.ecdsa_curve.nist_name)
- numbers = key.public_numbers()
- key_size_bytes = (key.curve.key_size + 7) // 8
- x_bytes = deflate_long(numbers.x, add_sign_padding=False)
- x_bytes = b'\x00' * (key_size_bytes - len(x_bytes)) + x_bytes
- y_bytes = deflate_long(numbers.y, add_sign_padding=False)
- y_bytes = b'\x00' * (key_size_bytes - len(y_bytes)) + y_bytes
- point_str = four_byte + x_bytes + y_bytes
- m.add_string(point_str)
- return m.asbytes()
- def __str__(self):
- return self.asbytes()
- def __hash__(self):
- return hash((self.get_name(), self.verifying_key.public_numbers().x,
- self.verifying_key.public_numbers().y))
- def get_name(self):
- return self.ecdsa_curve.key_format_identifier
- def get_bits(self):
- return self.ecdsa_curve.key_length
- def can_sign(self):
- return self.signing_key is not None
- def sign_ssh_data(self, data):
- ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
- signer = self.signing_key.signer(ecdsa)
- signer.update(data)
- sig = signer.finalize()
- r, s = decode_dss_signature(sig)
- m = Message()
- m.add_string(self.ecdsa_curve.key_format_identifier)
- m.add_string(self._sigencode(r, s))
- return m
- def verify_ssh_sig(self, data, msg):
- if msg.get_text() != self.ecdsa_curve.key_format_identifier:
- return False
- sig = msg.get_binary()
- sigR, sigS = self._sigdecode(sig)
- signature = encode_dss_signature(sigR, sigS)
- verifier = self.verifying_key.verifier(
- signature, ec.ECDSA(self.ecdsa_curve.hash_object())
- )
- verifier.update(data)
- try:
- verifier.verify()
- except InvalidSignature:
- return False
- else:
- return True
- def write_private_key_file(self, filename, password=None):
- self._write_private_key_file(
- filename,
- self.signing_key,
- serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
- )
- def write_private_key(self, file_obj, password=None):
- self._write_private_key(
- file_obj,
- self.signing_key,
- serialization.PrivateFormat.TraditionalOpenSSL,
- password=password
- )
- @classmethod
- def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
- """
- Generate a new private ECDSA key. This factory function can be used to
- generate a new host key or authentication key.
- :param progress_func: Not used for this type of key.
- :returns: A new private key (`.ECDSAKey`) object
- """
- if bits is not None:
- curve = cls._ECDSA_CURVES.get_by_key_length(bits)
- if curve is None:
- raise ValueError("Unsupported key length: %d" % bits)
- curve = curve.curve_class()
- private_key = ec.generate_private_key(curve, backend=default_backend())
- return ECDSAKey(vals=(private_key, private_key.public_key()))
- # ...internals...
- def _from_private_key_file(self, filename, password):
- data = self._read_private_key_file('EC', filename, password)
- self._decode_key(data)
- def _from_private_key(self, file_obj, password):
- data = self._read_private_key('EC', file_obj, password)
- self._decode_key(data)
- def _decode_key(self, data):
- try:
- key = serialization.load_der_private_key(
- data, password=None, backend=default_backend()
- )
- except (ValueError, AssertionError) as e:
- raise SSHException(str(e))
- self.signing_key = key
- self.verifying_key = key.public_key()
- curve_class = key.curve.__class__
- self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
- def _sigencode(self, r, s):
- msg = Message()
- msg.add_mpint(r)
- msg.add_mpint(s)
- return msg.asbytes()
- def _sigdecode(self, sig):
- msg = Message(sig)
- r = msg.get_mpint()
- s = msg.get_mpint()
- return r, s
|