pytest_concurrent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import sys
  4. import time
  5. import multiprocessing
  6. import concurrent.futures
  7. import collections
  8. import psutil
  9. import py
  10. import pytest
  11. from _pytest.junitxml import LogXML
  12. from _pytest.terminal import TerminalReporter
  13. from _pytest.junitxml import Junit
  14. from _pytest.junitxml import _NodeReporter
  15. from _pytest.junitxml import bin_xml_escape
  16. from _pytest.junitxml import mangle_test_address
  17. # Manager for the shared variables being used by in multiprocess mode
  18. MANAGER = multiprocessing.Manager()
  19. # to override the variable self.stats from LogXML
  20. XMLSTATS = MANAGER.dict()
  21. XMLSTATS['error'] = 0
  22. XMLSTATS['passed'] = 0
  23. XMLSTATS['failure'] = 0
  24. XMLSTATS['skipped'] = 0
  25. # ensures that XMLSTATS is not being modified simultaneously
  26. XMLLOCK = multiprocessing.Lock()
  27. XMLREPORTER = MANAGER.dict()
  28. # XMLREPORTER_ORDERED = MANAGER.list()
  29. NODELOCK = multiprocessing.Lock()
  30. NODEREPORTS = MANAGER.list()
  31. # to keep track of the log for TerminalReporter
  32. DICTIONARY = MANAGER.dict()
  33. # to override the variable self.stats from TerminalReporter
  34. STATS = MANAGER.dict()
  35. # ensures that STATS is not being modified simultaneously
  36. LOCK = multiprocessing.Lock()
  37. def pytest_addoption(parser):
  38. group = parser.getgroup('concurrent')
  39. group.addoption(
  40. '--concmode',
  41. action='store',
  42. dest='concurrent_mode',
  43. default=None,
  44. help='Set the concurrent mode (mthread, mproc, asyncnet)'
  45. )
  46. group.addoption(
  47. '--concworkers',
  48. action='store',
  49. dest='concurrent_workers',
  50. default=None,
  51. help='Set the concurrent worker amount (default to maximum)'
  52. )
  53. parser.addini('concurrent_mode', 'Set the concurrent mode (mthread, mproc, asyncnet)')
  54. parser.addini('concurrent_workers', 'Set the concurrent worker amount (default to maximum)')
  55. def pytest_runtestloop(session):
  56. '''Initialize a single test session'''
  57. if (session.testsfailed and
  58. not session.config.option.continue_on_collection_errors):
  59. raise session.Interrupted(
  60. "%d errors during collection" % session.testsfailed)
  61. if session.config.option.collectonly:
  62. return True
  63. mode = session.config.option.concurrent_mode if session.config.option.concurrent_mode \
  64. else session.config.getini('concurrent_mode')
  65. if mode and mode not in ['mproc', 'mthread', 'asyncnet']:
  66. raise NotImplementedError('Concurrent mode %s is not supported (available: mproc, mthread, asyncnet).' % mode)
  67. try:
  68. workers_raw = session.config.option.concurrent_workers if session.config.option.concurrent_workers else session.config.getini('concurrent_workers')
  69. # set worker amount to the collected test amount
  70. if workers_raw == 'max':
  71. workers_raw = len(session.items)
  72. workers = int(workers_raw) if workers_raw else None
  73. if sys.version_info < (3, 5) and sys.version_info > (3, 0):
  74. # backport max worker: https://github.com/python/cpython/blob/3.5/Lib/concurrent/futures/thread.py#L91-L94
  75. cpu_counter = os if sys.version_info > (3, 4) else psutil
  76. workers = (cpu_counter.cpu_count() or 1) * 5
  77. except ValueError:
  78. raise ValueError('Concurrent workers can only be integer.')
  79. # group collected tests into different lists
  80. groups = collections.defaultdict(list)
  81. ungrouped_items = list()
  82. for item in session.items:
  83. concurrent_group_marker = item.get_marker('concgroup')
  84. concurrent_group = None
  85. if concurrent_group_marker is not None:
  86. if 'args' in dir(concurrent_group_marker) \
  87. and concurrent_group_marker.args:
  88. concurrent_group = concurrent_group_marker.args[0]
  89. if 'kwargs' in dir(concurrent_group_marker) \
  90. and 'group' in concurrent_group_marker.kwargs:
  91. # kwargs beat args
  92. concurrent_group = concurrent_group_marker.kwargs['group']
  93. if concurrent_group:
  94. if not isinstance(concurrent_group, int):
  95. raise TypeError('Concurrent Group needs to be an integer')
  96. groups[concurrent_group].append(item)
  97. else:
  98. ungrouped_items.append(item)
  99. for group in sorted(groups):
  100. _run_items(mode=mode, items=groups[group], session=session, workers=workers)
  101. if ungrouped_items:
  102. _run_items(mode=mode, items=ungrouped_items, session=session, workers=workers)
  103. return True
  104. def _run_items(mode, items, session, workers=None):
  105. ''' Multiprocess is not compatible with Windows !!! '''
  106. if mode == "mproc":
  107. '''Using ThreadPoolExecutor as managers to control the lifecycle of processes.
  108. Each thread will spawn a process and terminates when the process joins.
  109. '''
  110. def run_task_in_proc(item, index):
  111. proc = multiprocessing.Process(target=_run_next_item, args=(session, item, index))
  112. proc.start()
  113. proc.join()
  114. with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
  115. for index, item in enumerate(items):
  116. executor.submit(run_task_in_proc, item, index)
  117. elif mode == "mthread":
  118. with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
  119. for index, item in enumerate(items):
  120. executor.submit(_run_next_item, session, item, index)
  121. elif mode == "asyncnet":
  122. import gevent
  123. import gevent.monkey
  124. import gevent.pool
  125. gevent.monkey.patch_all()
  126. pool = gevent.pool.Pool(size=workers)
  127. for index, item in enumerate(items):
  128. pool.spawn(_run_next_item, session, item, index)
  129. pool.join()
  130. else:
  131. for i, item in enumerate(items):
  132. nextitem = items[i + 1] if i + 1 < len(items) else None
  133. item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
  134. if session.shouldstop:
  135. raise session.Interrupted(session.shouldstop)
  136. def _run_next_item(session, item, i):
  137. nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
  138. item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
  139. if session.shouldstop:
  140. raise session.Interrupted(session.shouldstop)
  141. @pytest.mark.trylast
  142. def pytest_configure(config):
  143. config.addinivalue_line(
  144. 'markers',
  145. 'concgroup(group: int): concurrent group number to run tests in groups (smaller numbers are executed earlier)')
  146. if (config.option.concurrent_mode and config.option.concurrent_mode == 'mproc') or \
  147. config.getini('concurrent_mode') == 'mproc':
  148. standard_reporter = config.pluginmanager.getplugin('terminalreporter')
  149. concurrent_reporter = ConcurrentTerminalReporter(standard_reporter)
  150. config.pluginmanager.unregister(standard_reporter)
  151. config.pluginmanager.register(concurrent_reporter, 'terminalreporter')
  152. if config.option.xmlpath is not None:
  153. xmlpath = config.option.xmlpath
  154. config.pluginmanager.unregister(config._xml)
  155. config._xml = ConcurrentLogXML(xmlpath, config.option.junitprefix, config.getini("junit_suite_name"))
  156. config.pluginmanager.register(config._xml)
  157. class ConcurrentNodeReporter(_NodeReporter):
  158. '''to provide Node Reporting for multiprocess mode'''
  159. def __init__(self, nodeid, xml):
  160. self.id = nodeid
  161. self.xml = xml
  162. self.add_stats = self.xml.add_stats
  163. self.duration = 0
  164. self.properties = []
  165. self.nodes = []
  166. self.testcase = None
  167. self.attrs = {}
  168. def to_xml(self): # overriden
  169. testcase = Junit.testcase(time=self.duration, **self.attrs)
  170. testcase.append(self.make_properties_node())
  171. for node in self.nodes:
  172. testcase.append(node)
  173. return str(testcase.unicode(indent=0))
  174. def record_testreport(self, testreport):
  175. assert not self.testcase
  176. names = mangle_test_address(testreport.nodeid)
  177. classnames = names[:-1]
  178. if self.xml.prefix:
  179. classnames.insert(0, self.xml.prefix)
  180. attrs = {
  181. "classname": ".".join(classnames),
  182. "name": bin_xml_escape(names[-1]),
  183. "file": testreport.location[0],
  184. }
  185. if testreport.location[1] is not None:
  186. attrs["line"] = testreport.location[1]
  187. if hasattr(testreport, "url"):
  188. attrs["url"] = testreport.url
  189. self.attrs = attrs
  190. def finalize(self):
  191. data = self.to_xml() # .unicode(indent=0)
  192. self.__dict__.clear()
  193. self.to_xml = lambda: py.xml.raw(data)
  194. NODEREPORTS.append(data)
  195. class ConcurrentLogXML(LogXML):
  196. '''to provide XML reporting for multiprocess mode'''
  197. def __init__(self, logfile, prefix, suite_name="pytest"):
  198. logfile = logfile
  199. logfile = os.path.expanduser(os.path.expandvars(logfile))
  200. self.logfile = os.path.normpath(os.path.abspath(logfile))
  201. self.prefix = prefix
  202. self.suite_name = suite_name
  203. self.stats = XMLSTATS
  204. self.node_reporters = {} # XMLREPORTER # nodeid -> _NodeReporter
  205. self.node_reporters_ordered = []
  206. self.global_properties = []
  207. # List of reports that failed on call but teardown is pending.
  208. self.open_reports = []
  209. self.cnt_double_fail_tests = 0
  210. def pytest_sessionfinish(self):
  211. dirname = os.path.dirname(os.path.abspath(self.logfile))
  212. if not os.path.isdir(dirname):
  213. os.makedirs(dirname)
  214. logfile = open(self.logfile, 'w', encoding='utf-8')
  215. suite_stop_time = time.time()
  216. suite_time_delta = suite_stop_time - self.suite_start_time
  217. numtests = (self.stats['passed'] + self.stats['failure'] +
  218. self.stats['skipped'] + self.stats['error'] -
  219. self.cnt_double_fail_tests)
  220. # print("NODE REPORTS: " + str(NODEREPORTS))
  221. logfile.write('<?xml version="1.0" encoding="utf-8"?>')
  222. logfile.write(Junit.testsuite(
  223. self._get_global_properties_node(),
  224. [concurrent_log_to_xml(x) for x in NODEREPORTS],
  225. name=self.suite_name,
  226. errors=self.stats['error'],
  227. failures=self.stats['failure'],
  228. skips=self.stats['skipped'],
  229. tests=numtests,
  230. time="%.3f" % suite_time_delta, ).unicode(indent=0))
  231. logfile.close()
  232. def add_stats(self, key):
  233. XMLLOCK.acquire()
  234. if key in self.stats:
  235. self.stats[key] += 1
  236. XMLLOCK.release()
  237. def node_reporter(self, report):
  238. nodeid = getattr(report, 'nodeid', report)
  239. # local hack to handle xdist report order
  240. slavenode = getattr(report, 'node', None)
  241. key = nodeid, slavenode
  242. # NODELOCK.acquire()
  243. if key in self.node_reporters:
  244. # TODO: breasks for --dist=each
  245. return self.node_reporters[key]
  246. reporter = ConcurrentNodeReporter(nodeid, self)
  247. self.node_reporters[key] = reporter
  248. # NODEREPORTS.append(reporter.to_xml())
  249. return reporter
  250. def pytest_terminal_summary(self, terminalreporter):
  251. terminalreporter.write_sep("-",
  252. "generated xml file: %s" % (self.logfile))
  253. class ConcurrentTerminalReporter(TerminalReporter):
  254. '''to provide terminal reporting for multiprocess mode'''
  255. def __init__(self, reporter):
  256. TerminalReporter.__init__(self, reporter.config)
  257. self._tw = reporter._tw
  258. self.stats = STATS
  259. def add_stats(self, key):
  260. if key in self.stats:
  261. self.stats[key] += 1
  262. def pytest_runtest_logreport(self, report):
  263. rep = report
  264. res = self.config.hook.pytest_report_teststatus(report=rep)
  265. cat, letter, word = res
  266. append_list(self.stats, cat, rep)
  267. if report.when == 'call':
  268. DICTIONARY[report.nodeid] = report
  269. self._tests_ran = True
  270. if not letter and not word:
  271. # probably passed setup/teardown
  272. return
  273. if self.verbosity <= 0:
  274. if not hasattr(rep, 'node') and self.showfspath:
  275. self.write_fspath_result(rep.nodeid, letter)
  276. else:
  277. self._tw.write(letter)
  278. else:
  279. if isinstance(word, tuple):
  280. word, markup = word
  281. else:
  282. if rep.passed:
  283. markup = {'green': True}
  284. elif rep.failed:
  285. markup = {'red': True}
  286. elif rep.skipped:
  287. markup = {'yellow': True}
  288. line = self._locationline(rep.nodeid, *rep.location)
  289. if not hasattr(rep, 'node'):
  290. self.write_ensure_prefix(line, word, **markup)
  291. # self._tw.write(word, **markup)
  292. else:
  293. self.ensure_newline()
  294. if hasattr(rep, 'node'):
  295. self._tw.write("[%s] " % rep.node.gateway.id)
  296. self._tw.write(word, **markup)
  297. self._tw.write(" " + line)
  298. self.currentfspath = -2
  299. def append_list(stats, cat, rep):
  300. LOCK.acquire()
  301. cat_string = str(cat)
  302. if stats.get(cat_string) is None:
  303. stats[cat_string] = MANAGER.list()
  304. mylist = stats.get(cat_string)
  305. mylist.append(rep)
  306. stats[cat] = mylist
  307. LOCK.release()
  308. def concurrent_log_to_xml(log):
  309. return py.xml.raw(log)