__init__.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import sys
  2. import re
  3. from functools import wraps
  4. from inspect import getmembers
  5. from unittest import TestCase
  6. from scrapy.http import Request
  7. from scrapy.utils.spider import iterate_spider_output
  8. from scrapy.utils.python import get_spec
  9. class ContractsManager(object):
  10. contracts = {}
  11. def __init__(self, contracts):
  12. for contract in contracts:
  13. self.contracts[contract.name] = contract
  14. def tested_methods_from_spidercls(self, spidercls):
  15. methods = []
  16. for key, value in getmembers(spidercls):
  17. if (callable(value) and value.__doc__ and
  18. re.search(r'^\s*@', value.__doc__, re.MULTILINE)):
  19. methods.append(key)
  20. return methods
  21. def extract_contracts(self, method):
  22. contracts = []
  23. for line in method.__doc__.split('\n'):
  24. line = line.strip()
  25. if line.startswith('@'):
  26. name, args = re.match(r'@(\w+)\s*(.*)', line).groups()
  27. args = re.split(r'\s+', args)
  28. contracts.append(self.contracts[name](method, *args))
  29. return contracts
  30. def from_spider(self, spider, results):
  31. requests = []
  32. for method in self.tested_methods_from_spidercls(type(spider)):
  33. bound_method = spider.__getattribute__(method)
  34. try:
  35. requests.append(self.from_method(bound_method, results))
  36. except Exception:
  37. case = _create_testcase(bound_method, 'contract')
  38. results.addError(case, sys.exc_info())
  39. return requests
  40. def from_method(self, method, results):
  41. contracts = self.extract_contracts(method)
  42. if contracts:
  43. request_cls = Request
  44. for contract in contracts:
  45. if contract.request_cls is not None:
  46. request_cls = contract.request_cls
  47. # calculate request args
  48. args, kwargs = get_spec(request_cls.__init__)
  49. # Don't filter requests to allow
  50. # testing different callbacks on the same URL.
  51. kwargs['dont_filter'] = True
  52. kwargs['callback'] = method
  53. for contract in contracts:
  54. kwargs = contract.adjust_request_args(kwargs)
  55. args.remove('self')
  56. # check if all positional arguments are defined in kwargs
  57. if set(args).issubset(set(kwargs)):
  58. request = request_cls(**kwargs)
  59. # execute pre and post hooks in order
  60. for contract in reversed(contracts):
  61. request = contract.add_pre_hook(request, results)
  62. for contract in contracts:
  63. request = contract.add_post_hook(request, results)
  64. self._clean_req(request, method, results)
  65. return request
  66. def _clean_req(self, request, method, results):
  67. """ stop the request from returning objects and records any errors """
  68. cb = request.callback
  69. @wraps(cb)
  70. def cb_wrapper(response, **cb_kwargs):
  71. try:
  72. output = cb(response, **cb_kwargs)
  73. output = list(iterate_spider_output(output))
  74. except Exception:
  75. case = _create_testcase(method, 'callback')
  76. results.addError(case, sys.exc_info())
  77. def eb_wrapper(failure):
  78. case = _create_testcase(method, 'errback')
  79. exc_info = failure.type, failure.value, failure.getTracebackObject()
  80. results.addError(case, exc_info)
  81. request.callback = cb_wrapper
  82. request.errback = eb_wrapper
  83. class Contract(object):
  84. """ Abstract class for contracts """
  85. request_cls = None
  86. def __init__(self, method, *args):
  87. self.testcase_pre = _create_testcase(method, '@%s pre-hook' % self.name)
  88. self.testcase_post = _create_testcase(method, '@%s post-hook' % self.name)
  89. self.args = args
  90. def add_pre_hook(self, request, results):
  91. if hasattr(self, 'pre_process'):
  92. cb = request.callback
  93. @wraps(cb)
  94. def wrapper(response, **cb_kwargs):
  95. try:
  96. results.startTest(self.testcase_pre)
  97. self.pre_process(response)
  98. results.stopTest(self.testcase_pre)
  99. except AssertionError:
  100. results.addFailure(self.testcase_pre, sys.exc_info())
  101. except Exception:
  102. results.addError(self.testcase_pre, sys.exc_info())
  103. else:
  104. results.addSuccess(self.testcase_pre)
  105. finally:
  106. return list(iterate_spider_output(cb(response, **cb_kwargs)))
  107. request.callback = wrapper
  108. return request
  109. def add_post_hook(self, request, results):
  110. if hasattr(self, 'post_process'):
  111. cb = request.callback
  112. @wraps(cb)
  113. def wrapper(response, **cb_kwargs):
  114. output = list(iterate_spider_output(cb(response, **cb_kwargs)))
  115. try:
  116. results.startTest(self.testcase_post)
  117. self.post_process(output)
  118. results.stopTest(self.testcase_post)
  119. except AssertionError:
  120. results.addFailure(self.testcase_post, sys.exc_info())
  121. except Exception:
  122. results.addError(self.testcase_post, sys.exc_info())
  123. else:
  124. results.addSuccess(self.testcase_post)
  125. finally:
  126. return output
  127. request.callback = wrapper
  128. return request
  129. def adjust_request_args(self, args):
  130. return args
  131. def _create_testcase(method, desc):
  132. spider = method.__self__.name
  133. class ContractTestCase(TestCase):
  134. def __str__(_self):
  135. return "[%s] %s (%s)" % (spider, method.__name__, desc)
  136. name = '%s_%s' % (spider, method.__name__)
  137. setattr(ContractTestCase, name, lambda x: x)
  138. return ContractTestCase(name)