test_cred.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for L{twisted.cred}, now with 30% more starch.
  5. """
  6. from __future__ import absolute_import, division
  7. from zope.interface import implementer, Interface
  8. from binascii import hexlify, unhexlify
  9. from twisted.trial import unittest
  10. from twisted.python.compat import nativeString, networkString
  11. from twisted.python import components
  12. from twisted.internet import defer
  13. from twisted.cred import checkers, credentials, portal, error
  14. try:
  15. from crypt import crypt
  16. except ImportError:
  17. crypt = None
  18. class ITestable(Interface):
  19. """
  20. An interface for a theoretical protocol.
  21. """
  22. pass
  23. class TestAvatar(object):
  24. """
  25. A test avatar.
  26. """
  27. def __init__(self, name):
  28. self.name = name
  29. self.loggedIn = False
  30. self.loggedOut = False
  31. def login(self):
  32. assert not self.loggedIn
  33. self.loggedIn = True
  34. def logout(self):
  35. self.loggedOut = True
  36. @implementer(ITestable)
  37. class Testable(components.Adapter):
  38. """
  39. A theoretical protocol for testing.
  40. """
  41. pass
  42. components.registerAdapter(Testable, TestAvatar, ITestable)
  43. class IDerivedCredentials(credentials.IUsernamePassword):
  44. pass
  45. @implementer(IDerivedCredentials, ITestable)
  46. class DerivedCredentials(object):
  47. def __init__(self, username, password):
  48. self.username = username
  49. self.password = password
  50. def checkPassword(self, password):
  51. return password == self.password
  52. @implementer(portal.IRealm)
  53. class TestRealm(object):
  54. """
  55. A basic test realm.
  56. """
  57. def __init__(self):
  58. self.avatars = {}
  59. def requestAvatar(self, avatarId, mind, *interfaces):
  60. if avatarId in self.avatars:
  61. avatar = self.avatars[avatarId]
  62. else:
  63. avatar = TestAvatar(avatarId)
  64. self.avatars[avatarId] = avatar
  65. avatar.login()
  66. return (interfaces[0], interfaces[0](avatar),
  67. avatar.logout)
  68. class CredTests(unittest.TestCase):
  69. """
  70. Tests for the meat of L{twisted.cred} -- realms, portals, avatars, and
  71. checkers.
  72. """
  73. def setUp(self):
  74. self.realm = TestRealm()
  75. self.portal = portal.Portal(self.realm)
  76. self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
  77. self.checker.addUser(b"bob", b"hello")
  78. self.portal.registerChecker(self.checker)
  79. def test_listCheckers(self):
  80. """
  81. The checkers in a portal can check only certain types of credentials.
  82. Since this portal has
  83. L{checkers.InMemoryUsernamePasswordDatabaseDontUse} registered, it
  84. """
  85. expected = [credentials.IUsernamePassword,
  86. credentials.IUsernameHashedPassword]
  87. got = self.portal.listCredentialsInterfaces()
  88. self.assertEqual(sorted(got), sorted(expected))
  89. def test_basicLogin(self):
  90. """
  91. Calling C{login} on a portal with correct credentials and an interface
  92. that the portal's realm supports works.
  93. """
  94. login = self.successResultOf(self.portal.login(
  95. credentials.UsernamePassword(b"bob", b"hello"), self, ITestable))
  96. iface, impl, logout = login
  97. # whitebox
  98. self.assertEqual(iface, ITestable)
  99. self.assertTrue(iface.providedBy(impl),
  100. "%s does not implement %s" % (impl, iface))
  101. # greybox
  102. self.assertTrue(impl.original.loggedIn)
  103. self.assertTrue(not impl.original.loggedOut)
  104. logout()
  105. self.assertTrue(impl.original.loggedOut)
  106. def test_derivedInterface(self):
  107. """
  108. Logging in with correct derived credentials and an interface
  109. that the portal's realm supports works.
  110. """
  111. login = self.successResultOf(self.portal.login(
  112. DerivedCredentials(b"bob", b"hello"), self, ITestable))
  113. iface, impl, logout = login
  114. # whitebox
  115. self.assertEqual(iface, ITestable)
  116. self.assertTrue(iface.providedBy(impl),
  117. "%s does not implement %s" % (impl, iface))
  118. # greybox
  119. self.assertTrue(impl.original.loggedIn)
  120. self.assertTrue(not impl.original.loggedOut)
  121. logout()
  122. self.assertTrue(impl.original.loggedOut)
  123. def test_failedLoginPassword(self):
  124. """
  125. Calling C{login} with incorrect credentials (in this case a wrong
  126. password) causes L{error.UnauthorizedLogin} to be raised.
  127. """
  128. login = self.failureResultOf(self.portal.login(
  129. credentials.UsernamePassword(b"bob", b"h3llo"), self, ITestable))
  130. self.assertTrue(login)
  131. self.assertEqual(error.UnauthorizedLogin, login.type)
  132. def test_failedLoginName(self):
  133. """
  134. Calling C{login} with incorrect credentials (in this case no known
  135. user) causes L{error.UnauthorizedLogin} to be raised.
  136. """
  137. login = self.failureResultOf(self.portal.login(
  138. credentials.UsernamePassword(b"jay", b"hello"), self, ITestable))
  139. self.assertTrue(login)
  140. self.assertEqual(error.UnauthorizedLogin, login.type)
  141. class OnDiskDatabaseTests(unittest.TestCase):
  142. users = [
  143. (b'user1', b'pass1'),
  144. (b'user2', b'pass2'),
  145. (b'user3', b'pass3'),
  146. ]
  147. def setUp(self):
  148. self.dbfile = self.mktemp()
  149. with open(self.dbfile, 'wb') as f:
  150. for (u, p) in self.users:
  151. f.write(u + b":" + p + b"\n")
  152. def test_getUserNonexistentDatabase(self):
  153. """
  154. A missing db file will cause a permanent rejection of authorization
  155. attempts.
  156. """
  157. self.db = checkers.FilePasswordDB('test_thisbetternoteverexist.db')
  158. self.assertRaises(error.UnauthorizedLogin, self.db.getUser, 'user')
  159. def testUserLookup(self):
  160. self.db = checkers.FilePasswordDB(self.dbfile)
  161. for (u, p) in self.users:
  162. self.assertRaises(KeyError, self.db.getUser, u.upper())
  163. self.assertEqual(self.db.getUser(u), (u, p))
  164. def testCaseInSensitivity(self):
  165. self.db = checkers.FilePasswordDB(self.dbfile, caseSensitive=False)
  166. for (u, p) in self.users:
  167. self.assertEqual(self.db.getUser(u.upper()), (u, p))
  168. def testRequestAvatarId(self):
  169. self.db = checkers.FilePasswordDB(self.dbfile)
  170. creds = [credentials.UsernamePassword(u, p) for u, p in self.users]
  171. d = defer.gatherResults(
  172. [defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds])
  173. d.addCallback(self.assertEqual, [u for u, p in self.users])
  174. return d
  175. def testRequestAvatarId_hashed(self):
  176. self.db = checkers.FilePasswordDB(self.dbfile)
  177. creds = [credentials.UsernameHashedPassword(u, p)
  178. for u, p in self.users]
  179. d = defer.gatherResults(
  180. [defer.maybeDeferred(self.db.requestAvatarId, c) for c in creds])
  181. d.addCallback(self.assertEqual, [u for u, p in self.users])
  182. return d
  183. class HashedPasswordOnDiskDatabaseTests(unittest.TestCase):
  184. users = [
  185. (b'user1', b'pass1'),
  186. (b'user2', b'pass2'),
  187. (b'user3', b'pass3'),
  188. ]
  189. def setUp(self):
  190. dbfile = self.mktemp()
  191. self.db = checkers.FilePasswordDB(dbfile, hash=self.hash)
  192. with open(dbfile, 'wb') as f:
  193. for (u, p) in self.users:
  194. f.write(u + b":" + self.hash(u, p, u[:2]) + b"\n")
  195. r = TestRealm()
  196. self.port = portal.Portal(r)
  197. self.port.registerChecker(self.db)
  198. def hash(self, u, p, s):
  199. return networkString(crypt(nativeString(p), nativeString(s)))
  200. def testGoodCredentials(self):
  201. goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
  202. d = defer.gatherResults([self.db.requestAvatarId(c)
  203. for c in goodCreds])
  204. d.addCallback(self.assertEqual, [u for u, p in self.users])
  205. return d
  206. def testGoodCredentials_login(self):
  207. goodCreds = [credentials.UsernamePassword(u, p) for u, p in self.users]
  208. d = defer.gatherResults([self.port.login(c, None, ITestable)
  209. for c in goodCreds])
  210. d.addCallback(lambda x: [a.original.name for i, a, l in x])
  211. d.addCallback(self.assertEqual, [u for u, p in self.users])
  212. return d
  213. def testBadCredentials(self):
  214. badCreds = [credentials.UsernamePassword(u, 'wrong password')
  215. for u, p in self.users]
  216. d = defer.DeferredList([self.port.login(c, None, ITestable)
  217. for c in badCreds], consumeErrors=True)
  218. d.addCallback(self._assertFailures, error.UnauthorizedLogin)
  219. return d
  220. def testHashedCredentials(self):
  221. hashedCreds = [credentials.UsernameHashedPassword(
  222. u, self.hash(None, p, u[:2])) for u, p in self.users]
  223. d = defer.DeferredList([self.port.login(c, None, ITestable)
  224. for c in hashedCreds], consumeErrors=True)
  225. d.addCallback(self._assertFailures, error.UnhandledCredentials)
  226. return d
  227. def _assertFailures(self, failures, *expectedFailures):
  228. for flag, failure in failures:
  229. self.assertEqual(flag, defer.FAILURE)
  230. failure.trap(*expectedFailures)
  231. return None
  232. if crypt is None:
  233. skip = "crypt module not available"
  234. class CheckersMixin(object):
  235. """
  236. L{unittest.TestCase} mixin for testing that some checkers accept
  237. and deny specified credentials.
  238. Subclasses must provide
  239. - C{getCheckers} which returns a sequence of
  240. L{checkers.ICredentialChecker}
  241. - C{getGoodCredentials} which returns a list of 2-tuples of
  242. credential to check and avaterId to expect.
  243. - C{getBadCredentials} which returns a list of credentials
  244. which are expected to be unauthorized.
  245. """
  246. @defer.inlineCallbacks
  247. def test_positive(self):
  248. """
  249. The given credentials are accepted by all the checkers, and give
  250. the expected C{avatarID}s
  251. """
  252. for chk in self.getCheckers():
  253. for (cred, avatarId) in self.getGoodCredentials():
  254. r = yield chk.requestAvatarId(cred)
  255. self.assertEqual(r, avatarId)
  256. @defer.inlineCallbacks
  257. def test_negative(self):
  258. """
  259. The given credentials are rejected by all the checkers.
  260. """
  261. for chk in self.getCheckers():
  262. for cred in self.getBadCredentials():
  263. d = chk.requestAvatarId(cred)
  264. yield self.assertFailure(d, error.UnauthorizedLogin)
  265. class HashlessFilePasswordDBMixin(object):
  266. credClass = credentials.UsernamePassword
  267. diskHash = None
  268. networkHash = staticmethod(lambda x: x)
  269. _validCredentials = [
  270. (b'user1', b'password1'),
  271. (b'user2', b'password2'),
  272. (b'user3', b'password3')]
  273. def getGoodCredentials(self):
  274. for u, p in self._validCredentials:
  275. yield self.credClass(u, self.networkHash(p)), u
  276. def getBadCredentials(self):
  277. for u, p in [(b'user1', b'password3'),
  278. (b'user2', b'password1'),
  279. (b'bloof', b'blarf')]:
  280. yield self.credClass(u, self.networkHash(p))
  281. def getCheckers(self):
  282. diskHash = self.diskHash or (lambda x: x)
  283. hashCheck = self.diskHash and (lambda username, password,
  284. stored: self.diskHash(password))
  285. for cache in True, False:
  286. fn = self.mktemp()
  287. with open(fn, 'wb') as fObj:
  288. for u, p in self._validCredentials:
  289. fObj.write(u + b":" + diskHash(p) + b"\n")
  290. yield checkers.FilePasswordDB(fn, cache=cache, hash=hashCheck)
  291. fn = self.mktemp()
  292. with open(fn, 'wb') as fObj:
  293. for u, p in self._validCredentials:
  294. fObj.write(diskHash(p) + b' dingle dongle ' + u + b'\n')
  295. yield checkers.FilePasswordDB(fn, b' ', 3, 0,
  296. cache=cache, hash=hashCheck)
  297. fn = self.mktemp()
  298. with open(fn, 'wb') as fObj:
  299. for u, p in self._validCredentials:
  300. fObj.write(b'zip,zap,' + u.title() + b',zup,'\
  301. + diskHash(p) + b'\n',)
  302. yield checkers.FilePasswordDB(fn, b',', 2, 4, False,
  303. cache=cache, hash=hashCheck)
  304. class LocallyHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
  305. diskHash = staticmethod(lambda x: hexlify(x))
  306. class NetworkHashedFilePasswordDBMixin(HashlessFilePasswordDBMixin):
  307. networkHash = staticmethod(lambda x: hexlify(x))
  308. class credClass(credentials.UsernameHashedPassword):
  309. def checkPassword(self, password):
  310. return unhexlify(self.hashed) == password
  311. class HashlessFilePasswordDBCheckerTests(HashlessFilePasswordDBMixin,
  312. CheckersMixin, unittest.TestCase):
  313. pass
  314. class LocallyHashedFilePasswordDBCheckerTests(LocallyHashedFilePasswordDBMixin,
  315. CheckersMixin,
  316. unittest.TestCase):
  317. pass
  318. class NetworkHashedFilePasswordDBCheckerTests(NetworkHashedFilePasswordDBMixin,
  319. CheckersMixin,
  320. unittest.TestCase):
  321. pass