test_wrapper.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Test cases for L{twisted.protocols.haproxy.HAProxyProtocol}.
  5. """
  6. from twisted.trial import unittest
  7. from twisted.internet import address
  8. from twisted.internet.protocol import Protocol, Factory
  9. from twisted.test.proto_helpers import StringTransportWithDisconnection
  10. from .._wrapper import HAProxyWrappingFactory
  11. class StaticProtocol(Protocol):
  12. """
  13. Protocol stand-in that maintains test state.
  14. """
  15. def __init__(self):
  16. self.source = None
  17. self.destination = None
  18. self.data = b''
  19. self.disconnected = False
  20. def dataReceived(self, data):
  21. self.source = self.transport.getPeer()
  22. self.destination = self.transport.getHost()
  23. self.data += data
  24. class HAProxyWrappingFactoryV1Tests(unittest.TestCase):
  25. """
  26. Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v1 PROXY
  27. headers.
  28. """
  29. def test_invalidHeaderDisconnects(self):
  30. """
  31. Test if invalid headers result in connectionLost events.
  32. """
  33. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  34. proto = factory.buildProtocol(
  35. address.IPv4Address('TCP', b'127.1.1.1', 8080),
  36. )
  37. transport = StringTransportWithDisconnection()
  38. transport.protocol = proto
  39. proto.makeConnection(transport)
  40. proto.dataReceived(b'NOTPROXY anything can go here\r\n')
  41. self.assertFalse(transport.connected)
  42. def test_invalidPartialHeaderDisconnects(self):
  43. """
  44. Test if invalid headers result in connectionLost events.
  45. """
  46. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  47. proto = factory.buildProtocol(
  48. address.IPv4Address('TCP', b'127.1.1.1', 8080),
  49. )
  50. transport = StringTransportWithDisconnection()
  51. transport.protocol = proto
  52. proto.makeConnection(transport)
  53. proto.dataReceived(b'PROXY TCP4 1.1.1.1\r\n')
  54. proto.dataReceived(b'2.2.2.2 8080\r\n')
  55. self.assertFalse(transport.connected)
  56. def test_validIPv4HeaderResolves_getPeerHost(self):
  57. """
  58. Test if IPv4 headers result in the correct host and peer data.
  59. """
  60. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  61. proto = factory.buildProtocol(
  62. address.IPv4Address('TCP', b'127.0.0.1', 8080),
  63. )
  64. transport = StringTransportWithDisconnection()
  65. proto.makeConnection(transport)
  66. proto.dataReceived(b'PROXY TCP4 1.1.1.1 2.2.2.2 8080 8888\r\n')
  67. self.assertEqual(proto.getPeer().host, b'1.1.1.1')
  68. self.assertEqual(proto.getPeer().port, 8080)
  69. self.assertEqual(
  70. proto.wrappedProtocol.transport.getPeer().host,
  71. b'1.1.1.1',
  72. )
  73. self.assertEqual(
  74. proto.wrappedProtocol.transport.getPeer().port,
  75. 8080,
  76. )
  77. self.assertEqual(proto.getHost().host, b'2.2.2.2')
  78. self.assertEqual(proto.getHost().port, 8888)
  79. self.assertEqual(
  80. proto.wrappedProtocol.transport.getHost().host,
  81. b'2.2.2.2',
  82. )
  83. self.assertEqual(
  84. proto.wrappedProtocol.transport.getHost().port,
  85. 8888,
  86. )
  87. def test_validIPv6HeaderResolves_getPeerHost(self):
  88. """
  89. Test if IPv6 headers result in the correct host and peer data.
  90. """
  91. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  92. proto = factory.buildProtocol(
  93. address.IPv6Address('TCP', b'::1', 8080),
  94. )
  95. transport = StringTransportWithDisconnection()
  96. proto.makeConnection(transport)
  97. proto.dataReceived(b'PROXY TCP6 ::1 ::2 8080 8888\r\n')
  98. self.assertEqual(proto.getPeer().host, b'::1')
  99. self.assertEqual(proto.getPeer().port, 8080)
  100. self.assertEqual(
  101. proto.wrappedProtocol.transport.getPeer().host,
  102. b'::1',
  103. )
  104. self.assertEqual(
  105. proto.wrappedProtocol.transport.getPeer().port,
  106. 8080,
  107. )
  108. self.assertEqual(proto.getHost().host, b'::2')
  109. self.assertEqual(proto.getHost().port, 8888)
  110. self.assertEqual(
  111. proto.wrappedProtocol.transport.getHost().host,
  112. b'::2',
  113. )
  114. self.assertEqual(
  115. proto.wrappedProtocol.transport.getHost().port,
  116. 8888,
  117. )
  118. def test_overflowBytesSentToWrappedProtocol(self):
  119. """
  120. Test if non-header bytes are passed to the wrapped protocol.
  121. """
  122. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  123. proto = factory.buildProtocol(
  124. address.IPv6Address('TCP', b'::1', 8080),
  125. )
  126. transport = StringTransportWithDisconnection()
  127. proto.makeConnection(transport)
  128. proto.dataReceived(b'PROXY TCP6 ::1 ::2 8080 8888\r\nHTTP/1.1 / GET')
  129. self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
  130. def test_overflowBytesSentToWrappedProtocolChunks(self):
  131. """
  132. Test if header streaming passes extra data appropriately.
  133. """
  134. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  135. proto = factory.buildProtocol(
  136. address.IPv6Address('TCP', b'::1', 8080),
  137. )
  138. transport = StringTransportWithDisconnection()
  139. proto.makeConnection(transport)
  140. proto.dataReceived(b'PROXY TCP6 ::1 ::2 ')
  141. proto.dataReceived(b'8080 8888\r\nHTTP/1.1 / GET')
  142. self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
  143. def test_overflowBytesSentToWrappedProtocolAfter(self):
  144. """
  145. Test if wrapper writes all data to wrapped protocol after parsing.
  146. """
  147. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  148. proto = factory.buildProtocol(
  149. address.IPv6Address('TCP', b'::1', 8080),
  150. )
  151. transport = StringTransportWithDisconnection()
  152. proto.makeConnection(transport)
  153. proto.dataReceived(b'PROXY TCP6 ::1 ::2 ')
  154. proto.dataReceived(b'8080 8888\r\nHTTP/1.1 / GET')
  155. proto.dataReceived(b'\r\n\r\n')
  156. self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET\r\n\r\n')
  157. class HAProxyWrappingFactoryV2Tests(unittest.TestCase):
  158. """
  159. Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v2 PROXY
  160. headers.
  161. """
  162. IPV4HEADER = (
  163. # V2 Signature
  164. b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
  165. # V2 PROXY command
  166. b'\x21'
  167. # AF_INET/STREAM
  168. b'\x11'
  169. # 12 bytes for 2 IPv4 addresses and two ports
  170. b'\x00\x0C'
  171. # 127.0.0.1 for source and destination
  172. b'\x7F\x00\x00\x01\x7F\x00\x00\x01'
  173. # 8080 for source 8888 for destination
  174. b'\x1F\x90\x22\xB8'
  175. )
  176. IPV6HEADER = (
  177. # V2 Signature
  178. b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
  179. # V2 PROXY command
  180. b'\x21'
  181. # AF_INET6/STREAM
  182. b'\x21'
  183. # 16 bytes for 2 IPv6 addresses and two ports
  184. b'\x00\x24'
  185. # ::1 for source and destination
  186. b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'
  187. b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'
  188. # 8080 for source 8888 for destination
  189. b'\x1F\x90\x22\xB8'
  190. )
  191. _SOCK_PATH = (
  192. b'\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F\x6D\x79\x73\x6F'
  193. b'\x63\x6B\x65\x74\x73\x2F\x73\x6F\x63\x6B' + (b'\x00' * 82)
  194. )
  195. UNIXHEADER = (
  196. # V2 Signature
  197. b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
  198. # V2 PROXY command
  199. b'\x21'
  200. # AF_UNIX/STREAM
  201. b'\x31'
  202. # 108 bytes for 2 null terminated paths
  203. b'\x00\xD8'
  204. # /home/tests/mysockets/sock for source and destination paths
  205. ) + _SOCK_PATH + _SOCK_PATH
  206. def test_invalidHeaderDisconnects(self):
  207. """
  208. Test if invalid headers result in connectionLost events.
  209. """
  210. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  211. proto = factory.buildProtocol(
  212. address.IPv6Address('TCP', b'::1', 8080),
  213. )
  214. transport = StringTransportWithDisconnection()
  215. transport.protocol = proto
  216. proto.makeConnection(transport)
  217. proto.dataReceived(b'\x00' + self.IPV4HEADER[1:])
  218. self.assertFalse(transport.connected)
  219. def test_validIPv4HeaderResolves_getPeerHost(self):
  220. """
  221. Test if IPv4 headers result in the correct host and peer data.
  222. """
  223. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  224. proto = factory.buildProtocol(
  225. address.IPv4Address('TCP', b'127.0.0.1', 8080),
  226. )
  227. transport = StringTransportWithDisconnection()
  228. proto.makeConnection(transport)
  229. proto.dataReceived(self.IPV4HEADER)
  230. self.assertEqual(proto.getPeer().host, b'127.0.0.1')
  231. self.assertEqual(proto.getPeer().port, 8080)
  232. self.assertEqual(
  233. proto.wrappedProtocol.transport.getPeer().host,
  234. b'127.0.0.1',
  235. )
  236. self.assertEqual(
  237. proto.wrappedProtocol.transport.getPeer().port,
  238. 8080,
  239. )
  240. self.assertEqual(proto.getHost().host, b'127.0.0.1')
  241. self.assertEqual(proto.getHost().port, 8888)
  242. self.assertEqual(
  243. proto.wrappedProtocol.transport.getHost().host,
  244. b'127.0.0.1',
  245. )
  246. self.assertEqual(
  247. proto.wrappedProtocol.transport.getHost().port,
  248. 8888,
  249. )
  250. def test_validIPv6HeaderResolves_getPeerHost(self):
  251. """
  252. Test if IPv6 headers result in the correct host and peer data.
  253. """
  254. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  255. proto = factory.buildProtocol(
  256. address.IPv4Address('TCP', b'::1', 8080),
  257. )
  258. transport = StringTransportWithDisconnection()
  259. proto.makeConnection(transport)
  260. proto.dataReceived(self.IPV6HEADER)
  261. self.assertEqual(proto.getPeer().host, b'0:0:0:0:0:0:0:1')
  262. self.assertEqual(proto.getPeer().port, 8080)
  263. self.assertEqual(
  264. proto.wrappedProtocol.transport.getPeer().host,
  265. b'0:0:0:0:0:0:0:1',
  266. )
  267. self.assertEqual(
  268. proto.wrappedProtocol.transport.getPeer().port,
  269. 8080,
  270. )
  271. self.assertEqual(proto.getHost().host, b'0:0:0:0:0:0:0:1')
  272. self.assertEqual(proto.getHost().port, 8888)
  273. self.assertEqual(
  274. proto.wrappedProtocol.transport.getHost().host,
  275. b'0:0:0:0:0:0:0:1',
  276. )
  277. self.assertEqual(
  278. proto.wrappedProtocol.transport.getHost().port,
  279. 8888,
  280. )
  281. def test_validUNIXHeaderResolves_getPeerHost(self):
  282. """
  283. Test if UNIX headers result in the correct host and peer data.
  284. """
  285. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  286. proto = factory.buildProtocol(
  287. address.UNIXAddress(b'/home/test/sockets/server.sock'),
  288. )
  289. transport = StringTransportWithDisconnection()
  290. proto.makeConnection(transport)
  291. proto.dataReceived(self.UNIXHEADER)
  292. self.assertEqual(proto.getPeer().name, b'/home/tests/mysockets/sock')
  293. self.assertEqual(
  294. proto.wrappedProtocol.transport.getPeer().name,
  295. b'/home/tests/mysockets/sock',
  296. )
  297. self.assertEqual(proto.getHost().name, b'/home/tests/mysockets/sock')
  298. self.assertEqual(
  299. proto.wrappedProtocol.transport.getHost().name,
  300. b'/home/tests/mysockets/sock',
  301. )
  302. def test_overflowBytesSentToWrappedProtocol(self):
  303. """
  304. Test if non-header bytes are passed to the wrapped protocol.
  305. """
  306. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  307. proto = factory.buildProtocol(
  308. address.IPv6Address('TCP', b'::1', 8080),
  309. )
  310. transport = StringTransportWithDisconnection()
  311. proto.makeConnection(transport)
  312. proto.dataReceived(self.IPV6HEADER + b'HTTP/1.1 / GET')
  313. self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
  314. def test_overflowBytesSentToWrappedProtocolChunks(self):
  315. """
  316. Test if header streaming passes extra data appropriately.
  317. """
  318. factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
  319. proto = factory.buildProtocol(
  320. address.IPv6Address('TCP', b'::1', 8080),
  321. )
  322. transport = StringTransportWithDisconnection()
  323. proto.makeConnection(transport)
  324. proto.dataReceived(self.IPV6HEADER[:18])
  325. proto.dataReceived(self.IPV6HEADER[18:] + b'HTTP/1.1 / GET')
  326. self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')