test_proxy.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Test for L{twisted.web.proxy}.
  5. """
  6. from twisted.trial.unittest import TestCase
  7. from twisted.test.proto_helpers import StringTransportWithDisconnection
  8. from twisted.test.proto_helpers import MemoryReactor
  9. from twisted.web.resource import Resource
  10. from twisted.web.server import Site
  11. from twisted.web.proxy import ReverseProxyResource, ProxyClientFactory
  12. from twisted.web.proxy import ProxyClient, ProxyRequest, ReverseProxyRequest
  13. from twisted.web.test.test_web import DummyRequest
  14. class ReverseProxyResourceTests(TestCase):
  15. """
  16. Tests for L{ReverseProxyResource}.
  17. """
  18. def _testRender(self, uri, expectedURI):
  19. """
  20. Check that a request pointing at C{uri} produce a new proxy connection,
  21. with the path of this request pointing at C{expectedURI}.
  22. """
  23. root = Resource()
  24. reactor = MemoryReactor()
  25. resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
  26. root.putChild(b'index', resource)
  27. site = Site(root)
  28. transport = StringTransportWithDisconnection()
  29. channel = site.buildProtocol(None)
  30. channel.makeConnection(transport)
  31. # Clear the timeout if the tests failed
  32. self.addCleanup(channel.connectionLost, None)
  33. channel.dataReceived(b"GET " +
  34. uri +
  35. b" HTTP/1.1\r\nAccept: text/html\r\n\r\n")
  36. # Check that one connection has been created, to the good host/port
  37. self.assertEqual(len(reactor.tcpClients), 1)
  38. self.assertEqual(reactor.tcpClients[0][0], u"127.0.0.1")
  39. self.assertEqual(reactor.tcpClients[0][1], 1234)
  40. # Check the factory passed to the connect, and its given path
  41. factory = reactor.tcpClients[0][2]
  42. self.assertIsInstance(factory, ProxyClientFactory)
  43. self.assertEqual(factory.rest, expectedURI)
  44. self.assertEqual(factory.headers[b"host"], b"127.0.0.1:1234")
  45. def test_render(self):
  46. """
  47. Test that L{ReverseProxyResource.render} initiates a connection to the
  48. given server with a L{ProxyClientFactory} as parameter.
  49. """
  50. return self._testRender(b"/index", b"/path")
  51. def test_renderWithQuery(self):
  52. """
  53. Test that L{ReverseProxyResource.render} passes query parameters to the
  54. created factory.
  55. """
  56. return self._testRender(b"/index?foo=bar", b"/path?foo=bar")
  57. def test_getChild(self):
  58. """
  59. The L{ReverseProxyResource.getChild} method should return a resource
  60. instance with the same class as the originating resource, forward
  61. port, host, and reactor values, and update the path value with the
  62. value passed.
  63. """
  64. reactor = MemoryReactor()
  65. resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
  66. child = resource.getChild(b'foo', None)
  67. # The child should keep the same class
  68. self.assertIsInstance(child, ReverseProxyResource)
  69. self.assertEqual(child.path, b"/path/foo")
  70. self.assertEqual(child.port, 1234)
  71. self.assertEqual(child.host, u"127.0.0.1")
  72. self.assertIdentical(child.reactor, resource.reactor)
  73. def test_getChildWithSpecial(self):
  74. """
  75. The L{ReverseProxyResource} return by C{getChild} has a path which has
  76. already been quoted.
  77. """
  78. resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path")
  79. child = resource.getChild(b' /%', None)
  80. self.assertEqual(child.path, b"/path/%20%2F%25")
  81. class DummyChannel(object):
  82. """
  83. A dummy HTTP channel, that does nothing but holds a transport and saves
  84. connection lost.
  85. @ivar transport: the transport used by the client.
  86. @ivar lostReason: the reason saved at connection lost.
  87. """
  88. def __init__(self, transport):
  89. """
  90. Hold a reference to the transport.
  91. """
  92. self.transport = transport
  93. self.lostReason = None
  94. def connectionLost(self, reason):
  95. """
  96. Keep track of the connection lost reason.
  97. """
  98. self.lostReason = reason
  99. def getPeer(self):
  100. """
  101. Get peer information from the transport.
  102. """
  103. return self.transport.getPeer()
  104. def getHost(self):
  105. """
  106. Get host information from the transport.
  107. """
  108. return self.transport.getHost()
  109. class ProxyClientTests(TestCase):
  110. """
  111. Tests for L{ProxyClient}.
  112. """
  113. def _parseOutHeaders(self, content):
  114. """
  115. Parse the headers out of some web content.
  116. @param content: Bytes received from a web server.
  117. @return: A tuple of (requestLine, headers, body). C{headers} is a dict
  118. of headers, C{requestLine} is the first line (e.g. "POST /foo ...")
  119. and C{body} is whatever is left.
  120. """
  121. headers, body = content.split(b'\r\n\r\n')
  122. headers = headers.split(b'\r\n')
  123. requestLine = headers.pop(0)
  124. return (
  125. requestLine, dict(header.split(b': ') for header in headers), body)
  126. def makeRequest(self, path):
  127. """
  128. Make a dummy request object for the URL path.
  129. @param path: A URL path, beginning with a slash.
  130. @return: A L{DummyRequest}.
  131. """
  132. return DummyRequest(path)
  133. def makeProxyClient(self, request, method=b"GET", headers=None,
  134. requestBody=b""):
  135. """
  136. Make a L{ProxyClient} object used for testing.
  137. @param request: The request to use.
  138. @param method: The HTTP method to use, GET by default.
  139. @param headers: The HTTP headers to use expressed as a dict. If not
  140. provided, defaults to {'accept': 'text/html'}.
  141. @param requestBody: The body of the request. Defaults to the empty
  142. string.
  143. @return: A L{ProxyClient}
  144. """
  145. if headers is None:
  146. headers = {b"accept": b"text/html"}
  147. path = b'/' + request.postpath
  148. return ProxyClient(
  149. method, path, b'HTTP/1.0', headers, requestBody, request)
  150. def connectProxy(self, proxyClient):
  151. """
  152. Connect a proxy client to a L{StringTransportWithDisconnection}.
  153. @param proxyClient: A L{ProxyClient}.
  154. @return: The L{StringTransportWithDisconnection}.
  155. """
  156. clientTransport = StringTransportWithDisconnection()
  157. clientTransport.protocol = proxyClient
  158. proxyClient.makeConnection(clientTransport)
  159. return clientTransport
  160. def assertForwardsHeaders(self, proxyClient, requestLine, headers):
  161. """
  162. Assert that C{proxyClient} sends C{headers} when it connects.
  163. @param proxyClient: A L{ProxyClient}.
  164. @param requestLine: The request line we expect to be sent.
  165. @param headers: A dict of headers we expect to be sent.
  166. @return: If the assertion is successful, return the request body as
  167. bytes.
  168. """
  169. self.connectProxy(proxyClient)
  170. requestContent = proxyClient.transport.value()
  171. receivedLine, receivedHeaders, body = self._parseOutHeaders(
  172. requestContent)
  173. self.assertEqual(receivedLine, requestLine)
  174. self.assertEqual(receivedHeaders, headers)
  175. return body
  176. def makeResponseBytes(self, code, message, headers, body):
  177. lines = [b"HTTP/1.0 " + str(code).encode('ascii') + b' ' + message]
  178. for header, values in headers:
  179. for value in values:
  180. lines.append(header + b': ' + value)
  181. lines.extend([b'', body])
  182. return b'\r\n'.join(lines)
  183. def assertForwardsResponse(self, request, code, message, headers, body):
  184. """
  185. Assert that C{request} has forwarded a response from the server.
  186. @param request: A L{DummyRequest}.
  187. @param code: The expected HTTP response code.
  188. @param message: The expected HTTP message.
  189. @param headers: The expected HTTP headers.
  190. @param body: The expected response body.
  191. """
  192. self.assertEqual(request.responseCode, code)
  193. self.assertEqual(request.responseMessage, message)
  194. receivedHeaders = list(request.responseHeaders.getAllRawHeaders())
  195. receivedHeaders.sort()
  196. expectedHeaders = headers[:]
  197. expectedHeaders.sort()
  198. self.assertEqual(receivedHeaders, expectedHeaders)
  199. self.assertEqual(b''.join(request.written), body)
  200. def _testDataForward(self, code, message, headers, body, method=b"GET",
  201. requestBody=b"", loseConnection=True):
  202. """
  203. Build a fake proxy connection, and send C{data} over it, checking that
  204. it's forwarded to the originating request.
  205. """
  206. request = self.makeRequest(b'foo')
  207. client = self.makeProxyClient(
  208. request, method, {b'accept': b'text/html'}, requestBody)
  209. receivedBody = self.assertForwardsHeaders(
  210. client, method + b' /foo HTTP/1.0',
  211. {b'connection': b'close', b'accept': b'text/html'})
  212. self.assertEqual(receivedBody, requestBody)
  213. # Fake an answer
  214. client.dataReceived(
  215. self.makeResponseBytes(code, message, headers, body))
  216. # Check that the response data has been forwarded back to the original
  217. # requester.
  218. self.assertForwardsResponse(request, code, message, headers, body)
  219. # Check that when the response is done, the request is finished.
  220. if loseConnection:
  221. client.transport.loseConnection()
  222. # Even if we didn't call loseConnection, the transport should be
  223. # disconnected. This lets us not rely on the server to close our
  224. # sockets for us.
  225. self.assertFalse(client.transport.connected)
  226. self.assertEqual(request.finished, 1)
  227. def test_forward(self):
  228. """
  229. When connected to the server, L{ProxyClient} should send the saved
  230. request, with modifications of the headers, and then forward the result
  231. to the parent request.
  232. """
  233. return self._testDataForward(
  234. 200, b"OK", [(b"Foo", [b"bar", b"baz"])], b"Some data\r\n")
  235. def test_postData(self):
  236. """
  237. Try to post content in the request, and check that the proxy client
  238. forward the body of the request.
  239. """
  240. return self._testDataForward(
  241. 200, b"OK", [(b"Foo", [b"bar"])], b"Some data\r\n", b"POST", b"Some content")
  242. def test_statusWithMessage(self):
  243. """
  244. If the response contains a status with a message, it should be
  245. forwarded to the parent request with all the information.
  246. """
  247. return self._testDataForward(
  248. 404, b"Not Found", [], b"")
  249. def test_contentLength(self):
  250. """
  251. If the response contains a I{Content-Length} header, the inbound
  252. request object should still only have C{finish} called on it once.
  253. """
  254. data = b"foo bar baz"
  255. return self._testDataForward(
  256. 200,
  257. b"OK",
  258. [(b"Content-Length", [str(len(data)).encode('ascii')])],
  259. data)
  260. def test_losesConnection(self):
  261. """
  262. If the response contains a I{Content-Length} header, the outgoing
  263. connection is closed when all response body data has been received.
  264. """
  265. data = b"foo bar baz"
  266. return self._testDataForward(
  267. 200,
  268. b"OK",
  269. [(b"Content-Length", [str(len(data)).encode('ascii')])],
  270. data,
  271. loseConnection=False)
  272. def test_headersCleanups(self):
  273. """
  274. The headers given at initialization should be modified:
  275. B{proxy-connection} should be removed if present, and B{connection}
  276. should be added.
  277. """
  278. client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0',
  279. {b"accept": b"text/html", b"proxy-connection": b"foo"}, b'', None)
  280. self.assertEqual(client.headers,
  281. {b"accept": b"text/html", b"connection": b"close"})
  282. def test_keepaliveNotForwarded(self):
  283. """
  284. The proxy doesn't really know what to do with keepalive things from
  285. the remote server, so we stomp over any keepalive header we get from
  286. the client.
  287. """
  288. headers = {
  289. b"accept": b"text/html",
  290. b'keep-alive': b'300',
  291. b'connection': b'keep-alive',
  292. }
  293. expectedHeaders = headers.copy()
  294. expectedHeaders[b'connection'] = b'close'
  295. del expectedHeaders[b'keep-alive']
  296. client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0', headers, b'', None)
  297. self.assertForwardsHeaders(
  298. client, b'GET /foo HTTP/1.0', expectedHeaders)
  299. def test_defaultHeadersOverridden(self):
  300. """
  301. L{server.Request} within the proxy sets certain response headers by
  302. default. When we get these headers back from the remote server, the
  303. defaults are overridden rather than simply appended.
  304. """
  305. request = self.makeRequest(b'foo')
  306. request.responseHeaders.setRawHeaders(b'server', [b'old-bar'])
  307. request.responseHeaders.setRawHeaders(b'date', [b'old-baz'])
  308. request.responseHeaders.setRawHeaders(b'content-type', [b"old/qux"])
  309. client = self.makeProxyClient(request, headers={b'accept': b'text/html'})
  310. self.connectProxy(client)
  311. headers = {
  312. b'Server': [b'bar'],
  313. b'Date': [b'2010-01-01'],
  314. b'Content-Type': [b'application/x-baz'],
  315. }
  316. client.dataReceived(
  317. self.makeResponseBytes(200, b"OK", headers.items(), b''))
  318. self.assertForwardsResponse(
  319. request, 200, b'OK', list(headers.items()), b'')
  320. class ProxyClientFactoryTests(TestCase):
  321. """
  322. Tests for L{ProxyClientFactory}.
  323. """
  324. def test_connectionFailed(self):
  325. """
  326. Check that L{ProxyClientFactory.clientConnectionFailed} produces
  327. a B{501} response to the parent request.
  328. """
  329. request = DummyRequest([b'foo'])
  330. factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
  331. {b"accept": b"text/html"}, '', request)
  332. factory.clientConnectionFailed(None, None)
  333. self.assertEqual(request.responseCode, 501)
  334. self.assertEqual(request.responseMessage, b"Gateway error")
  335. self.assertEqual(
  336. list(request.responseHeaders.getAllRawHeaders()),
  337. [(b"Content-Type", [b"text/html"])])
  338. self.assertEqual(
  339. b''.join(request.written),
  340. b"<H1>Could not connect</H1>")
  341. self.assertEqual(request.finished, 1)
  342. def test_buildProtocol(self):
  343. """
  344. L{ProxyClientFactory.buildProtocol} should produce a L{ProxyClient}
  345. with the same values of attributes (with updates on the headers).
  346. """
  347. factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
  348. {b"accept": b"text/html"}, b'Some data',
  349. None)
  350. proto = factory.buildProtocol(None)
  351. self.assertIsInstance(proto, ProxyClient)
  352. self.assertEqual(proto.command, b'GET')
  353. self.assertEqual(proto.rest, b'/foo')
  354. self.assertEqual(proto.data, b'Some data')
  355. self.assertEqual(proto.headers,
  356. {b"accept": b"text/html", b"connection": b"close"})
  357. class ProxyRequestTests(TestCase):
  358. """
  359. Tests for L{ProxyRequest}.
  360. """
  361. def _testProcess(self, uri, expectedURI, method=b"GET", data=b""):
  362. """
  363. Build a request pointing at C{uri}, and check that a proxied request
  364. is created, pointing a C{expectedURI}.
  365. """
  366. transport = StringTransportWithDisconnection()
  367. channel = DummyChannel(transport)
  368. reactor = MemoryReactor()
  369. request = ProxyRequest(channel, False, reactor)
  370. request.gotLength(len(data))
  371. request.handleContentChunk(data)
  372. request.requestReceived(method, b'http://example.com' + uri,
  373. b'HTTP/1.0')
  374. self.assertEqual(len(reactor.tcpClients), 1)
  375. self.assertEqual(reactor.tcpClients[0][0], u"example.com")
  376. self.assertEqual(reactor.tcpClients[0][1], 80)
  377. factory = reactor.tcpClients[0][2]
  378. self.assertIsInstance(factory, ProxyClientFactory)
  379. self.assertEqual(factory.command, method)
  380. self.assertEqual(factory.version, b'HTTP/1.0')
  381. self.assertEqual(factory.headers, {b'host': b'example.com'})
  382. self.assertEqual(factory.data, data)
  383. self.assertEqual(factory.rest, expectedURI)
  384. self.assertEqual(factory.father, request)
  385. def test_process(self):
  386. """
  387. L{ProxyRequest.process} should create a connection to the given server,
  388. with a L{ProxyClientFactory} as connection factory, with the correct
  389. parameters:
  390. - forward comment, version and data values
  391. - update headers with the B{host} value
  392. - remove the host from the URL
  393. - pass the request as parent request
  394. """
  395. return self._testProcess(b"/foo/bar", b"/foo/bar")
  396. def test_processWithoutTrailingSlash(self):
  397. """
  398. If the incoming request doesn't contain a slash,
  399. L{ProxyRequest.process} should add one when instantiating
  400. L{ProxyClientFactory}.
  401. """
  402. return self._testProcess(b"", b"/")
  403. def test_processWithData(self):
  404. """
  405. L{ProxyRequest.process} should be able to retrieve request body and
  406. to forward it.
  407. """
  408. return self._testProcess(
  409. b"/foo/bar", b"/foo/bar", b"POST", b"Some content")
  410. def test_processWithPort(self):
  411. """
  412. Check that L{ProxyRequest.process} correctly parse port in the incoming
  413. URL, and create an outgoing connection with this port.
  414. """
  415. transport = StringTransportWithDisconnection()
  416. channel = DummyChannel(transport)
  417. reactor = MemoryReactor()
  418. request = ProxyRequest(channel, False, reactor)
  419. request.gotLength(0)
  420. request.requestReceived(b'GET', b'http://example.com:1234/foo/bar',
  421. b'HTTP/1.0')
  422. # That should create one connection, with the port parsed from the URL
  423. self.assertEqual(len(reactor.tcpClients), 1)
  424. self.assertEqual(reactor.tcpClients[0][0], u"example.com")
  425. self.assertEqual(reactor.tcpClients[0][1], 1234)
  426. class DummyFactory(object):
  427. """
  428. A simple holder for C{host} and C{port} information.
  429. """
  430. def __init__(self, host, port):
  431. self.host = host
  432. self.port = port
  433. class ReverseProxyRequestTests(TestCase):
  434. """
  435. Tests for L{ReverseProxyRequest}.
  436. """
  437. def test_process(self):
  438. """
  439. L{ReverseProxyRequest.process} should create a connection to its
  440. factory host/port, using a L{ProxyClientFactory} instantiated with the
  441. correct parameters, and particularly set the B{host} header to the
  442. factory host.
  443. """
  444. transport = StringTransportWithDisconnection()
  445. channel = DummyChannel(transport)
  446. reactor = MemoryReactor()
  447. request = ReverseProxyRequest(channel, False, reactor)
  448. request.factory = DummyFactory(u"example.com", 1234)
  449. request.gotLength(0)
  450. request.requestReceived(b'GET', b'/foo/bar', b'HTTP/1.0')
  451. # Check that one connection has been created, to the good host/port
  452. self.assertEqual(len(reactor.tcpClients), 1)
  453. self.assertEqual(reactor.tcpClients[0][0], u"example.com")
  454. self.assertEqual(reactor.tcpClients[0][1], 1234)
  455. # Check the factory passed to the connect, and its headers
  456. factory = reactor.tcpClients[0][2]
  457. self.assertIsInstance(factory, ProxyClientFactory)
  458. self.assertEqual(factory.headers, {b'host': b'example.com'})