cookies.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import time
  2. from six.moves.http_cookiejar import (
  3. CookieJar as _CookieJar, DefaultCookiePolicy, IPV4_RE
  4. )
  5. from scrapy.utils.httpobj import urlparse_cached
  6. from scrapy.utils.python import to_native_str
  7. class CookieJar(object):
  8. def __init__(self, policy=None, check_expired_frequency=10000):
  9. self.policy = policy or DefaultCookiePolicy()
  10. self.jar = _CookieJar(self.policy)
  11. self.jar._cookies_lock = _DummyLock()
  12. self.check_expired_frequency = check_expired_frequency
  13. self.processed = 0
  14. def extract_cookies(self, response, request):
  15. wreq = WrappedRequest(request)
  16. wrsp = WrappedResponse(response)
  17. return self.jar.extract_cookies(wrsp, wreq)
  18. def add_cookie_header(self, request):
  19. wreq = WrappedRequest(request)
  20. self.policy._now = self.jar._now = int(time.time())
  21. # the cookiejar implementation iterates through all domains
  22. # instead we restrict to potential matches on the domain
  23. req_host = urlparse_cached(request).hostname
  24. if not req_host:
  25. return
  26. if not IPV4_RE.search(req_host):
  27. hosts = potential_domain_matches(req_host)
  28. if '.' not in req_host:
  29. hosts += [req_host + ".local"]
  30. else:
  31. hosts = [req_host]
  32. cookies = []
  33. for host in hosts:
  34. if host in self.jar._cookies:
  35. cookies += self.jar._cookies_for_domain(host, wreq)
  36. attrs = self.jar._cookie_attrs(cookies)
  37. if attrs:
  38. if not wreq.has_header("Cookie"):
  39. wreq.add_unredirected_header("Cookie", "; ".join(attrs))
  40. self.processed += 1
  41. if self.processed % self.check_expired_frequency == 0:
  42. # This is still quite inefficient for large number of cookies
  43. self.jar.clear_expired_cookies()
  44. @property
  45. def _cookies(self):
  46. return self.jar._cookies
  47. def clear_session_cookies(self, *args, **kwargs):
  48. return self.jar.clear_session_cookies(*args, **kwargs)
  49. def clear(self, domain=None, path=None, name=None):
  50. return self.jar.clear(domain, path, name)
  51. def __iter__(self):
  52. return iter(self.jar)
  53. def __len__(self):
  54. return len(self.jar)
  55. def set_policy(self, pol):
  56. return self.jar.set_policy(pol)
  57. def make_cookies(self, response, request):
  58. wreq = WrappedRequest(request)
  59. wrsp = WrappedResponse(response)
  60. return self.jar.make_cookies(wrsp, wreq)
  61. def set_cookie(self, cookie):
  62. self.jar.set_cookie(cookie)
  63. def set_cookie_if_ok(self, cookie, request):
  64. self.jar.set_cookie_if_ok(cookie, WrappedRequest(request))
  65. def potential_domain_matches(domain):
  66. """Potential domain matches for a cookie
  67. >>> potential_domain_matches('www.example.com')
  68. ['www.example.com', 'example.com', '.www.example.com', '.example.com']
  69. """
  70. matches = [domain]
  71. try:
  72. start = domain.index('.') + 1
  73. end = domain.rindex('.')
  74. while start < end:
  75. matches.append(domain[start:])
  76. start = domain.index('.', start) + 1
  77. except ValueError:
  78. pass
  79. return matches + ['.' + d for d in matches]
  80. class _DummyLock(object):
  81. def acquire(self):
  82. pass
  83. def release(self):
  84. pass
  85. class WrappedRequest(object):
  86. """Wraps a scrapy Request class with methods defined by urllib2.Request class to interact with CookieJar class
  87. see http://docs.python.org/library/urllib2.html#urllib2.Request
  88. """
  89. def __init__(self, request):
  90. self.request = request
  91. def get_full_url(self):
  92. return self.request.url
  93. def get_host(self):
  94. return urlparse_cached(self.request).netloc
  95. def get_type(self):
  96. return urlparse_cached(self.request).scheme
  97. def is_unverifiable(self):
  98. """Unverifiable should indicate whether the request is unverifiable, as defined by RFC 2965.
  99. It defaults to False. An unverifiable request is one whose URL the user did not have the
  100. option to approve. For example, if the request is for an image in an
  101. HTML document, and the user had no option to approve the automatic
  102. fetching of the image, this should be true.
  103. """
  104. return self.request.meta.get('is_unverifiable', False)
  105. def get_origin_req_host(self):
  106. return urlparse_cached(self.request).hostname
  107. # python3 uses attributes instead of methods
  108. @property
  109. def full_url(self):
  110. return self.get_full_url()
  111. @property
  112. def host(self):
  113. return self.get_host()
  114. @property
  115. def type(self):
  116. return self.get_type()
  117. @property
  118. def unverifiable(self):
  119. return self.is_unverifiable()
  120. @property
  121. def origin_req_host(self):
  122. return self.get_origin_req_host()
  123. def has_header(self, name):
  124. return name in self.request.headers
  125. def get_header(self, name, default=None):
  126. return to_native_str(self.request.headers.get(name, default),
  127. errors='replace')
  128. def header_items(self):
  129. return [
  130. (to_native_str(k, errors='replace'),
  131. [to_native_str(x, errors='replace') for x in v])
  132. for k, v in self.request.headers.items()
  133. ]
  134. def add_unredirected_header(self, name, value):
  135. self.request.headers.appendlist(name, value)
  136. class WrappedResponse(object):
  137. def __init__(self, response):
  138. self.response = response
  139. def info(self):
  140. return self
  141. # python3 cookiejars calls get_all
  142. def get_all(self, name, default=None):
  143. return [to_native_str(v, errors='replace')
  144. for v in self.response.headers.getlist(name)]
  145. # python2 cookiejars calls getheaders
  146. getheaders = get_all