test_serialization.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from __future__ import absolute_import
  2. import os
  3. import base64
  4. from kombu.serialization import registry
  5. from celery.exceptions import SecurityError
  6. from celery.security.serialization import SecureSerializer, register_auth
  7. from celery.security.certificate import Certificate, CertStore
  8. from celery.security.key import PrivateKey
  9. from . import CERT1, CERT2, KEY1, KEY2
  10. from .case import SecurityCase
  11. class test_SecureSerializer(SecurityCase):
  12. def _get_s(self, key, cert, certs):
  13. store = CertStore()
  14. for c in certs:
  15. store.add_cert(Certificate(c))
  16. return SecureSerializer(PrivateKey(key), Certificate(cert), store)
  17. def test_serialize(self):
  18. s = self._get_s(KEY1, CERT1, [CERT1])
  19. self.assertEqual(s.deserialize(s.serialize('foo')), 'foo')
  20. def test_deserialize(self):
  21. s = self._get_s(KEY1, CERT1, [CERT1])
  22. self.assertRaises(SecurityError, s.deserialize, 'bad data')
  23. def test_unmatched_key_cert(self):
  24. s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
  25. self.assertRaises(SecurityError,
  26. s.deserialize, s.serialize('foo'))
  27. def test_unknown_source(self):
  28. s1 = self._get_s(KEY1, CERT1, [CERT2])
  29. s2 = self._get_s(KEY1, CERT1, [])
  30. self.assertRaises(SecurityError,
  31. s1.deserialize, s1.serialize('foo'))
  32. self.assertRaises(SecurityError,
  33. s2.deserialize, s2.serialize('foo'))
  34. def test_self_send(self):
  35. s1 = self._get_s(KEY1, CERT1, [CERT1])
  36. s2 = self._get_s(KEY1, CERT1, [CERT1])
  37. self.assertEqual(s2.deserialize(s1.serialize('foo')), 'foo')
  38. def test_separate_ends(self):
  39. s1 = self._get_s(KEY1, CERT1, [CERT2])
  40. s2 = self._get_s(KEY2, CERT2, [CERT1])
  41. self.assertEqual(s2.deserialize(s1.serialize('foo')), 'foo')
  42. def test_register_auth(self):
  43. register_auth(KEY1, CERT1, '')
  44. self.assertIn('application/data', registry._decoders)
  45. def test_lots_of_sign(self):
  46. for i in range(1000):
  47. rdata = base64.urlsafe_b64encode(os.urandom(265))
  48. s = self._get_s(KEY1, CERT1, [CERT1])
  49. self.assertEqual(s.deserialize(s.serialize(rdata)), rdata)