test_callback.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from __future__ import division, absolute_import, print_function
  2. import math
  3. import textwrap
  4. import sys
  5. import pytest
  6. import numpy as np
  7. from numpy.testing import assert_, assert_equal
  8. from . import util
  9. class TestF77Callback(util.F2PyTest):
  10. code = """
  11. subroutine t(fun,a)
  12. integer a
  13. cf2py intent(out) a
  14. external fun
  15. call fun(a)
  16. end
  17. subroutine func(a)
  18. cf2py intent(in,out) a
  19. integer a
  20. a = a + 11
  21. end
  22. subroutine func0(a)
  23. cf2py intent(out) a
  24. integer a
  25. a = 11
  26. end
  27. subroutine t2(a)
  28. cf2py intent(callback) fun
  29. integer a
  30. cf2py intent(out) a
  31. external fun
  32. call fun(a)
  33. end
  34. subroutine string_callback(callback, a)
  35. external callback
  36. double precision callback
  37. double precision a
  38. character*1 r
  39. cf2py intent(out) a
  40. r = 'r'
  41. a = callback(r)
  42. end
  43. subroutine string_callback_array(callback, cu, lencu, a)
  44. external callback
  45. integer callback
  46. integer lencu
  47. character*8 cu(lencu)
  48. integer a
  49. cf2py intent(out) a
  50. a = callback(cu, lencu)
  51. end
  52. """
  53. @pytest.mark.slow
  54. @pytest.mark.parametrize('name', 't,t2'.split(','))
  55. def test_all(self, name):
  56. self.check_function(name)
  57. @pytest.mark.slow
  58. def test_docstring(self):
  59. expected = """
  60. a = t(fun,[fun_extra_args])
  61. Wrapper for ``t``.
  62. Parameters
  63. ----------
  64. fun : call-back function
  65. Other Parameters
  66. ----------------
  67. fun_extra_args : input tuple, optional
  68. Default: ()
  69. Returns
  70. -------
  71. a : int
  72. Notes
  73. -----
  74. Call-back functions::
  75. def fun(): return a
  76. Return objects:
  77. a : int
  78. """
  79. assert_equal(self.module.t.__doc__, textwrap.dedent(expected).lstrip())
  80. def check_function(self, name):
  81. t = getattr(self.module, name)
  82. r = t(lambda: 4)
  83. assert_(r == 4, repr(r))
  84. r = t(lambda a: 5, fun_extra_args=(6,))
  85. assert_(r == 5, repr(r))
  86. r = t(lambda a: a, fun_extra_args=(6,))
  87. assert_(r == 6, repr(r))
  88. r = t(lambda a: 5 + a, fun_extra_args=(7,))
  89. assert_(r == 12, repr(r))
  90. r = t(lambda a: math.degrees(a), fun_extra_args=(math.pi,))
  91. assert_(r == 180, repr(r))
  92. r = t(math.degrees, fun_extra_args=(math.pi,))
  93. assert_(r == 180, repr(r))
  94. r = t(self.module.func, fun_extra_args=(6,))
  95. assert_(r == 17, repr(r))
  96. r = t(self.module.func0)
  97. assert_(r == 11, repr(r))
  98. r = t(self.module.func0._cpointer)
  99. assert_(r == 11, repr(r))
  100. class A(object):
  101. def __call__(self):
  102. return 7
  103. def mth(self):
  104. return 9
  105. a = A()
  106. r = t(a)
  107. assert_(r == 7, repr(r))
  108. r = t(a.mth)
  109. assert_(r == 9, repr(r))
  110. @pytest.mark.skipif(sys.platform=='win32',
  111. reason='Fails with MinGW64 Gfortran (Issue #9673)')
  112. def test_string_callback(self):
  113. def callback(code):
  114. if code == 'r':
  115. return 0
  116. else:
  117. return 1
  118. f = getattr(self.module, 'string_callback')
  119. r = f(callback)
  120. assert_(r == 0, repr(r))
  121. @pytest.mark.skipif(sys.platform=='win32',
  122. reason='Fails with MinGW64 Gfortran (Issue #9673)')
  123. def test_string_callback_array(self):
  124. # See gh-10027
  125. cu = np.zeros((1, 8), 'S1')
  126. def callback(cu, lencu):
  127. if cu.shape != (lencu, 8):
  128. return 1
  129. if cu.dtype != 'S1':
  130. return 2
  131. if not np.all(cu == b''):
  132. return 3
  133. return 0
  134. f = getattr(self.module, 'string_callback_array')
  135. res = f(callback, cu, len(cu))
  136. assert_(res == 0, repr(res))