test__threadsafety.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from __future__ import division, print_function, absolute_import
  2. import threading
  3. import time
  4. import traceback
  5. from numpy.testing import assert_
  6. from pytest import raises as assert_raises
  7. from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
  8. def test_parallel_threads():
  9. # Check that ReentrancyLock serializes work in parallel threads.
  10. #
  11. # The test is not fully deterministic, and may succeed falsely if
  12. # the timings go wrong.
  13. lock = ReentrancyLock("failure")
  14. failflag = [False]
  15. exceptions_raised = []
  16. def worker(k):
  17. try:
  18. with lock:
  19. assert_(not failflag[0])
  20. failflag[0] = True
  21. time.sleep(0.1 * k)
  22. assert_(failflag[0])
  23. failflag[0] = False
  24. except Exception:
  25. exceptions_raised.append(traceback.format_exc(2))
  26. threads = [threading.Thread(target=lambda k=k: worker(k))
  27. for k in range(3)]
  28. for t in threads:
  29. t.start()
  30. for t in threads:
  31. t.join()
  32. exceptions_raised = "\n".join(exceptions_raised)
  33. assert_(not exceptions_raised, exceptions_raised)
  34. def test_reentering():
  35. # Check that ReentrancyLock prevents re-entering from the same thread.
  36. @non_reentrant()
  37. def func(x):
  38. return func(x)
  39. assert_raises(ReentrancyError, func, 0)