test_banana.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. from __future__ import absolute_import, division
  4. import sys
  5. from functools import partial
  6. from io import BytesIO
  7. from twisted.trial import unittest
  8. from twisted.spread import banana
  9. from twisted.python import failure
  10. from twisted.python.compat import long, iterbytes, _bytesChr as chr, _PY3
  11. from twisted.internet import protocol, main
  12. from twisted.test.proto_helpers import StringTransport
  13. if _PY3:
  14. _maxint = 9223372036854775807
  15. else:
  16. from sys import maxint as _maxint
  17. class MathTests(unittest.TestCase):
  18. def test_int2b128(self):
  19. funkylist = (list(range(0,100)) + list(range(1000,1100)) +
  20. list(range(1000000,1000100)) + [1024 **10])
  21. for i in funkylist:
  22. x = BytesIO()
  23. banana.int2b128(i, x.write)
  24. v = x.getvalue()
  25. y = banana.b1282int(v)
  26. self.assertEqual(y, i)
  27. def selectDialect(protocol, dialect):
  28. """
  29. Dictate a Banana dialect to use.
  30. @param protocol: A L{banana.Banana} instance which has not yet had a
  31. dialect negotiated.
  32. @param dialect: A L{bytes} instance naming a Banana dialect to select.
  33. """
  34. # We can't do this the normal way by delivering bytes because other setup
  35. # stuff gets in the way (for example, clients and servers have incompatible
  36. # negotiations for this step). So use the private API to make this happen.
  37. protocol._selectDialect(dialect)
  38. def encode(bananaFactory, obj):
  39. """
  40. Banana encode an object using L{banana.Banana.sendEncoded}.
  41. @param bananaFactory: A no-argument callable which will return a new,
  42. unconnected protocol instance to use to do the encoding (this should
  43. most likely be a L{banana.Banana} instance).
  44. @param obj: The object to encode.
  45. @type obj: Any type supported by Banana.
  46. @return: A L{bytes} instance giving the encoded form of C{obj}.
  47. """
  48. transport = StringTransport()
  49. banana = bananaFactory()
  50. banana.makeConnection(transport)
  51. transport.clear()
  52. banana.sendEncoded(obj)
  53. return transport.value()
  54. class BananaTestBase(unittest.TestCase):
  55. """
  56. The base for test classes. It defines commonly used things and sets up a
  57. connection for testing.
  58. """
  59. encClass = banana.Banana
  60. def setUp(self):
  61. self.io = BytesIO()
  62. self.enc = self.encClass()
  63. self.enc.makeConnection(protocol.FileWrapper(self.io))
  64. selectDialect(self.enc, b"none")
  65. self.enc.expressionReceived = self.putResult
  66. self.encode = partial(encode, self.encClass)
  67. def putResult(self, result):
  68. """
  69. Store an expression received by C{self.enc}.
  70. @param result: The object that was received.
  71. @type result: Any type supported by Banana.
  72. """
  73. self.result = result
  74. def tearDown(self):
  75. self.enc.connectionLost(failure.Failure(main.CONNECTION_DONE))
  76. del self.enc
  77. class BananaTests(BananaTestBase):
  78. """
  79. General banana tests.
  80. """
  81. def test_string(self):
  82. self.enc.sendEncoded(b"hello")
  83. self.enc.dataReceived(self.io.getvalue())
  84. assert self.result == b'hello'
  85. def test_unsupportedUnicode(self):
  86. """
  87. Banana does not support unicode. ``Banana.sendEncoded`` raises
  88. ``BananaError`` if called with an instance of ``unicode``.
  89. """
  90. if _PY3:
  91. self._unsupportedTypeTest(u"hello", "builtins.str")
  92. else:
  93. self._unsupportedTypeTest(u"hello", "__builtin__.unicode")
  94. def test_unsupportedBuiltinType(self):
  95. """
  96. Banana does not support arbitrary builtin types like L{type}.
  97. L{banana.Banana.sendEncoded} raises L{banana.BananaError} if called
  98. with an instance of L{type}.
  99. """
  100. # type is an instance of type
  101. if _PY3:
  102. self._unsupportedTypeTest(type, "builtins.type")
  103. else:
  104. self._unsupportedTypeTest(type, "__builtin__.type")
  105. def test_unsupportedUserType(self):
  106. """
  107. Banana does not support arbitrary user-defined types (such as those
  108. defined with the ``class`` statement). ``Banana.sendEncoded`` raises
  109. ``BananaError`` if called with an instance of such a type.
  110. """
  111. self._unsupportedTypeTest(MathTests(), __name__ + ".MathTests")
  112. def _unsupportedTypeTest(self, obj, name):
  113. """
  114. Assert that L{banana.Banana.sendEncoded} raises L{banana.BananaError}
  115. if called with the given object.
  116. @param obj: Some object that Banana does not support.
  117. @param name: The name of the type of the object.
  118. @raise: The failure exception is raised if L{Banana.sendEncoded} does
  119. not raise L{banana.BananaError} or if the message associated with the
  120. exception is not formatted to include the type of the unsupported
  121. object.
  122. """
  123. exc = self.assertRaises(banana.BananaError, self.enc.sendEncoded, obj)
  124. self.assertIn("Banana cannot send {0} objects".format(name), str(exc))
  125. def test_int(self):
  126. """
  127. A positive integer less than 2 ** 32 should round-trip through
  128. banana without changing value and should come out represented
  129. as an C{int} (regardless of the type which was encoded).
  130. """
  131. for value in (10151, long(10151)):
  132. self.enc.sendEncoded(value)
  133. self.enc.dataReceived(self.io.getvalue())
  134. self.assertEqual(self.result, 10151)
  135. self.assertIsInstance(self.result, int)
  136. def test_largeLong(self):
  137. """
  138. Integers greater than 2 ** 32 and less than -2 ** 32 should
  139. round-trip through banana without changing value and should
  140. come out represented as C{int} instances if the value fits
  141. into that type on the receiving platform.
  142. """
  143. for exp in (32, 64, 128, 256):
  144. for add in (0, 1):
  145. m = 2 ** exp + add
  146. for n in (m, -m-1):
  147. self.enc.dataReceived(self.encode(n))
  148. self.assertEqual(self.result, n)
  149. if n > _maxint or n < -_maxint - 1:
  150. self.assertIsInstance(self.result, long)
  151. else:
  152. self.assertIsInstance(self.result, int)
  153. if _PY3:
  154. test_largeLong.skip = (
  155. "Python 3 has unified int/long into an int type of unlimited size")
  156. def _getSmallest(self):
  157. # How many bytes of prefix our implementation allows
  158. bytes = self.enc.prefixLimit
  159. # How many useful bits we can extract from that based on Banana's
  160. # base-128 representation.
  161. bits = bytes * 7
  162. # The largest number we _should_ be able to encode
  163. largest = 2 ** bits - 1
  164. # The smallest number we _shouldn't_ be able to encode
  165. smallest = largest + 1
  166. return smallest
  167. def test_encodeTooLargeLong(self):
  168. """
  169. Test that a long above the implementation-specific limit is rejected
  170. as too large to be encoded.
  171. """
  172. smallest = self._getSmallest()
  173. self.assertRaises(banana.BananaError, self.enc.sendEncoded, smallest)
  174. def test_decodeTooLargeLong(self):
  175. """
  176. Test that a long above the implementation specific limit is rejected
  177. as too large to be decoded.
  178. """
  179. smallest = self._getSmallest()
  180. self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
  181. self.enc.sendEncoded(smallest)
  182. encoded = self.io.getvalue()
  183. self.io.truncate(0)
  184. self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
  185. self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
  186. def _getLargest(self):
  187. return -self._getSmallest()
  188. def test_encodeTooSmallLong(self):
  189. """
  190. Test that a negative long below the implementation-specific limit is
  191. rejected as too small to be encoded.
  192. """
  193. largest = self._getLargest()
  194. self.assertRaises(banana.BananaError, self.enc.sendEncoded, largest)
  195. def test_decodeTooSmallLong(self):
  196. """
  197. Test that a negative long below the implementation specific limit is
  198. rejected as too small to be decoded.
  199. """
  200. largest = self._getLargest()
  201. self.enc.setPrefixLimit(self.enc.prefixLimit * 2)
  202. self.enc.sendEncoded(largest)
  203. encoded = self.io.getvalue()
  204. self.io.truncate(0)
  205. self.enc.setPrefixLimit(self.enc.prefixLimit // 2)
  206. self.assertRaises(banana.BananaError, self.enc.dataReceived, encoded)
  207. def test_integer(self):
  208. self.enc.sendEncoded(1015)
  209. self.enc.dataReceived(self.io.getvalue())
  210. self.assertEqual(self.result, 1015)
  211. def test_negative(self):
  212. self.enc.sendEncoded(-1015)
  213. self.enc.dataReceived(self.io.getvalue())
  214. self.assertEqual(self.result, -1015)
  215. def test_float(self):
  216. self.enc.sendEncoded(1015.)
  217. self.enc.dataReceived(self.io.getvalue())
  218. self.assertEqual(self.result, 1015.0)
  219. def test_list(self):
  220. foo = ([1, 2, [3, 4], [30.5, 40.2], 5,
  221. [b"six", b"seven", [b"eight", 9]], [10], []])
  222. self.enc.sendEncoded(foo)
  223. self.enc.dataReceived(self.io.getvalue())
  224. self.assertEqual(self.result, foo)
  225. def test_partial(self):
  226. """
  227. Test feeding the data byte per byte to the receiver. Normally
  228. data is not split.
  229. """
  230. foo = [1, 2, [3, 4], [30.5, 40.2], 5,
  231. [b"six", b"seven", [b"eight", 9]], [10],
  232. # TODO: currently the C implementation's a bit buggy...
  233. sys.maxsize * 3, sys.maxsize * 2, sys.maxsize * -2]
  234. self.enc.sendEncoded(foo)
  235. self.feed(self.io.getvalue())
  236. self.assertEqual(self.result, foo)
  237. def feed(self, data):
  238. """
  239. Feed the data byte per byte to the receiver.
  240. @param data: The bytes to deliver.
  241. @type data: L{bytes}
  242. """
  243. for byte in iterbytes(data):
  244. self.enc.dataReceived(byte)
  245. def test_oversizedList(self):
  246. data = b'\x02\x01\x01\x01\x01\x80'
  247. # list(size=0x0101010102, about 4.3e9)
  248. self.assertRaises(banana.BananaError, self.feed, data)
  249. def test_oversizedString(self):
  250. data = b'\x02\x01\x01\x01\x01\x82'
  251. # string(size=0x0101010102, about 4.3e9)
  252. self.assertRaises(banana.BananaError, self.feed, data)
  253. def test_crashString(self):
  254. crashString = b'\x00\x00\x00\x00\x04\x80'
  255. # string(size=0x0400000000, about 17.2e9)
  256. # cBanana would fold that into a 32-bit 'int', then try to allocate
  257. # a list with PyList_New(). cBanana ignored the NULL return value,
  258. # so it would segfault when trying to free the imaginary list.
  259. # This variant doesn't segfault straight out in my environment.
  260. # Instead, it takes up large amounts of CPU and memory...
  261. #crashString = '\x00\x00\x00\x00\x01\x80'
  262. # print repr(crashString)
  263. #self.failUnlessRaises(Exception, self.enc.dataReceived, crashString)
  264. try:
  265. # should now raise MemoryError
  266. self.enc.dataReceived(crashString)
  267. except banana.BananaError:
  268. pass
  269. def test_crashNegativeLong(self):
  270. # There was a bug in cBanana which relied on negating a negative integer
  271. # always giving a positive result, but for the lowest possible number in
  272. # 2s-complement arithmetic, that's not true, i.e.
  273. # long x = -2147483648;
  274. # long y = -x;
  275. # x == y; /* true! */
  276. # (assuming 32-bit longs)
  277. self.enc.sendEncoded(-2147483648)
  278. self.enc.dataReceived(self.io.getvalue())
  279. self.assertEqual(self.result, -2147483648)
  280. def test_sizedIntegerTypes(self):
  281. """
  282. Test that integers below the maximum C{INT} token size cutoff are
  283. serialized as C{INT} or C{NEG} and that larger integers are
  284. serialized as C{LONGINT} or C{LONGNEG}.
  285. """
  286. baseIntIn = +2147483647
  287. baseNegIn = -2147483648
  288. baseIntOut = b'\x7f\x7f\x7f\x07\x81'
  289. self.assertEqual(self.encode(baseIntIn - 2), b'\x7d' + baseIntOut)
  290. self.assertEqual(self.encode(baseIntIn - 1), b'\x7e' + baseIntOut)
  291. self.assertEqual(self.encode(baseIntIn - 0), b'\x7f' + baseIntOut)
  292. baseLongIntOut = b'\x00\x00\x00\x08\x85'
  293. self.assertEqual(self.encode(baseIntIn + 1), b'\x00' + baseLongIntOut)
  294. self.assertEqual(self.encode(baseIntIn + 2), b'\x01' + baseLongIntOut)
  295. self.assertEqual(self.encode(baseIntIn + 3), b'\x02' + baseLongIntOut)
  296. baseNegOut = b'\x7f\x7f\x7f\x07\x83'
  297. self.assertEqual(self.encode(baseNegIn + 2), b'\x7e' + baseNegOut)
  298. self.assertEqual(self.encode(baseNegIn + 1), b'\x7f' + baseNegOut)
  299. self.assertEqual(self.encode(baseNegIn + 0), b'\x00\x00\x00\x00\x08\x83')
  300. baseLongNegOut = b'\x00\x00\x00\x08\x86'
  301. self.assertEqual(self.encode(baseNegIn - 1), b'\x01' + baseLongNegOut)
  302. self.assertEqual(self.encode(baseNegIn - 2), b'\x02' + baseLongNegOut)
  303. self.assertEqual(self.encode(baseNegIn - 3), b'\x03' + baseLongNegOut)
  304. class DialectTests(BananaTestBase):
  305. """
  306. Tests for Banana's handling of dialects.
  307. """
  308. vocab = b'remote'
  309. legalPbItem = chr(banana.Banana.outgoingVocabulary[vocab]) + banana.VOCAB
  310. illegalPbItem = chr(122) + banana.VOCAB
  311. def test_dialectNotSet(self):
  312. """
  313. If no dialect has been selected and a PB VOCAB item is received,
  314. L{NotImplementedError} is raised.
  315. """
  316. self.assertRaises(
  317. NotImplementedError,
  318. self.enc.dataReceived, self.legalPbItem)
  319. def test_receivePb(self):
  320. """
  321. If the PB dialect has been selected, a PB VOCAB item is accepted.
  322. """
  323. selectDialect(self.enc, b'pb')
  324. self.enc.dataReceived(self.legalPbItem)
  325. self.assertEqual(self.result, self.vocab)
  326. def test_receiveIllegalPb(self):
  327. """
  328. If the PB dialect has been selected and an unrecognized PB VOCAB item
  329. is received, L{banana.Banana.dataReceived} raises L{KeyError}.
  330. """
  331. selectDialect(self.enc, b'pb')
  332. self.assertRaises(KeyError, self.enc.dataReceived, self.illegalPbItem)
  333. def test_sendPb(self):
  334. """
  335. if pb dialect is selected, the sender must be able to send things in
  336. that dialect.
  337. """
  338. selectDialect(self.enc, b'pb')
  339. self.enc.sendEncoded(self.vocab)
  340. self.assertEqual(self.legalPbItem, self.io.getvalue())
  341. class GlobalCoderTests(unittest.TestCase):
  342. """
  343. Tests for the free functions L{banana.encode} and L{banana.decode}.
  344. """
  345. def test_statelessDecode(self):
  346. """
  347. Calls to L{banana.decode} are independent of each other.
  348. """
  349. # Banana encoding of 2 ** 449
  350. undecodable = b'\x7f' * 65 + b'\x85'
  351. self.assertRaises(banana.BananaError, banana.decode, undecodable)
  352. # Banana encoding of 1. This should be decodable even though the
  353. # previous call passed un-decodable data and triggered an exception.
  354. decodable = b'\x01\x81'
  355. self.assertEqual(banana.decode(decodable), 1)