test_parameter.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import division, absolute_import, print_function
  2. import os
  3. import pytest
  4. import numpy as np
  5. from numpy.testing import assert_raises, assert_equal
  6. from . import util
  7. def _path(*a):
  8. return os.path.join(*((os.path.dirname(__file__),) + a))
  9. class TestParameters(util.F2PyTest):
  10. # Check that intent(in out) translates as intent(inout)
  11. sources = [_path('src', 'parameter', 'constant_real.f90'),
  12. _path('src', 'parameter', 'constant_integer.f90'),
  13. _path('src', 'parameter', 'constant_both.f90'),
  14. _path('src', 'parameter', 'constant_compound.f90'),
  15. _path('src', 'parameter', 'constant_non_compound.f90'),
  16. ]
  17. @pytest.mark.slow
  18. def test_constant_real_single(self):
  19. # non-contiguous should raise error
  20. x = np.arange(6, dtype=np.float32)[::2]
  21. assert_raises(ValueError, self.module.foo_single, x)
  22. # check values with contiguous array
  23. x = np.arange(3, dtype=np.float32)
  24. self.module.foo_single(x)
  25. assert_equal(x, [0 + 1 + 2*3, 1, 2])
  26. @pytest.mark.slow
  27. def test_constant_real_double(self):
  28. # non-contiguous should raise error
  29. x = np.arange(6, dtype=np.float64)[::2]
  30. assert_raises(ValueError, self.module.foo_double, x)
  31. # check values with contiguous array
  32. x = np.arange(3, dtype=np.float64)
  33. self.module.foo_double(x)
  34. assert_equal(x, [0 + 1 + 2*3, 1, 2])
  35. @pytest.mark.slow
  36. def test_constant_compound_int(self):
  37. # non-contiguous should raise error
  38. x = np.arange(6, dtype=np.int32)[::2]
  39. assert_raises(ValueError, self.module.foo_compound_int, x)
  40. # check values with contiguous array
  41. x = np.arange(3, dtype=np.int32)
  42. self.module.foo_compound_int(x)
  43. assert_equal(x, [0 + 1 + 2*6, 1, 2])
  44. @pytest.mark.slow
  45. def test_constant_non_compound_int(self):
  46. # check values
  47. x = np.arange(4, dtype=np.int32)
  48. self.module.foo_non_compound_int(x)
  49. assert_equal(x, [0 + 1 + 2 + 3*4, 1, 2, 3])
  50. @pytest.mark.slow
  51. def test_constant_integer_int(self):
  52. # non-contiguous should raise error
  53. x = np.arange(6, dtype=np.int32)[::2]
  54. assert_raises(ValueError, self.module.foo_int, x)
  55. # check values with contiguous array
  56. x = np.arange(3, dtype=np.int32)
  57. self.module.foo_int(x)
  58. assert_equal(x, [0 + 1 + 2*3, 1, 2])
  59. @pytest.mark.slow
  60. def test_constant_integer_long(self):
  61. # non-contiguous should raise error
  62. x = np.arange(6, dtype=np.int64)[::2]
  63. assert_raises(ValueError, self.module.foo_long, x)
  64. # check values with contiguous array
  65. x = np.arange(3, dtype=np.int64)
  66. self.module.foo_long(x)
  67. assert_equal(x, [0 + 1 + 2*3, 1, 2])
  68. @pytest.mark.slow
  69. def test_constant_both(self):
  70. # non-contiguous should raise error
  71. x = np.arange(6, dtype=np.float64)[::2]
  72. assert_raises(ValueError, self.module.foo, x)
  73. # check values with contiguous array
  74. x = np.arange(3, dtype=np.float64)
  75. self.module.foo(x)
  76. assert_equal(x, [0 + 1*3*3 + 2*3*3, 1*3, 2*3])
  77. @pytest.mark.slow
  78. def test_constant_no(self):
  79. # non-contiguous should raise error
  80. x = np.arange(6, dtype=np.float64)[::2]
  81. assert_raises(ValueError, self.module.foo_no, x)
  82. # check values with contiguous array
  83. x = np.arange(3, dtype=np.float64)
  84. self.module.foo_no(x)
  85. assert_equal(x, [0 + 1*3*3 + 2*3*3, 1*3, 2*3])
  86. @pytest.mark.slow
  87. def test_constant_sum(self):
  88. # non-contiguous should raise error
  89. x = np.arange(6, dtype=np.float64)[::2]
  90. assert_raises(ValueError, self.module.foo_sum, x)
  91. # check values with contiguous array
  92. x = np.arange(3, dtype=np.float64)
  93. self.module.foo_sum(x)
  94. assert_equal(x, [0 + 1*3*3 + 2*3*3, 1*3, 2*3])