message.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # -*- coding: utf-8 -
  2. #
  3. # This file is part of gunicorn released under the MIT license.
  4. # See the NOTICE for more information.
  5. import re
  6. import socket
  7. from errno import ENOTCONN
  8. from gunicorn._compat import bytes_to_str
  9. from gunicorn.http.unreader import SocketUnreader
  10. from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
  11. from gunicorn.http.errors import (InvalidHeader, InvalidHeaderName, NoMoreData,
  12. InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
  13. LimitRequestLine, LimitRequestHeaders)
  14. from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest
  15. from gunicorn.six import BytesIO
  16. from gunicorn._compat import urlsplit
  17. MAX_REQUEST_LINE = 8190
  18. MAX_HEADERS = 32768
  19. DEFAULT_MAX_HEADERFIELD_SIZE = 8190
  20. HEADER_RE = re.compile("[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]")
  21. METH_RE = re.compile(r"[A-Z0-9$-_.]{3,20}")
  22. VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
  23. class Message(object):
  24. def __init__(self, cfg, unreader):
  25. self.cfg = cfg
  26. self.unreader = unreader
  27. self.version = None
  28. self.headers = []
  29. self.trailers = []
  30. self.body = None
  31. # set headers limits
  32. self.limit_request_fields = cfg.limit_request_fields
  33. if (self.limit_request_fields <= 0
  34. or self.limit_request_fields > MAX_HEADERS):
  35. self.limit_request_fields = MAX_HEADERS
  36. self.limit_request_field_size = cfg.limit_request_field_size
  37. if self.limit_request_field_size < 0:
  38. self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
  39. # set max header buffer size
  40. max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
  41. self.max_buffer_headers = self.limit_request_fields * \
  42. (max_header_field_size + 2) + 4
  43. unused = self.parse(self.unreader)
  44. self.unreader.unread(unused)
  45. self.set_body_reader()
  46. def parse(self):
  47. raise NotImplementedError()
  48. def parse_headers(self, data):
  49. headers = []
  50. # Split lines on \r\n keeping the \r\n on each line
  51. lines = [bytes_to_str(line) + "\r\n" for line in data.split(b"\r\n")]
  52. # Parse headers into key/value pairs paying attention
  53. # to continuation lines.
  54. while len(lines):
  55. if len(headers) >= self.limit_request_fields:
  56. raise LimitRequestHeaders("limit request headers fields")
  57. # Parse initial header name : value pair.
  58. curr = lines.pop(0)
  59. header_length = len(curr)
  60. if curr.find(":") < 0:
  61. raise InvalidHeader(curr.strip())
  62. name, value = curr.split(":", 1)
  63. name = name.rstrip(" \t").upper()
  64. if HEADER_RE.search(name):
  65. raise InvalidHeaderName(name)
  66. name, value = name.strip(), [value.lstrip()]
  67. # Consume value continuation lines
  68. while len(lines) and lines[0].startswith((" ", "\t")):
  69. curr = lines.pop(0)
  70. header_length += len(curr)
  71. if header_length > self.limit_request_field_size > 0:
  72. raise LimitRequestHeaders("limit request headers "
  73. + "fields size")
  74. value.append(curr)
  75. value = ''.join(value).rstrip()
  76. if header_length > self.limit_request_field_size > 0:
  77. raise LimitRequestHeaders("limit request headers fields size")
  78. headers.append((name, value))
  79. return headers
  80. def set_body_reader(self):
  81. chunked = False
  82. content_length = None
  83. for (name, value) in self.headers:
  84. if name == "CONTENT-LENGTH":
  85. content_length = value
  86. elif name == "TRANSFER-ENCODING":
  87. chunked = value.lower() == "chunked"
  88. elif name == "SEC-WEBSOCKET-KEY1":
  89. content_length = 8
  90. if chunked:
  91. self.body = Body(ChunkedReader(self, self.unreader))
  92. elif content_length is not None:
  93. try:
  94. content_length = int(content_length)
  95. except ValueError:
  96. raise InvalidHeader("CONTENT-LENGTH", req=self)
  97. if content_length < 0:
  98. raise InvalidHeader("CONTENT-LENGTH", req=self)
  99. self.body = Body(LengthReader(self.unreader, content_length))
  100. else:
  101. self.body = Body(EOFReader(self.unreader))
  102. def should_close(self):
  103. for (h, v) in self.headers:
  104. if h == "CONNECTION":
  105. v = v.lower().strip()
  106. if v == "close":
  107. return True
  108. elif v == "keep-alive":
  109. return False
  110. break
  111. return self.version <= (1, 0)
  112. class Request(Message):
  113. def __init__(self, cfg, unreader, req_number=1):
  114. self.method = None
  115. self.uri = None
  116. self.path = None
  117. self.query = None
  118. self.fragment = None
  119. # get max request line size
  120. self.limit_request_line = cfg.limit_request_line
  121. if (self.limit_request_line < 0
  122. or self.limit_request_line >= MAX_REQUEST_LINE):
  123. self.limit_request_line = MAX_REQUEST_LINE
  124. self.req_number = req_number
  125. self.proxy_protocol_info = None
  126. super(Request, self).__init__(cfg, unreader)
  127. def get_data(self, unreader, buf, stop=False):
  128. data = unreader.read()
  129. if not data:
  130. if stop:
  131. raise StopIteration()
  132. raise NoMoreData(buf.getvalue())
  133. buf.write(data)
  134. def parse(self, unreader):
  135. buf = BytesIO()
  136. self.get_data(unreader, buf, stop=True)
  137. # get request line
  138. line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
  139. # proxy protocol
  140. if self.proxy_protocol(bytes_to_str(line)):
  141. # get next request line
  142. buf = BytesIO()
  143. buf.write(rbuf)
  144. line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
  145. self.parse_request_line(bytes_to_str(line))
  146. buf = BytesIO()
  147. buf.write(rbuf)
  148. # Headers
  149. data = buf.getvalue()
  150. idx = data.find(b"\r\n\r\n")
  151. done = data[:2] == b"\r\n"
  152. while True:
  153. idx = data.find(b"\r\n\r\n")
  154. done = data[:2] == b"\r\n"
  155. if idx < 0 and not done:
  156. self.get_data(unreader, buf)
  157. data = buf.getvalue()
  158. if len(data) > self.max_buffer_headers:
  159. raise LimitRequestHeaders("max buffer headers")
  160. else:
  161. break
  162. if done:
  163. self.unreader.unread(data[2:])
  164. return b""
  165. self.headers = self.parse_headers(data[:idx])
  166. ret = data[idx + 4:]
  167. buf = BytesIO()
  168. return ret
  169. def read_line(self, unreader, buf, limit=0):
  170. data = buf.getvalue()
  171. while True:
  172. idx = data.find(b"\r\n")
  173. if idx >= 0:
  174. # check if the request line is too large
  175. if idx > limit > 0:
  176. raise LimitRequestLine(idx, limit)
  177. break
  178. elif len(data) - 2 > limit > 0:
  179. raise LimitRequestLine(len(data), limit)
  180. self.get_data(unreader, buf)
  181. data = buf.getvalue()
  182. return (data[:idx], # request line,
  183. data[idx + 2:]) # residue in the buffer, skip \r\n
  184. def proxy_protocol(self, line):
  185. """\
  186. Detect, check and parse proxy protocol.
  187. :raises: ForbiddenProxyRequest, InvalidProxyLine.
  188. :return: True for proxy protocol line else False
  189. """
  190. if not self.cfg.proxy_protocol:
  191. return False
  192. if self.req_number != 1:
  193. return False
  194. if not line.startswith("PROXY"):
  195. return False
  196. self.proxy_protocol_access_check()
  197. self.parse_proxy_protocol(line)
  198. return True
  199. def proxy_protocol_access_check(self):
  200. # check in allow list
  201. if isinstance(self.unreader, SocketUnreader):
  202. try:
  203. remote_host = self.unreader.sock.getpeername()[0]
  204. except socket.error as e:
  205. if e.args[0] == ENOTCONN:
  206. raise ForbiddenProxyRequest("UNKNOW")
  207. raise
  208. if ("*" not in self.cfg.proxy_allow_ips and
  209. remote_host not in self.cfg.proxy_allow_ips):
  210. raise ForbiddenProxyRequest(remote_host)
  211. def parse_proxy_protocol(self, line):
  212. bits = line.split()
  213. if len(bits) != 6:
  214. raise InvalidProxyLine(line)
  215. # Extract data
  216. proto = bits[1]
  217. s_addr = bits[2]
  218. d_addr = bits[3]
  219. # Validation
  220. if proto not in ["TCP4", "TCP6"]:
  221. raise InvalidProxyLine("protocol '%s' not supported" % proto)
  222. if proto == "TCP4":
  223. try:
  224. socket.inet_pton(socket.AF_INET, s_addr)
  225. socket.inet_pton(socket.AF_INET, d_addr)
  226. except socket.error:
  227. raise InvalidProxyLine(line)
  228. elif proto == "TCP6":
  229. try:
  230. socket.inet_pton(socket.AF_INET6, s_addr)
  231. socket.inet_pton(socket.AF_INET6, d_addr)
  232. except socket.error:
  233. raise InvalidProxyLine(line)
  234. try:
  235. s_port = int(bits[4])
  236. d_port = int(bits[5])
  237. except ValueError:
  238. raise InvalidProxyLine("invalid port %s" % line)
  239. if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
  240. raise InvalidProxyLine("invalid port %s" % line)
  241. # Set data
  242. self.proxy_protocol_info = {
  243. "proxy_protocol": proto,
  244. "client_addr": s_addr,
  245. "client_port": s_port,
  246. "proxy_addr": d_addr,
  247. "proxy_port": d_port
  248. }
  249. def parse_request_line(self, line):
  250. bits = line.split(None, 2)
  251. if len(bits) != 3:
  252. raise InvalidRequestLine(line)
  253. # Method
  254. if not METH_RE.match(bits[0]):
  255. raise InvalidRequestMethod(bits[0])
  256. self.method = bits[0].upper()
  257. # URI
  258. # When the path starts with //, urlsplit considers it as a
  259. # relative uri while the RDF says it shouldnt
  260. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
  261. # considers it as an absolute url.
  262. # fix issue #297
  263. if bits[1].startswith("//"):
  264. self.uri = bits[1][1:]
  265. else:
  266. self.uri = bits[1]
  267. try:
  268. parts = urlsplit(self.uri)
  269. except ValueError:
  270. raise InvalidRequestLine(line)
  271. self.path = parts.path or ""
  272. self.query = parts.query or ""
  273. self.fragment = parts.fragment or ""
  274. # Version
  275. match = VERSION_RE.match(bits[2])
  276. if match is None:
  277. raise InvalidHTTPVersion(bits[2])
  278. self.version = (int(match.group(1)), int(match.group(2)))
  279. def set_body_reader(self):
  280. super(Request, self).set_body_reader()
  281. if isinstance(self.body.reader, EOFReader):
  282. self.body = Body(LengthReader(self.unreader, 0))