test_default.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for L{twisted.conch.client.default}.
  5. """
  6. from __future__ import absolute_import, division
  7. import sys
  8. from twisted.python.reflect import requireModule
  9. if requireModule('cryptography') and requireModule('pyasn1'):
  10. from twisted.conch.client.agent import SSHAgentClient
  11. from twisted.conch.client.default import SSHUserAuthClient
  12. from twisted.conch.client.options import ConchOptions
  13. from twisted.conch.client import default
  14. from twisted.conch.ssh.keys import Key
  15. skip = None
  16. else:
  17. skip = "cryptography and PyASN1 required for twisted.conch.client.default."
  18. from twisted.trial.unittest import TestCase
  19. from twisted.python.filepath import FilePath
  20. from twisted.conch.error import ConchError
  21. from twisted.conch.test import keydata
  22. from twisted.test.proto_helpers import StringTransport
  23. from twisted.python.compat import nativeString
  24. from twisted.python.runtime import platform
  25. if platform.isWindows():
  26. windowsSkip = (
  27. "genericAnswers and getPassword does not work on Windows."
  28. " Should be fixed as part of fixing bug 6409 and 6410")
  29. else:
  30. windowsSkip = skip
  31. ttySkip = None
  32. if not sys.stdin.isatty():
  33. ttySkip = "sys.stdin is not an interactive tty"
  34. if not sys.stdout.isatty():
  35. ttySkip = "sys.stdout is not an interactive tty"
  36. class SSHUserAuthClientTests(TestCase):
  37. """
  38. Tests for L{SSHUserAuthClient}.
  39. @type rsaPublic: L{Key}
  40. @ivar rsaPublic: A public RSA key.
  41. """
  42. def setUp(self):
  43. self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
  44. self.tmpdir = FilePath(self.mktemp())
  45. self.tmpdir.makedirs()
  46. self.rsaFile = self.tmpdir.child('id_rsa')
  47. self.rsaFile.setContent(keydata.privateRSA_openssh)
  48. self.tmpdir.child('id_rsa.pub').setContent(keydata.publicRSA_openssh)
  49. def test_signDataWithAgent(self):
  50. """
  51. When connected to an agent, L{SSHUserAuthClient} can use it to
  52. request signatures of particular data with a particular L{Key}.
  53. """
  54. client = SSHUserAuthClient(b"user", ConchOptions(), None)
  55. agent = SSHAgentClient()
  56. transport = StringTransport()
  57. agent.makeConnection(transport)
  58. client.keyAgent = agent
  59. cleartext = b"Sign here"
  60. client.signData(self.rsaPublic, cleartext)
  61. self.assertEqual(
  62. transport.value(),
  63. b"\x00\x00\x00\x8b\r\x00\x00\x00u" + self.rsaPublic.blob() +
  64. b"\x00\x00\x00\t" + cleartext +
  65. b"\x00\x00\x00\x00")
  66. def test_agentGetPublicKey(self):
  67. """
  68. L{SSHUserAuthClient} looks up public keys from the agent using the
  69. L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
  70. L{Key} object with one of the public keys in the agent. If no more
  71. keys are present, it returns L{None}.
  72. """
  73. agent = SSHAgentClient()
  74. agent.blobs = [self.rsaPublic.blob()]
  75. key = agent.getPublicKey()
  76. self.assertTrue(key.isPublic())
  77. self.assertEqual(key, self.rsaPublic)
  78. self.assertIsNone(agent.getPublicKey())
  79. def test_getPublicKeyFromFile(self):
  80. """
  81. L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
  82. the first file described by its options' C{identitys} list, and return
  83. the corresponding public L{Key} object.
  84. """
  85. options = ConchOptions()
  86. options.identitys = [self.rsaFile.path]
  87. client = SSHUserAuthClient(b"user", options, None)
  88. key = client.getPublicKey()
  89. self.assertTrue(key.isPublic())
  90. self.assertEqual(key, self.rsaPublic)
  91. def test_getPublicKeyAgentFallback(self):
  92. """
  93. If an agent is present, but doesn't return a key,
  94. L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
  95. """
  96. options = ConchOptions()
  97. options.identitys = [self.rsaFile.path]
  98. agent = SSHAgentClient()
  99. client = SSHUserAuthClient(b"user", options, None)
  100. client.keyAgent = agent
  101. key = client.getPublicKey()
  102. self.assertTrue(key.isPublic())
  103. self.assertEqual(key, self.rsaPublic)
  104. def test_getPublicKeyBadKeyError(self):
  105. """
  106. If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
  107. L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
  108. calling itself recursively.
  109. """
  110. options = ConchOptions()
  111. self.tmpdir.child('id_dsa.pub').setContent(keydata.publicDSA_openssh)
  112. dsaFile = self.tmpdir.child('id_dsa')
  113. dsaFile.setContent(keydata.privateDSA_openssh)
  114. options.identitys = [self.rsaFile.path, dsaFile.path]
  115. self.tmpdir.child('id_rsa.pub').setContent(b'not a key!')
  116. client = SSHUserAuthClient(b"user", options, None)
  117. key = client.getPublicKey()
  118. self.assertTrue(key.isPublic())
  119. self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
  120. self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
  121. def test_getPrivateKey(self):
  122. """
  123. L{SSHUserAuthClient.getPrivateKey} will load a private key from the
  124. last used file populated by L{SSHUserAuthClient.getPublicKey}, and
  125. return a L{Deferred} which fires with the corresponding private L{Key}.
  126. """
  127. rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
  128. options = ConchOptions()
  129. options.identitys = [self.rsaFile.path]
  130. client = SSHUserAuthClient(b"user", options, None)
  131. # Populate the list of used files
  132. client.getPublicKey()
  133. def _cbGetPrivateKey(key):
  134. self.assertFalse(key.isPublic())
  135. self.assertEqual(key, rsaPrivate)
  136. return client.getPrivateKey().addCallback(_cbGetPrivateKey)
  137. def test_getPrivateKeyPassphrase(self):
  138. """
  139. L{SSHUserAuthClient} can get a private key from a file, and return a
  140. Deferred called back with a private L{Key} object, even if the key is
  141. encrypted.
  142. """
  143. rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
  144. passphrase = b'this is the passphrase'
  145. self.rsaFile.setContent(rsaPrivate.toString('openssh', passphrase))
  146. options = ConchOptions()
  147. options.identitys = [self.rsaFile.path]
  148. client = SSHUserAuthClient(b"user", options, None)
  149. # Populate the list of used files
  150. client.getPublicKey()
  151. def _getPassword(prompt):
  152. self.assertEqual(
  153. prompt,
  154. "Enter passphrase for key '%s': " % (self.rsaFile.path,))
  155. return nativeString(passphrase)
  156. def _cbGetPrivateKey(key):
  157. self.assertFalse(key.isPublic())
  158. self.assertEqual(key, rsaPrivate)
  159. self.patch(client, '_getPassword', _getPassword)
  160. return client.getPrivateKey().addCallback(_cbGetPrivateKey)
  161. def test_getPassword(self):
  162. """
  163. Get the password using
  164. L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
  165. """
  166. class FakeTransport:
  167. def __init__(self, host):
  168. self.transport = self
  169. self.host = host
  170. def getPeer(self):
  171. return self
  172. options = ConchOptions()
  173. client = SSHUserAuthClient(b"user", options, None)
  174. client.transport = FakeTransport("127.0.0.1")
  175. def getpass(prompt):
  176. self.assertEqual(prompt, "user@127.0.0.1's password: ")
  177. return 'bad password'
  178. self.patch(default.getpass, 'getpass', getpass)
  179. d = client.getPassword()
  180. d.addCallback(self.assertEqual, b'bad password')
  181. return d
  182. test_getPassword.skip = windowsSkip or ttySkip
  183. def test_getPasswordPrompt(self):
  184. """
  185. Get the password using
  186. L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
  187. using a different prompt.
  188. """
  189. options = ConchOptions()
  190. client = SSHUserAuthClient(b"user", options, None)
  191. prompt = b"Give up your password"
  192. def getpass(p):
  193. self.assertEqual(p, nativeString(prompt))
  194. return 'bad password'
  195. self.patch(default.getpass, 'getpass', getpass)
  196. d = client.getPassword(prompt)
  197. d.addCallback(self.assertEqual, b'bad password')
  198. return d
  199. test_getPasswordPrompt.skip = windowsSkip or ttySkip
  200. def test_getPasswordConchError(self):
  201. """
  202. Get the password using
  203. L{twisted.conch.client.default.SSHUserAuthClient.getPassword}
  204. and trigger a {twisted.conch.error import ConchError}.
  205. """
  206. options = ConchOptions()
  207. client = SSHUserAuthClient(b"user", options, None)
  208. def getpass(prompt):
  209. raise KeyboardInterrupt("User pressed CTRL-C")
  210. self.patch(default.getpass, 'getpass', getpass)
  211. stdout, stdin = sys.stdout, sys.stdin
  212. d = client.getPassword(b'?')
  213. @d.addErrback
  214. def check_sys(fail):
  215. self.assertEqual(
  216. [stdout, stdin], [sys.stdout, sys.stdin])
  217. return fail
  218. self.assertFailure(d, ConchError)
  219. test_getPasswordConchError.skip = windowsSkip or ttySkip
  220. def test_getGenericAnswers(self):
  221. """
  222. L{twisted.conch.client.default.SSHUserAuthClient.getGenericAnswers}
  223. """
  224. options = ConchOptions()
  225. client = SSHUserAuthClient(b"user", options, None)
  226. def getpass(prompt):
  227. self.assertEqual(prompt, "pass prompt")
  228. return "getpass"
  229. self.patch(default.getpass, 'getpass', getpass)
  230. def raw_input(prompt):
  231. self.assertEqual(prompt, "raw_input prompt")
  232. return "raw_input"
  233. self.patch(default, 'raw_input', raw_input)
  234. d = client.getGenericAnswers(
  235. b"Name", b"Instruction", [
  236. (b"pass prompt", False), (b"raw_input prompt", True)])
  237. d.addCallback(
  238. self.assertListEqual, ["getpass", "raw_input"])
  239. return d
  240. test_getGenericAnswers.skip = windowsSkip or ttySkip
  241. class ConchOptionsParsing(TestCase):
  242. """
  243. Options parsing.
  244. """
  245. def test_macs(self):
  246. """
  247. Specify MAC algorithms.
  248. """
  249. opts = ConchOptions()
  250. e = self.assertRaises(SystemExit, opts.opt_macs, "invalid-mac")
  251. self.assertIn("Unknown mac type", e.code)
  252. opts = ConchOptions()
  253. opts.opt_macs("hmac-sha2-512")
  254. self.assertEqual(opts['macs'], [b"hmac-sha2-512"])
  255. opts.opt_macs(b"hmac-sha2-512")
  256. self.assertEqual(opts['macs'], [b"hmac-sha2-512"])
  257. opts.opt_macs("hmac-sha2-256,hmac-sha1,hmac-md5")
  258. self.assertEqual(opts['macs'], [b"hmac-sha2-256", b"hmac-sha1", b"hmac-md5"])
  259. def test_host_key_algorithms(self):
  260. """
  261. Specify host key algorithms.
  262. """
  263. opts = ConchOptions()
  264. e = self.assertRaises(SystemExit, opts.opt_host_key_algorithms, "invalid-key")
  265. self.assertIn("Unknown host key type", e.code)
  266. opts = ConchOptions()
  267. opts.opt_host_key_algorithms("ssh-rsa")
  268. self.assertEqual(opts['host-key-algorithms'], [b"ssh-rsa"])
  269. opts.opt_host_key_algorithms(b"ssh-dss")
  270. self.assertEqual(opts['host-key-algorithms'], [b"ssh-dss"])
  271. opts.opt_host_key_algorithms("ssh-rsa,ssh-dss")
  272. self.assertEqual(opts['host-key-algorithms'], [b"ssh-rsa", b"ssh-dss"])