test_ccallback.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from __future__ import division, print_function, absolute_import
  2. from numpy.testing import assert_equal, assert_
  3. from pytest import raises as assert_raises
  4. import time
  5. import pytest
  6. import ctypes
  7. import threading
  8. from scipy._lib import _ccallback_c as _test_ccallback_cython
  9. from scipy._lib import _test_ccallback
  10. from scipy._lib._ccallback import LowLevelCallable
  11. try:
  12. import cffi
  13. HAVE_CFFI = True
  14. except ImportError:
  15. HAVE_CFFI = False
  16. ERROR_VALUE = 2.0
  17. def callback_python(a, user_data=None):
  18. if a == ERROR_VALUE:
  19. raise ValueError("bad value")
  20. if user_data is None:
  21. return a + 1
  22. else:
  23. return a + user_data
  24. def _get_cffi_func(base, signature):
  25. if not HAVE_CFFI:
  26. pytest.skip("cffi not installed")
  27. # Get function address
  28. voidp = ctypes.cast(base, ctypes.c_void_p)
  29. address = voidp.value
  30. # Create corresponding cffi handle
  31. ffi = cffi.FFI()
  32. func = ffi.cast(signature, address)
  33. return func
  34. def _get_ctypes_data():
  35. value = ctypes.c_double(2.0)
  36. return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
  37. def _get_cffi_data():
  38. if not HAVE_CFFI:
  39. pytest.skip("cffi not installed")
  40. ffi = cffi.FFI()
  41. return ffi.new('double *', 2.0)
  42. CALLERS = {
  43. 'simple': _test_ccallback.test_call_simple,
  44. 'nodata': _test_ccallback.test_call_nodata,
  45. 'nonlocal': _test_ccallback.test_call_nonlocal,
  46. 'cython': _test_ccallback_cython.test_call_cython,
  47. }
  48. # These functions have signatures known to the callers
  49. FUNCS = {
  50. 'python': lambda: callback_python,
  51. 'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
  52. 'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1_cython"),
  53. 'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
  54. 'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
  55. 'double (*)(double, int *, void *)'),
  56. 'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
  57. 'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1b_cython"),
  58. 'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
  59. 'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
  60. 'double (*)(double, double, int *, void *)'),
  61. }
  62. # These functions have signatures the callers don't know
  63. BAD_FUNCS = {
  64. 'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
  65. 'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1bc_cython"),
  66. 'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
  67. 'cffi_bc': lambda: _get_cffi_func(_test_ccallback_cython.plus1bc_ctypes,
  68. 'double (*)(double, double, double, int *, void *)'),
  69. }
  70. USER_DATAS = {
  71. 'ctypes': _get_ctypes_data,
  72. 'cffi': _get_cffi_data,
  73. 'capsule': _test_ccallback.test_get_data_capsule,
  74. }
  75. def test_callbacks():
  76. def check(caller, func, user_data):
  77. caller = CALLERS[caller]
  78. func = FUNCS[func]()
  79. user_data = USER_DATAS[user_data]()
  80. if func is callback_python:
  81. func2 = lambda x: func(x, 2.0)
  82. else:
  83. func2 = LowLevelCallable(func, user_data)
  84. func = LowLevelCallable(func)
  85. # Test basic call
  86. assert_equal(caller(func, 1.0), 2.0)
  87. # Test 'bad' value resulting to an error
  88. assert_raises(ValueError, caller, func, ERROR_VALUE)
  89. # Test passing in user_data
  90. assert_equal(caller(func2, 1.0), 3.0)
  91. for caller in sorted(CALLERS.keys()):
  92. for func in sorted(FUNCS.keys()):
  93. for user_data in sorted(USER_DATAS.keys()):
  94. check(caller, func, user_data)
  95. def test_bad_callbacks():
  96. def check(caller, func, user_data):
  97. caller = CALLERS[caller]
  98. user_data = USER_DATAS[user_data]()
  99. func = BAD_FUNCS[func]()
  100. if func is callback_python:
  101. func2 = lambda x: func(x, 2.0)
  102. else:
  103. func2 = LowLevelCallable(func, user_data)
  104. func = LowLevelCallable(func)
  105. # Test that basic call fails
  106. assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
  107. # Test that passing in user_data also fails
  108. assert_raises(ValueError, caller, func2, 1.0)
  109. # Test error message
  110. llfunc = LowLevelCallable(func)
  111. try:
  112. caller(llfunc, 1.0)
  113. except ValueError as err:
  114. msg = str(err)
  115. assert_(llfunc.signature in msg, msg)
  116. assert_('double (double, double, int *, void *)' in msg, msg)
  117. for caller in sorted(CALLERS.keys()):
  118. for func in sorted(BAD_FUNCS.keys()):
  119. for user_data in sorted(USER_DATAS.keys()):
  120. check(caller, func, user_data)
  121. def test_signature_override():
  122. caller = _test_ccallback.test_call_simple
  123. func = _test_ccallback.test_get_plus1_capsule()
  124. llcallable = LowLevelCallable(func, signature="bad signature")
  125. assert_equal(llcallable.signature, "bad signature")
  126. assert_raises(ValueError, caller, llcallable, 3)
  127. llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
  128. assert_equal(llcallable.signature, "double (double, int *, void *)")
  129. assert_equal(caller(llcallable, 3), 4)
  130. def test_threadsafety():
  131. def callback(a, caller):
  132. if a <= 0:
  133. return 1
  134. else:
  135. res = caller(lambda x: callback(x, caller), a - 1)
  136. return 2*res
  137. def check(caller):
  138. caller = CALLERS[caller]
  139. results = []
  140. count = 10
  141. def run():
  142. time.sleep(0.01)
  143. r = caller(lambda x: callback(x, caller), count)
  144. results.append(r)
  145. threads = [threading.Thread(target=run) for j in range(20)]
  146. for thread in threads:
  147. thread.start()
  148. for thread in threads:
  149. thread.join()
  150. assert_equal(results, [2.0**count]*len(threads))
  151. for caller in CALLERS.keys():
  152. check(caller)