_testutils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. class _FakeMatrix(object):
  4. def __init__(self, data):
  5. self._data = data
  6. self.__array_interface__ = data.__array_interface__
  7. class _FakeMatrix2(object):
  8. def __init__(self, data):
  9. self._data = data
  10. def __array__(self):
  11. return self._data
  12. def _get_array(shape, dtype):
  13. """
  14. Get a test array of given shape and data type.
  15. Returned NxN matrices are posdef, and 2xN are banded-posdef.
  16. """
  17. if len(shape) == 2 and shape[0] == 2:
  18. # yield a banded positive definite one
  19. x = np.zeros(shape, dtype=dtype)
  20. x[0, 1:] = -1
  21. x[1] = 2
  22. return x
  23. elif len(shape) == 2 and shape[0] == shape[1]:
  24. # always yield a positive definite matrix
  25. x = np.zeros(shape, dtype=dtype)
  26. j = np.arange(shape[0])
  27. x[j, j] = 2
  28. x[j[:-1], j[:-1]+1] = -1
  29. x[j[:-1]+1, j[:-1]] = -1
  30. return x
  31. else:
  32. np.random.seed(1234)
  33. return np.random.randn(*shape).astype(dtype)
  34. def _id(x):
  35. return x
  36. def assert_no_overwrite(call, shapes, dtypes=None):
  37. """
  38. Test that a call does not overwrite its input arguments
  39. """
  40. if dtypes is None:
  41. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  42. for dtype in dtypes:
  43. for order in ["C", "F"]:
  44. for faker in [_id, _FakeMatrix, _FakeMatrix2]:
  45. orig_inputs = [_get_array(s, dtype) for s in shapes]
  46. inputs = [faker(x.copy(order)) for x in orig_inputs]
  47. call(*inputs)
  48. msg = "call modified inputs [%r, %r]" % (dtype, faker)
  49. for a, b in zip(inputs, orig_inputs):
  50. np.testing.assert_equal(a, b, err_msg=msg)