middleware.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import re
  2. from django import http
  3. from django.apps import apps
  4. from django.utils.cache import patch_vary_headers
  5. from django.utils.six.moves.urllib.parse import urlparse
  6. from .compat import MiddlewareMixin
  7. from .conf import conf
  8. from .signals import check_request_enabled
  9. ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
  10. ACCESS_CONTROL_EXPOSE_HEADERS = 'Access-Control-Expose-Headers'
  11. ACCESS_CONTROL_ALLOW_CREDENTIALS = 'Access-Control-Allow-Credentials'
  12. ACCESS_CONTROL_ALLOW_HEADERS = 'Access-Control-Allow-Headers'
  13. ACCESS_CONTROL_ALLOW_METHODS = 'Access-Control-Allow-Methods'
  14. ACCESS_CONTROL_MAX_AGE = 'Access-Control-Max-Age'
  15. class CorsPostCsrfMiddleware(MiddlewareMixin):
  16. def _https_referer_replace_reverse(self, request):
  17. """
  18. Put the HTTP_REFERER back to its original value and delete the
  19. temporary storage
  20. """
  21. if conf.CORS_REPLACE_HTTPS_REFERER and 'ORIGINAL_HTTP_REFERER' in request.META:
  22. http_referer = request.META['ORIGINAL_HTTP_REFERER']
  23. request.META['HTTP_REFERER'] = http_referer
  24. del request.META['ORIGINAL_HTTP_REFERER']
  25. def process_request(self, request):
  26. self._https_referer_replace_reverse(request)
  27. return None
  28. def process_view(self, request, callback, callback_args, callback_kwargs):
  29. self._https_referer_replace_reverse(request)
  30. return None
  31. class CorsMiddleware(MiddlewareMixin):
  32. def _https_referer_replace(self, request):
  33. """
  34. When https is enabled, django CSRF checking includes referer checking
  35. which breaks when using CORS. This function updates the HTTP_REFERER
  36. header to make sure it matches HTTP_HOST, provided that our cors logic
  37. succeeds
  38. """
  39. origin = request.META.get('HTTP_ORIGIN')
  40. if request.is_secure() and origin and 'ORIGINAL_HTTP_REFERER' not in request.META:
  41. url = urlparse(origin)
  42. if not conf.CORS_ORIGIN_ALLOW_ALL and not self.origin_found_in_white_lists(origin, url):
  43. return
  44. try:
  45. http_referer = request.META['HTTP_REFERER']
  46. http_host = "https://%s/" % request.META['HTTP_HOST']
  47. request.META = request.META.copy()
  48. request.META['ORIGINAL_HTTP_REFERER'] = http_referer
  49. request.META['HTTP_REFERER'] = http_host
  50. except KeyError:
  51. pass
  52. def process_request(self, request):
  53. """
  54. If CORS preflight header, then create an
  55. empty body response (200 OK) and return it
  56. Django won't bother calling any other request
  57. view/exception middleware along with the requested view;
  58. it will call any response middlewares
  59. """
  60. request._cors_enabled = self.is_enabled(request)
  61. if request._cors_enabled:
  62. if conf.CORS_REPLACE_HTTPS_REFERER:
  63. self._https_referer_replace(request)
  64. if (
  65. request.method == 'OPTIONS' and
  66. 'HTTP_ACCESS_CONTROL_REQUEST_METHOD' in request.META
  67. ):
  68. return http.HttpResponse()
  69. def process_view(self, request, callback, callback_args, callback_kwargs):
  70. """
  71. Do the referer replacement here as well
  72. """
  73. if request._cors_enabled and conf.CORS_REPLACE_HTTPS_REFERER:
  74. self._https_referer_replace(request)
  75. return None
  76. def process_response(self, request, response):
  77. """
  78. Add the respective CORS headers
  79. """
  80. origin = request.META.get('HTTP_ORIGIN')
  81. if not origin:
  82. return response
  83. enabled = getattr(request, '_cors_enabled', None)
  84. if enabled is None:
  85. enabled = self.is_enabled(request)
  86. if not enabled:
  87. return response
  88. # todo: check hostname from db instead
  89. url = urlparse(origin)
  90. if conf.CORS_ALLOW_CREDENTIALS:
  91. response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = 'true'
  92. if (
  93. not conf.CORS_ORIGIN_ALLOW_ALL and
  94. not self.origin_found_in_white_lists(origin, url) and
  95. not self.origin_found_in_model(url) and
  96. not self.check_signal(request)
  97. ):
  98. return response
  99. if conf.CORS_ORIGIN_ALLOW_ALL and not conf.CORS_ALLOW_CREDENTIALS:
  100. response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
  101. else:
  102. response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
  103. patch_vary_headers(response, ['Origin'])
  104. if len(conf.CORS_EXPOSE_HEADERS):
  105. response[ACCESS_CONTROL_EXPOSE_HEADERS] = ', '.join(conf.CORS_EXPOSE_HEADERS)
  106. if request.method == 'OPTIONS':
  107. response[ACCESS_CONTROL_ALLOW_HEADERS] = ', '.join(conf.CORS_ALLOW_HEADERS)
  108. response[ACCESS_CONTROL_ALLOW_METHODS] = ', '.join(conf.CORS_ALLOW_METHODS)
  109. if conf.CORS_PREFLIGHT_MAX_AGE:
  110. response[ACCESS_CONTROL_MAX_AGE] = conf.CORS_PREFLIGHT_MAX_AGE
  111. return response
  112. def origin_found_in_white_lists(self, origin, url):
  113. return (
  114. url.netloc in conf.CORS_ORIGIN_WHITELIST or
  115. (origin == 'null' and origin in conf.CORS_ORIGIN_WHITELIST) or
  116. self.regex_domain_match(origin)
  117. )
  118. def regex_domain_match(self, origin):
  119. for domain_pattern in conf.CORS_ORIGIN_REGEX_WHITELIST:
  120. if re.match(domain_pattern, origin):
  121. return origin
  122. def origin_found_in_model(self, url):
  123. if conf.CORS_MODEL is None:
  124. return False
  125. model = apps.get_model(*conf.CORS_MODEL.split('.'))
  126. return model.objects.filter(cors=url.netloc).exists()
  127. def is_enabled(self, request):
  128. return (
  129. re.match(conf.CORS_URLS_REGEX, request.path) or
  130. self.check_signal(request)
  131. )
  132. def check_signal(self, request):
  133. signal_responses = check_request_enabled.send(
  134. sender=None,
  135. request=request,
  136. )
  137. return any(
  138. return_value for
  139. function, return_value in signal_responses
  140. )