workermanage.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. from __future__ import print_function
  2. import fnmatch
  3. import os
  4. import re
  5. import threading
  6. import py
  7. import pytest
  8. import execnet
  9. import xdist.remote
  10. from _pytest import runner # XXX load dynamically
  11. def parse_spec_config(config):
  12. xspeclist = []
  13. for xspec in config.getvalue("tx"):
  14. i = xspec.find("*")
  15. try:
  16. num = int(xspec[:i])
  17. except ValueError:
  18. xspeclist.append(xspec)
  19. else:
  20. xspeclist.extend([xspec[i + 1 :]] * num)
  21. if not xspeclist:
  22. raise pytest.UsageError(
  23. "MISSING test execution (tx) nodes: please specify --tx"
  24. )
  25. return xspeclist
  26. class NodeManager(object):
  27. EXIT_TIMEOUT = 10
  28. DEFAULT_IGNORES = [".*", "*.pyc", "*.pyo", "*~"]
  29. def __init__(self, config, specs=None, defaultchdir="pyexecnetcache"):
  30. self.config = config
  31. self._nodesready = threading.Event()
  32. self.trace = self.config.trace.get("nodemanager")
  33. self.group = execnet.Group()
  34. if specs is None:
  35. specs = self._getxspecs()
  36. self.specs = []
  37. for spec in specs:
  38. if not isinstance(spec, execnet.XSpec):
  39. spec = execnet.XSpec(spec)
  40. if not spec.chdir and not spec.popen:
  41. spec.chdir = defaultchdir
  42. self.group.allocate_id(spec)
  43. self.specs.append(spec)
  44. self.roots = self._getrsyncdirs()
  45. self.rsyncoptions = self._getrsyncoptions()
  46. self._rsynced_specs = set()
  47. def rsync_roots(self, gateway):
  48. """Rsync the set of roots to the node's gateway cwd."""
  49. if self.roots:
  50. for root in self.roots:
  51. self.rsync(gateway, root, **self.rsyncoptions)
  52. def setup_nodes(self, putevent):
  53. self.config.hook.pytest_xdist_setupnodes(config=self.config, specs=self.specs)
  54. self.trace("setting up nodes")
  55. nodes = []
  56. for spec in self.specs:
  57. nodes.append(self.setup_node(spec, putevent))
  58. return nodes
  59. def setup_node(self, spec, putevent):
  60. gw = self.group.makegateway(spec)
  61. self.config.hook.pytest_xdist_newgateway(gateway=gw)
  62. self.rsync_roots(gw)
  63. node = WorkerController(self, gw, self.config, putevent)
  64. gw.node = node # keep the node alive
  65. node.setup()
  66. self.trace("started node %r" % node)
  67. return node
  68. def teardown_nodes(self):
  69. self.group.terminate(self.EXIT_TIMEOUT)
  70. def _getxspecs(self):
  71. return [execnet.XSpec(x) for x in parse_spec_config(self.config)]
  72. def _getrsyncdirs(self):
  73. for spec in self.specs:
  74. if not spec.popen or spec.chdir:
  75. break
  76. else:
  77. return []
  78. import pytest
  79. import _pytest
  80. pytestpath = pytest.__file__.rstrip("co")
  81. pytestdir = py.path.local(_pytest.__file__).dirpath()
  82. config = self.config
  83. candidates = [py._pydir, pytestpath, pytestdir]
  84. candidates += config.option.rsyncdir
  85. rsyncroots = config.getini("rsyncdirs")
  86. if rsyncroots:
  87. candidates.extend(rsyncroots)
  88. roots = []
  89. for root in candidates:
  90. root = py.path.local(root).realpath()
  91. if not root.check():
  92. raise pytest.UsageError("rsyncdir doesn't exist: %r" % (root,))
  93. if root not in roots:
  94. roots.append(root)
  95. return roots
  96. def _getrsyncoptions(self):
  97. """Get options to be passed for rsync."""
  98. ignores = list(self.DEFAULT_IGNORES)
  99. ignores += self.config.option.rsyncignore
  100. ignores += self.config.getini("rsyncignore")
  101. return {"ignores": ignores, "verbose": self.config.option.verbose}
  102. def rsync(self, gateway, source, notify=None, verbose=False, ignores=None):
  103. """Perform rsync to remote hosts for node."""
  104. # XXX This changes the calling behaviour of
  105. # pytest_xdist_rsyncstart and pytest_xdist_rsyncfinish to
  106. # be called once per rsync target.
  107. rsync = HostRSync(source, verbose=verbose, ignores=ignores)
  108. spec = gateway.spec
  109. if spec.popen and not spec.chdir:
  110. # XXX This assumes that sources are python-packages
  111. # and that adding the basedir does not hurt.
  112. gateway.remote_exec(
  113. """
  114. import sys ; sys.path.insert(0, %r)
  115. """
  116. % os.path.dirname(str(source))
  117. ).waitclose()
  118. return
  119. if (spec, source) in self._rsynced_specs:
  120. return
  121. def finished():
  122. if notify:
  123. notify("rsyncrootready", spec, source)
  124. rsync.add_target_host(gateway, finished=finished)
  125. self._rsynced_specs.add((spec, source))
  126. self.config.hook.pytest_xdist_rsyncstart(source=source, gateways=[gateway])
  127. rsync.send()
  128. self.config.hook.pytest_xdist_rsyncfinish(source=source, gateways=[gateway])
  129. class HostRSync(execnet.RSync):
  130. """ RSyncer that filters out common files
  131. """
  132. def __init__(self, sourcedir, *args, **kwargs):
  133. self._synced = {}
  134. self._ignores = []
  135. ignores = kwargs.pop("ignores", None) or []
  136. for x in ignores:
  137. x = getattr(x, "strpath", x)
  138. self._ignores.append(re.compile(fnmatch.translate(x)))
  139. super(HostRSync, self).__init__(sourcedir=sourcedir, **kwargs)
  140. def filter(self, path):
  141. path = py.path.local(path)
  142. for cre in self._ignores:
  143. if cre.match(path.basename) or cre.match(path.strpath):
  144. return False
  145. else:
  146. return True
  147. def add_target_host(self, gateway, finished=None):
  148. remotepath = os.path.basename(self._sourcedir)
  149. super(HostRSync, self).add_target(
  150. gateway, remotepath, finishedcallback=finished, delete=True
  151. )
  152. def _report_send_file(self, gateway, modified_rel_path):
  153. if self._verbose:
  154. path = os.path.basename(self._sourcedir) + "/" + modified_rel_path
  155. remotepath = gateway.spec.chdir
  156. print("%s:%s <= %s" % (gateway.spec, remotepath, path))
  157. def make_reltoroot(roots, args):
  158. # XXX introduce/use public API for splitting pytest args
  159. splitcode = "::"
  160. result = []
  161. for arg in args:
  162. parts = arg.split(splitcode)
  163. fspath = py.path.local(parts[0])
  164. for root in roots:
  165. x = fspath.relto(root)
  166. if x or fspath == root:
  167. parts[0] = root.basename + "/" + x
  168. break
  169. else:
  170. raise ValueError("arg %s not relative to an rsync root" % (arg,))
  171. result.append(splitcode.join(parts))
  172. return result
  173. class WorkerController(object):
  174. ENDMARK = -1
  175. def __init__(self, nodemanager, gateway, config, putevent):
  176. self.nodemanager = nodemanager
  177. self.putevent = putevent
  178. self.gateway = gateway
  179. self.config = config
  180. self.workerinput = {
  181. "workerid": gateway.id,
  182. "workercount": len(nodemanager.specs),
  183. "slaveid": gateway.id,
  184. "slavecount": len(nodemanager.specs),
  185. }
  186. # TODO: deprecated name, backward compatibility only. Remove it in future
  187. self.slaveinput = self.workerinput
  188. self._down = False
  189. self._shutdown_sent = False
  190. self.log = py.log.Producer("workerctl-%s" % gateway.id)
  191. if not self.config.option.debug:
  192. py.log.setconsumer(self.log._keywords, None)
  193. def __repr__(self):
  194. return "<%s %s>" % (self.__class__.__name__, self.gateway.id)
  195. @property
  196. def shutting_down(self):
  197. return self._down or self._shutdown_sent
  198. def setup(self):
  199. self.log("setting up worker session")
  200. spec = self.gateway.spec
  201. args = self.config.args
  202. if not spec.popen or spec.chdir:
  203. args = make_reltoroot(self.nodemanager.roots, args)
  204. option_dict = vars(self.config.option)
  205. if spec.popen:
  206. name = "popen-%s" % self.gateway.id
  207. if hasattr(self.config, "_tmpdirhandler"):
  208. basetemp = self.config._tmpdirhandler.getbasetemp()
  209. option_dict["basetemp"] = str(basetemp.join(name))
  210. self.config.hook.pytest_configure_node(node=self)
  211. self.channel = self.gateway.remote_exec(xdist.remote)
  212. self.channel.send((self.workerinput, args, option_dict))
  213. if self.putevent:
  214. self.channel.setcallback(self.process_from_remote, endmarker=self.ENDMARK)
  215. def ensure_teardown(self):
  216. if hasattr(self, "channel"):
  217. if not self.channel.isclosed():
  218. self.log("closing", self.channel)
  219. self.channel.close()
  220. # del self.channel
  221. if hasattr(self, "gateway"):
  222. self.log("exiting", self.gateway)
  223. self.gateway.exit()
  224. # del self.gateway
  225. def send_runtest_some(self, indices):
  226. self.sendcommand("runtests", indices=indices)
  227. def send_runtest_all(self):
  228. self.sendcommand("runtests_all")
  229. def shutdown(self):
  230. if not self._down:
  231. try:
  232. self.sendcommand("shutdown")
  233. except IOError:
  234. pass
  235. self._shutdown_sent = True
  236. def sendcommand(self, name, **kwargs):
  237. """ send a named parametrized command to the other side. """
  238. self.log("sending command %s(**%s)" % (name, kwargs))
  239. self.channel.send((name, kwargs))
  240. def notify_inproc(self, eventname, **kwargs):
  241. self.log("queuing %s(**%s)" % (eventname, kwargs))
  242. self.putevent((eventname, kwargs))
  243. def process_from_remote(self, eventcall): # noqa too complex
  244. """ this gets called for each object we receive from
  245. the other side and if the channel closes.
  246. Note that channel callbacks run in the receiver
  247. thread of execnet gateways - we need to
  248. avoid raising exceptions or doing heavy work.
  249. """
  250. try:
  251. if eventcall == self.ENDMARK:
  252. err = self.channel._getremoteerror()
  253. if not self._down:
  254. if not err or isinstance(err, EOFError):
  255. err = "Not properly terminated" # lost connection?
  256. self.notify_inproc("errordown", node=self, error=err)
  257. self._down = True
  258. return
  259. eventname, kwargs = eventcall
  260. if eventname in ("collectionstart",):
  261. self.log("ignoring %s(%s)" % (eventname, kwargs))
  262. elif eventname == "workerready":
  263. self.notify_inproc(eventname, node=self, **kwargs)
  264. elif eventname == "workerfinished":
  265. self._down = True
  266. self.workeroutput = self.slaveoutput = kwargs["workeroutput"]
  267. self.notify_inproc("workerfinished", node=self)
  268. elif eventname in ("logstart", "logfinish"):
  269. self.notify_inproc(eventname, node=self, **kwargs)
  270. elif eventname in ("testreport", "collectreport", "teardownreport"):
  271. item_index = kwargs.pop("item_index", None)
  272. rep = unserialize_report(eventname, kwargs["data"])
  273. if item_index is not None:
  274. rep.item_index = item_index
  275. self.notify_inproc(eventname, node=self, rep=rep)
  276. elif eventname == "collectionfinish":
  277. self.notify_inproc(eventname, node=self, ids=kwargs["ids"])
  278. elif eventname == "runtest_protocol_complete":
  279. self.notify_inproc(eventname, node=self, **kwargs)
  280. elif eventname == "logwarning":
  281. self.notify_inproc(
  282. eventname,
  283. message=kwargs["message"],
  284. code=kwargs["code"],
  285. nodeid=kwargs["nodeid"],
  286. fslocation=kwargs["nodeid"],
  287. )
  288. else:
  289. raise ValueError("unknown event: %s" % (eventname,))
  290. except KeyboardInterrupt:
  291. # should not land in receiver-thread
  292. raise
  293. except: # noqa
  294. from _pytest._code import ExceptionInfo
  295. excinfo = ExceptionInfo()
  296. print("!" * 20, excinfo)
  297. self.config.notify_exception(excinfo)
  298. self.shutdown()
  299. self.notify_inproc("errordown", node=self, error=excinfo)
  300. def unserialize_report(name, reportdict):
  301. def assembled_report(reportdict):
  302. from _pytest._code.code import (
  303. ReprEntry,
  304. ReprEntryNative,
  305. ReprExceptionInfo,
  306. ReprFileLocation,
  307. ReprFuncArgs,
  308. ReprLocals,
  309. ReprTraceback,
  310. )
  311. if reportdict["longrepr"]:
  312. if (
  313. "reprcrash" in reportdict["longrepr"]
  314. and "reprtraceback" in reportdict["longrepr"]
  315. ):
  316. reprtraceback = reportdict["longrepr"]["reprtraceback"]
  317. reprcrash = reportdict["longrepr"]["reprcrash"]
  318. unserialized_entries = []
  319. reprentry = None
  320. for entry_data in reprtraceback["reprentries"]:
  321. data = entry_data["data"]
  322. entry_type = entry_data["type"]
  323. if entry_type == "ReprEntry":
  324. reprfuncargs = None
  325. reprfileloc = None
  326. reprlocals = None
  327. if data["reprfuncargs"]:
  328. reprfuncargs = ReprFuncArgs(**data["reprfuncargs"])
  329. if data["reprfileloc"]:
  330. reprfileloc = ReprFileLocation(**data["reprfileloc"])
  331. if data["reprlocals"]:
  332. reprlocals = ReprLocals(data["reprlocals"]["lines"])
  333. reprentry = ReprEntry(
  334. lines=data["lines"],
  335. reprfuncargs=reprfuncargs,
  336. reprlocals=reprlocals,
  337. filelocrepr=reprfileloc,
  338. style=data["style"],
  339. )
  340. elif entry_type == "ReprEntryNative":
  341. reprentry = ReprEntryNative(data["lines"])
  342. else:
  343. report_unserialization_failure(entry_type, name, reportdict)
  344. unserialized_entries.append(reprentry)
  345. reprtraceback["reprentries"] = unserialized_entries
  346. exception_info = ReprExceptionInfo(
  347. reprtraceback=ReprTraceback(**reprtraceback),
  348. reprcrash=ReprFileLocation(**reprcrash),
  349. )
  350. for section in reportdict["longrepr"]["sections"]:
  351. exception_info.addsection(*section)
  352. reportdict["longrepr"] = exception_info
  353. return reportdict
  354. if name == "testreport":
  355. return runner.TestReport(**assembled_report(reportdict))
  356. elif name == "collectreport":
  357. return runner.CollectReport(**assembled_report(reportdict))
  358. def report_unserialization_failure(type_name, report_name, reportdict):
  359. from pprint import pprint
  360. url = "https://github.com/pytest-dev/pytest-xdist/issues"
  361. stream = py.io.TextIO()
  362. pprint("-" * 100, stream=stream)
  363. pprint("INTERNALERROR: Unknown entry type returned: %s" % type_name, stream=stream)
  364. pprint("report_name: %s" % report_name, stream=stream)
  365. pprint(reportdict, stream=stream)
  366. pprint("Please report this bug at %s" % url, stream=stream)
  367. pprint("-" * 100, stream=stream)
  368. assert 0, stream.getvalue()