worker.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # -*- test-case-name: twisted.trial._dist.test.test_worker -*-
  2. #
  3. # Copyright (c) Twisted Matrix Laboratories.
  4. # See LICENSE for details.
  5. """
  6. This module implements the worker classes.
  7. @since: 12.3
  8. """
  9. import os
  10. from zope.interface import implementer
  11. from twisted.internet.protocol import ProcessProtocol
  12. from twisted.internet.interfaces import ITransport, IAddress
  13. from twisted.internet.defer import Deferred
  14. from twisted.protocols.amp import AMP
  15. from twisted.python.compat import _PY3, unicode
  16. from twisted.python.failure import Failure
  17. from twisted.python.reflect import namedObject
  18. from twisted.trial.unittest import Todo
  19. from twisted.trial.runner import TrialSuite, TestLoader
  20. from twisted.trial._dist import workercommands, managercommands
  21. from twisted.trial._dist import _WORKER_AMP_STDIN, _WORKER_AMP_STDOUT
  22. from twisted.trial._dist.workerreporter import WorkerReporter
  23. class WorkerProtocol(AMP):
  24. """
  25. The worker-side trial distributed protocol.
  26. """
  27. def __init__(self, forceGarbageCollection=False):
  28. self._loader = TestLoader()
  29. self._result = WorkerReporter(self)
  30. self._forceGarbageCollection = forceGarbageCollection
  31. def run(self, testCase):
  32. """
  33. Run a test case by name.
  34. """
  35. if _PY3:
  36. testCase = testCase.decode("utf-8")
  37. case = self._loader.loadByName(testCase)
  38. suite = TrialSuite([case], self._forceGarbageCollection)
  39. suite.run(self._result)
  40. return {'success': True}
  41. workercommands.Run.responder(run)
  42. def start(self, directory):
  43. """
  44. Set up the worker, moving into given directory for tests to run in
  45. them.
  46. """
  47. os.chdir(directory)
  48. return {'success': True}
  49. workercommands.Start.responder(start)
  50. class LocalWorkerAMP(AMP):
  51. """
  52. Local implementation of the manager commands.
  53. """
  54. def addSuccess(self, testName):
  55. """
  56. Add a success to the reporter.
  57. """
  58. self._result.addSuccess(self._testCase)
  59. return {'success': True}
  60. managercommands.AddSuccess.responder(addSuccess)
  61. def _buildFailure(self, error, errorClass, frames):
  62. """
  63. Helper to build a C{Failure} with some traceback.
  64. @param error: An C{Exception} instance.
  65. @param error: The class name of the C{error} class.
  66. @param frames: A flat list of strings representing the information need
  67. to approximatively rebuild C{Failure} frames.
  68. @return: A L{Failure} instance with enough information about a test
  69. error.
  70. """
  71. if _PY3:
  72. errorClass = errorClass.decode("utf-8")
  73. errorType = namedObject(errorClass)
  74. failure = Failure(error, errorType)
  75. for i in range(0, len(frames), 3):
  76. failure.frames.append(
  77. (frames[i], frames[i + 1], int(frames[i + 2]), [], []))
  78. return failure
  79. def addError(self, testName, error, errorClass, frames):
  80. """
  81. Add an error to the reporter.
  82. """
  83. failure = self._buildFailure(error, errorClass, frames)
  84. self._result.addError(self._testCase, failure)
  85. return {'success': True}
  86. managercommands.AddError.responder(addError)
  87. def addFailure(self, testName, fail, failClass, frames):
  88. """
  89. Add a failure to the reporter.
  90. """
  91. failure = self._buildFailure(fail, failClass, frames)
  92. self._result.addFailure(self._testCase, failure)
  93. return {'success': True}
  94. managercommands.AddFailure.responder(addFailure)
  95. def addSkip(self, testName, reason):
  96. """
  97. Add a skip to the reporter.
  98. """
  99. self._result.addSkip(self._testCase, reason)
  100. return {'success': True}
  101. managercommands.AddSkip.responder(addSkip)
  102. def addExpectedFailure(self, testName, error, todo):
  103. """
  104. Add an expected failure to the reporter.
  105. """
  106. _todo = Todo(todo)
  107. self._result.addExpectedFailure(self._testCase, error, _todo)
  108. return {'success': True}
  109. managercommands.AddExpectedFailure.responder(addExpectedFailure)
  110. def addUnexpectedSuccess(self, testName, todo):
  111. """
  112. Add an unexpected success to the reporter.
  113. """
  114. self._result.addUnexpectedSuccess(self._testCase, todo)
  115. return {'success': True}
  116. managercommands.AddUnexpectedSuccess.responder(addUnexpectedSuccess)
  117. def testWrite(self, out):
  118. """
  119. Print test output from the worker.
  120. """
  121. if _PY3 and isinstance(out, bytes):
  122. out = out.decode("utf-8")
  123. self._testStream.write(out + '\n')
  124. self._testStream.flush()
  125. return {'success': True}
  126. managercommands.TestWrite.responder(testWrite)
  127. def _stopTest(self, result):
  128. """
  129. Stop the current running test case, forwarding the result.
  130. """
  131. self._result.stopTest(self._testCase)
  132. return result
  133. def run(self, testCase, result):
  134. """
  135. Run a test.
  136. """
  137. self._testCase = testCase
  138. self._result = result
  139. self._result.startTest(testCase)
  140. d = self.callRemote(workercommands.Run, testCase=testCase.id())
  141. return d.addCallback(self._stopTest)
  142. def setTestStream(self, stream):
  143. """
  144. Set the stream used to log output from tests.
  145. """
  146. self._testStream = stream
  147. @implementer(IAddress)
  148. class LocalWorkerAddress(object):
  149. """
  150. A L{IAddress} implementation meant to provide stub addresses for
  151. L{ITransport.getPeer} and L{ITransport.getHost}.
  152. """
  153. @implementer(ITransport)
  154. class LocalWorkerTransport(object):
  155. """
  156. A stub transport implementation used to support L{AMP} over a
  157. L{ProcessProtocol} transport.
  158. """
  159. def __init__(self, transport):
  160. self._transport = transport
  161. def write(self, data):
  162. """
  163. Forward data to transport.
  164. """
  165. self._transport.writeToChild(_WORKER_AMP_STDIN, data)
  166. def writeSequence(self, sequence):
  167. """
  168. Emulate C{writeSequence} by iterating data in the C{sequence}.
  169. """
  170. for data in sequence:
  171. self._transport.writeToChild(_WORKER_AMP_STDIN, data)
  172. def loseConnection(self):
  173. """
  174. Closes the transport.
  175. """
  176. self._transport.loseConnection()
  177. def getHost(self):
  178. """
  179. Return a L{LocalWorkerAddress} instance.
  180. """
  181. return LocalWorkerAddress()
  182. def getPeer(self):
  183. """
  184. Return a L{LocalWorkerAddress} instance.
  185. """
  186. return LocalWorkerAddress()
  187. class LocalWorker(ProcessProtocol):
  188. """
  189. Local process worker protocol. This worker runs as a local process and
  190. communicates via stdin/out.
  191. @ivar _ampProtocol: The L{AMP} protocol instance used to communicate with
  192. the worker.
  193. @ivar _logDirectory: The directory where logs will reside.
  194. @ivar _logFile: The name of the main log file for tests output.
  195. """
  196. def __init__(self, ampProtocol, logDirectory, logFile):
  197. self._ampProtocol = ampProtocol
  198. self._logDirectory = logDirectory
  199. self._logFile = logFile
  200. self.endDeferred = Deferred()
  201. def connectionMade(self):
  202. """
  203. When connection is made, create the AMP protocol instance.
  204. """
  205. self._ampProtocol.makeConnection(LocalWorkerTransport(self.transport))
  206. if not os.path.exists(self._logDirectory):
  207. os.makedirs(self._logDirectory)
  208. self._outLog = open(os.path.join(self._logDirectory, 'out.log'), 'w')
  209. self._errLog = open(os.path.join(self._logDirectory, 'err.log'), 'w')
  210. testLog = open(os.path.join(self._logDirectory, self._logFile), 'w')
  211. self._ampProtocol.setTestStream(testLog)
  212. logDirectory = self._logDirectory
  213. if isinstance(logDirectory, unicode):
  214. logDirectory = logDirectory.encode("utf-8")
  215. d = self._ampProtocol.callRemote(workercommands.Start,
  216. directory=logDirectory)
  217. # Ignore the potential errors, the test suite will fail properly and it
  218. # would just print garbage.
  219. d.addErrback(lambda x: None)
  220. def connectionLost(self, reason):
  221. """
  222. On connection lost, close the log files that we're managing for stdin
  223. and stdout.
  224. """
  225. self._outLog.close()
  226. self._errLog.close()
  227. def processEnded(self, reason):
  228. """
  229. When the process closes, call C{connectionLost} for cleanup purposes
  230. and forward the information to the C{_ampProtocol}.
  231. """
  232. self.connectionLost(reason)
  233. self._ampProtocol.connectionLost(reason)
  234. self.endDeferred.callback(reason)
  235. def outReceived(self, data):
  236. """
  237. Send data received from stdout to log.
  238. """
  239. self._outLog.write(data)
  240. def errReceived(self, data):
  241. """
  242. Write error data to log.
  243. """
  244. self._errLog.write(data)
  245. def childDataReceived(self, childFD, data):
  246. """
  247. Handle data received on the specific pipe for the C{_ampProtocol}.
  248. """
  249. if childFD == _WORKER_AMP_STDOUT:
  250. self._ampProtocol.dataReceived(data)
  251. else:
  252. ProcessProtocol.childDataReceived(self, childFD, data)