test_fblas.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # Test interfaces to fortran blas.
  2. #
  3. # The tests are more of interface than they are of the underlying blas.
  4. # Only very small matrices checked -- N=3 or so.
  5. #
  6. # !! Complex calculations really aren't checked that carefully.
  7. # !! Only real valued complex numbers are used in tests.
  8. from __future__ import division, print_function, absolute_import
  9. from numpy import float32, float64, complex64, complex128, arange, array, \
  10. zeros, shape, transpose, newaxis, common_type, conjugate
  11. from scipy.linalg import _fblas as fblas
  12. from scipy._lib.six import xrange
  13. from numpy.testing import assert_array_equal, \
  14. assert_allclose, assert_array_almost_equal, assert_
  15. import pytest
  16. # decimal accuracy to require between Python and LAPACK/BLAS calculations
  17. accuracy = 5
  18. # Since numpy.dot likely uses the same blas, use this routine
  19. # to check.
  20. def matrixmultiply(a, b):
  21. if len(b.shape) == 1:
  22. b_is_vector = True
  23. b = b[:, newaxis]
  24. else:
  25. b_is_vector = False
  26. assert_(a.shape[1] == b.shape[0])
  27. c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
  28. for i in xrange(a.shape[0]):
  29. for j in xrange(b.shape[1]):
  30. s = 0
  31. for k in xrange(a.shape[1]):
  32. s += a[i, k] * b[k, j]
  33. c[i, j] = s
  34. if b_is_vector:
  35. c = c.reshape((a.shape[0],))
  36. return c
  37. ##################################################
  38. # Test blas ?axpy
  39. class BaseAxpy(object):
  40. ''' Mixin class for axpy tests '''
  41. def test_default_a(self):
  42. x = arange(3., dtype=self.dtype)
  43. y = arange(3., dtype=x.dtype)
  44. real_y = x*1.+y
  45. y = self.blas_func(x, y)
  46. assert_array_equal(real_y, y)
  47. def test_simple(self):
  48. x = arange(3., dtype=self.dtype)
  49. y = arange(3., dtype=x.dtype)
  50. real_y = x*3.+y
  51. y = self.blas_func(x, y, a=3.)
  52. assert_array_equal(real_y, y)
  53. def test_x_stride(self):
  54. x = arange(6., dtype=self.dtype)
  55. y = zeros(3, x.dtype)
  56. y = arange(3., dtype=x.dtype)
  57. real_y = x[::2]*3.+y
  58. y = self.blas_func(x, y, a=3., n=3, incx=2)
  59. assert_array_equal(real_y, y)
  60. def test_y_stride(self):
  61. x = arange(3., dtype=self.dtype)
  62. y = zeros(6, x.dtype)
  63. real_y = x*3.+y[::2]
  64. y = self.blas_func(x, y, a=3., n=3, incy=2)
  65. assert_array_equal(real_y, y[::2])
  66. def test_x_and_y_stride(self):
  67. x = arange(12., dtype=self.dtype)
  68. y = zeros(6, x.dtype)
  69. real_y = x[::4]*3.+y[::2]
  70. y = self.blas_func(x, y, a=3., n=3, incx=4, incy=2)
  71. assert_array_equal(real_y, y[::2])
  72. def test_x_bad_size(self):
  73. x = arange(12., dtype=self.dtype)
  74. y = zeros(6, x.dtype)
  75. with pytest.raises(Exception, match='failed for 1st keyword'):
  76. self.blas_func(x, y, n=4, incx=5)
  77. def test_y_bad_size(self):
  78. x = arange(12., dtype=self.dtype)
  79. y = zeros(6, x.dtype)
  80. with pytest.raises(Exception, match='failed for 1st keyword'):
  81. self.blas_func(x, y, n=3, incy=5)
  82. try:
  83. class TestSaxpy(BaseAxpy):
  84. blas_func = fblas.saxpy
  85. dtype = float32
  86. except AttributeError:
  87. class TestSaxpy:
  88. pass
  89. class TestDaxpy(BaseAxpy):
  90. blas_func = fblas.daxpy
  91. dtype = float64
  92. try:
  93. class TestCaxpy(BaseAxpy):
  94. blas_func = fblas.caxpy
  95. dtype = complex64
  96. except AttributeError:
  97. class TestCaxpy:
  98. pass
  99. class TestZaxpy(BaseAxpy):
  100. blas_func = fblas.zaxpy
  101. dtype = complex128
  102. ##################################################
  103. # Test blas ?scal
  104. class BaseScal(object):
  105. ''' Mixin class for scal testing '''
  106. def test_simple(self):
  107. x = arange(3., dtype=self.dtype)
  108. real_x = x*3.
  109. x = self.blas_func(3., x)
  110. assert_array_equal(real_x, x)
  111. def test_x_stride(self):
  112. x = arange(6., dtype=self.dtype)
  113. real_x = x.copy()
  114. real_x[::2] = x[::2]*array(3., self.dtype)
  115. x = self.blas_func(3., x, n=3, incx=2)
  116. assert_array_equal(real_x, x)
  117. def test_x_bad_size(self):
  118. x = arange(12., dtype=self.dtype)
  119. with pytest.raises(Exception, match='failed for 1st keyword'):
  120. self.blas_func(2., x, n=4, incx=5)
  121. try:
  122. class TestSscal(BaseScal):
  123. blas_func = fblas.sscal
  124. dtype = float32
  125. except AttributeError:
  126. class TestSscal:
  127. pass
  128. class TestDscal(BaseScal):
  129. blas_func = fblas.dscal
  130. dtype = float64
  131. try:
  132. class TestCscal(BaseScal):
  133. blas_func = fblas.cscal
  134. dtype = complex64
  135. except AttributeError:
  136. class TestCscal:
  137. pass
  138. class TestZscal(BaseScal):
  139. blas_func = fblas.zscal
  140. dtype = complex128
  141. ##################################################
  142. # Test blas ?copy
  143. class BaseCopy(object):
  144. ''' Mixin class for copy testing '''
  145. def test_simple(self):
  146. x = arange(3., dtype=self.dtype)
  147. y = zeros(shape(x), x.dtype)
  148. y = self.blas_func(x, y)
  149. assert_array_equal(x, y)
  150. def test_x_stride(self):
  151. x = arange(6., dtype=self.dtype)
  152. y = zeros(3, x.dtype)
  153. y = self.blas_func(x, y, n=3, incx=2)
  154. assert_array_equal(x[::2], y)
  155. def test_y_stride(self):
  156. x = arange(3., dtype=self.dtype)
  157. y = zeros(6, x.dtype)
  158. y = self.blas_func(x, y, n=3, incy=2)
  159. assert_array_equal(x, y[::2])
  160. def test_x_and_y_stride(self):
  161. x = arange(12., dtype=self.dtype)
  162. y = zeros(6, x.dtype)
  163. y = self.blas_func(x, y, n=3, incx=4, incy=2)
  164. assert_array_equal(x[::4], y[::2])
  165. def test_x_bad_size(self):
  166. x = arange(12., dtype=self.dtype)
  167. y = zeros(6, x.dtype)
  168. with pytest.raises(Exception, match='failed for 1st keyword'):
  169. self.blas_func(x, y, n=4, incx=5)
  170. def test_y_bad_size(self):
  171. x = arange(12., dtype=self.dtype)
  172. y = zeros(6, x.dtype)
  173. with pytest.raises(Exception, match='failed for 1st keyword'):
  174. self.blas_func(x, y, n=3, incy=5)
  175. # def test_y_bad_type(self):
  176. ## Hmmm. Should this work? What should be the output.
  177. # x = arange(3.,dtype=self.dtype)
  178. # y = zeros(shape(x))
  179. # self.blas_func(x,y)
  180. # assert_array_equal(x,y)
  181. try:
  182. class TestScopy(BaseCopy):
  183. blas_func = fblas.scopy
  184. dtype = float32
  185. except AttributeError:
  186. class TestScopy:
  187. pass
  188. class TestDcopy(BaseCopy):
  189. blas_func = fblas.dcopy
  190. dtype = float64
  191. try:
  192. class TestCcopy(BaseCopy):
  193. blas_func = fblas.ccopy
  194. dtype = complex64
  195. except AttributeError:
  196. class TestCcopy:
  197. pass
  198. class TestZcopy(BaseCopy):
  199. blas_func = fblas.zcopy
  200. dtype = complex128
  201. ##################################################
  202. # Test blas ?swap
  203. class BaseSwap(object):
  204. ''' Mixin class for swap tests '''
  205. def test_simple(self):
  206. x = arange(3., dtype=self.dtype)
  207. y = zeros(shape(x), x.dtype)
  208. desired_x = y.copy()
  209. desired_y = x.copy()
  210. x, y = self.blas_func(x, y)
  211. assert_array_equal(desired_x, x)
  212. assert_array_equal(desired_y, y)
  213. def test_x_stride(self):
  214. x = arange(6., dtype=self.dtype)
  215. y = zeros(3, x.dtype)
  216. desired_x = y.copy()
  217. desired_y = x.copy()[::2]
  218. x, y = self.blas_func(x, y, n=3, incx=2)
  219. assert_array_equal(desired_x, x[::2])
  220. assert_array_equal(desired_y, y)
  221. def test_y_stride(self):
  222. x = arange(3., dtype=self.dtype)
  223. y = zeros(6, x.dtype)
  224. desired_x = y.copy()[::2]
  225. desired_y = x.copy()
  226. x, y = self.blas_func(x, y, n=3, incy=2)
  227. assert_array_equal(desired_x, x)
  228. assert_array_equal(desired_y, y[::2])
  229. def test_x_and_y_stride(self):
  230. x = arange(12., dtype=self.dtype)
  231. y = zeros(6, x.dtype)
  232. desired_x = y.copy()[::2]
  233. desired_y = x.copy()[::4]
  234. x, y = self.blas_func(x, y, n=3, incx=4, incy=2)
  235. assert_array_equal(desired_x, x[::4])
  236. assert_array_equal(desired_y, y[::2])
  237. def test_x_bad_size(self):
  238. x = arange(12., dtype=self.dtype)
  239. y = zeros(6, x.dtype)
  240. with pytest.raises(Exception, match='failed for 1st keyword'):
  241. self.blas_func(x, y, n=4, incx=5)
  242. def test_y_bad_size(self):
  243. x = arange(12., dtype=self.dtype)
  244. y = zeros(6, x.dtype)
  245. with pytest.raises(Exception, match='failed for 1st keyword'):
  246. self.blas_func(x, y, n=3, incy=5)
  247. try:
  248. class TestSswap(BaseSwap):
  249. blas_func = fblas.sswap
  250. dtype = float32
  251. except AttributeError:
  252. class TestSswap:
  253. pass
  254. class TestDswap(BaseSwap):
  255. blas_func = fblas.dswap
  256. dtype = float64
  257. try:
  258. class TestCswap(BaseSwap):
  259. blas_func = fblas.cswap
  260. dtype = complex64
  261. except AttributeError:
  262. class TestCswap:
  263. pass
  264. class TestZswap(BaseSwap):
  265. blas_func = fblas.zswap
  266. dtype = complex128
  267. ##################################################
  268. # Test blas ?gemv
  269. # This will be a mess to test all cases.
  270. class BaseGemv(object):
  271. ''' Mixin class for gemv tests '''
  272. def get_data(self, x_stride=1, y_stride=1):
  273. mult = array(1, dtype=self.dtype)
  274. if self.dtype in [complex64, complex128]:
  275. mult = array(1+1j, dtype=self.dtype)
  276. from numpy.random import normal, seed
  277. seed(1234)
  278. alpha = array(1., dtype=self.dtype) * mult
  279. beta = array(1., dtype=self.dtype) * mult
  280. a = normal(0., 1., (3, 3)).astype(self.dtype) * mult
  281. x = arange(shape(a)[0]*x_stride, dtype=self.dtype) * mult
  282. y = arange(shape(a)[1]*y_stride, dtype=self.dtype) * mult
  283. return alpha, beta, a, x, y
  284. def test_simple(self):
  285. alpha, beta, a, x, y = self.get_data()
  286. desired_y = alpha*matrixmultiply(a, x)+beta*y
  287. y = self.blas_func(alpha, a, x, beta, y)
  288. assert_array_almost_equal(desired_y, y)
  289. def test_default_beta_y(self):
  290. alpha, beta, a, x, y = self.get_data()
  291. desired_y = matrixmultiply(a, x)
  292. y = self.blas_func(1, a, x)
  293. assert_array_almost_equal(desired_y, y)
  294. def test_simple_transpose(self):
  295. alpha, beta, a, x, y = self.get_data()
  296. desired_y = alpha*matrixmultiply(transpose(a), x)+beta*y
  297. y = self.blas_func(alpha, a, x, beta, y, trans=1)
  298. assert_array_almost_equal(desired_y, y)
  299. def test_simple_transpose_conj(self):
  300. alpha, beta, a, x, y = self.get_data()
  301. desired_y = alpha*matrixmultiply(transpose(conjugate(a)), x)+beta*y
  302. y = self.blas_func(alpha, a, x, beta, y, trans=2)
  303. assert_array_almost_equal(desired_y, y)
  304. def test_x_stride(self):
  305. alpha, beta, a, x, y = self.get_data(x_stride=2)
  306. desired_y = alpha*matrixmultiply(a, x[::2])+beta*y
  307. y = self.blas_func(alpha, a, x, beta, y, incx=2)
  308. assert_array_almost_equal(desired_y, y)
  309. def test_x_stride_transpose(self):
  310. alpha, beta, a, x, y = self.get_data(x_stride=2)
  311. desired_y = alpha*matrixmultiply(transpose(a), x[::2])+beta*y
  312. y = self.blas_func(alpha, a, x, beta, y, trans=1, incx=2)
  313. assert_array_almost_equal(desired_y, y)
  314. def test_x_stride_assert(self):
  315. # What is the use of this test?
  316. alpha, beta, a, x, y = self.get_data(x_stride=2)
  317. with pytest.raises(Exception, match='failed for 3rd argument'):
  318. y = self.blas_func(1, a, x, 1, y, trans=0, incx=3)
  319. with pytest.raises(Exception, match='failed for 3rd argument'):
  320. y = self.blas_func(1, a, x, 1, y, trans=1, incx=3)
  321. def test_y_stride(self):
  322. alpha, beta, a, x, y = self.get_data(y_stride=2)
  323. desired_y = y.copy()
  324. desired_y[::2] = alpha*matrixmultiply(a, x)+beta*y[::2]
  325. y = self.blas_func(alpha, a, x, beta, y, incy=2)
  326. assert_array_almost_equal(desired_y, y)
  327. def test_y_stride_transpose(self):
  328. alpha, beta, a, x, y = self.get_data(y_stride=2)
  329. desired_y = y.copy()
  330. desired_y[::2] = alpha*matrixmultiply(transpose(a), x)+beta*y[::2]
  331. y = self.blas_func(alpha, a, x, beta, y, trans=1, incy=2)
  332. assert_array_almost_equal(desired_y, y)
  333. def test_y_stride_assert(self):
  334. # What is the use of this test?
  335. alpha, beta, a, x, y = self.get_data(y_stride=2)
  336. with pytest.raises(Exception, match='failed for 2nd keyword'):
  337. y = self.blas_func(1, a, x, 1, y, trans=0, incy=3)
  338. with pytest.raises(Exception, match='failed for 2nd keyword'):
  339. y = self.blas_func(1, a, x, 1, y, trans=1, incy=3)
  340. try:
  341. class TestSgemv(BaseGemv):
  342. blas_func = fblas.sgemv
  343. dtype = float32
  344. def test_sgemv_on_osx(self):
  345. from itertools import product
  346. import sys
  347. import numpy as np
  348. if sys.platform != 'darwin':
  349. return
  350. def aligned_array(shape, align, dtype, order='C'):
  351. # Make array shape `shape` with aligned at `align` bytes
  352. d = dtype()
  353. # Make array of correct size with `align` extra bytes
  354. N = np.prod(shape)
  355. tmp = np.zeros(N * d.nbytes + align, dtype=np.uint8)
  356. address = tmp.__array_interface__["data"][0]
  357. # Find offset into array giving desired alignment
  358. for offset in range(align):
  359. if (address + offset) % align == 0:
  360. break
  361. tmp = tmp[offset:offset+N*d.nbytes].view(dtype=dtype)
  362. return tmp.reshape(shape, order=order)
  363. def as_aligned(arr, align, dtype, order='C'):
  364. # Copy `arr` into an aligned array with same shape
  365. aligned = aligned_array(arr.shape, align, dtype, order)
  366. aligned[:] = arr[:]
  367. return aligned
  368. def assert_dot_close(A, X, desired):
  369. assert_allclose(self.blas_func(1.0, A, X), desired,
  370. rtol=1e-5, atol=1e-7)
  371. testdata = product((15, 32), (10000,), (200, 89), ('C', 'F'))
  372. for align, m, n, a_order in testdata:
  373. A_d = np.random.rand(m, n)
  374. X_d = np.random.rand(n)
  375. desired = np.dot(A_d, X_d)
  376. # Calculation with aligned single precision
  377. A_f = as_aligned(A_d, align, np.float32, order=a_order)
  378. X_f = as_aligned(X_d, align, np.float32, order=a_order)
  379. assert_dot_close(A_f, X_f, desired)
  380. except AttributeError:
  381. class TestSgemv:
  382. pass
  383. class TestDgemv(BaseGemv):
  384. blas_func = fblas.dgemv
  385. dtype = float64
  386. try:
  387. class TestCgemv(BaseGemv):
  388. blas_func = fblas.cgemv
  389. dtype = complex64
  390. except AttributeError:
  391. class TestCgemv:
  392. pass
  393. class TestZgemv(BaseGemv):
  394. blas_func = fblas.zgemv
  395. dtype = complex128
  396. """
  397. ##################################################
  398. ### Test blas ?ger
  399. ### This will be a mess to test all cases.
  400. class BaseGer(object):
  401. def get_data(self,x_stride=1,y_stride=1):
  402. from numpy.random import normal, seed
  403. seed(1234)
  404. alpha = array(1., dtype = self.dtype)
  405. a = normal(0.,1.,(3,3)).astype(self.dtype)
  406. x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
  407. y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
  408. return alpha,a,x,y
  409. def test_simple(self):
  410. alpha,a,x,y = self.get_data()
  411. # tranpose takes care of Fortran vs. C(and Python) memory layout
  412. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  413. self.blas_func(x,y,a)
  414. assert_array_almost_equal(desired_a,a)
  415. def test_x_stride(self):
  416. alpha,a,x,y = self.get_data(x_stride=2)
  417. desired_a = alpha*transpose(x[::2,newaxis]*y) + a
  418. self.blas_func(x,y,a,incx=2)
  419. assert_array_almost_equal(desired_a,a)
  420. def test_x_stride_assert(self):
  421. alpha,a,x,y = self.get_data(x_stride=2)
  422. with pytest.raises(ValueError, match='foo'):
  423. self.blas_func(x,y,a,incx=3)
  424. def test_y_stride(self):
  425. alpha,a,x,y = self.get_data(y_stride=2)
  426. desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
  427. self.blas_func(x,y,a,incy=2)
  428. assert_array_almost_equal(desired_a,a)
  429. def test_y_stride_assert(self):
  430. alpha,a,x,y = self.get_data(y_stride=2)
  431. with pytest.raises(ValueError, match='foo'):
  432. self.blas_func(a,x,y,incy=3)
  433. class TestSger(BaseGer):
  434. blas_func = fblas.sger
  435. dtype = float32
  436. class TestDger(BaseGer):
  437. blas_func = fblas.dger
  438. dtype = float64
  439. """
  440. ##################################################
  441. # Test blas ?gerc
  442. # This will be a mess to test all cases.
  443. """
  444. class BaseGerComplex(BaseGer):
  445. def get_data(self,x_stride=1,y_stride=1):
  446. from numpy.random import normal, seed
  447. seed(1234)
  448. alpha = array(1+1j, dtype = self.dtype)
  449. a = normal(0.,1.,(3,3)).astype(self.dtype)
  450. a = a + normal(0.,1.,(3,3)) * array(1j, dtype = self.dtype)
  451. x = normal(0.,1.,shape(a)[0]*x_stride).astype(self.dtype)
  452. x = x + x * array(1j, dtype = self.dtype)
  453. y = normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
  454. y = y + y * array(1j, dtype = self.dtype)
  455. return alpha,a,x,y
  456. def test_simple(self):
  457. alpha,a,x,y = self.get_data()
  458. # tranpose takes care of Fortran vs. C(and Python) memory layout
  459. a = a * array(0.,dtype = self.dtype)
  460. #desired_a = alpha*transpose(x[:,newaxis]*self.transform(y)) + a
  461. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  462. #self.blas_func(x,y,a,alpha = alpha)
  463. fblas.cgeru(x,y,a,alpha = alpha)
  464. assert_array_almost_equal(desired_a,a)
  465. #def test_x_stride(self):
  466. # alpha,a,x,y = self.get_data(x_stride=2)
  467. # desired_a = alpha*transpose(x[::2,newaxis]*self.transform(y)) + a
  468. # self.blas_func(x,y,a,incx=2)
  469. # assert_array_almost_equal(desired_a,a)
  470. #def test_y_stride(self):
  471. # alpha,a,x,y = self.get_data(y_stride=2)
  472. # desired_a = alpha*transpose(x[:,newaxis]*self.transform(y[::2])) + a
  473. # self.blas_func(x,y,a,incy=2)
  474. # assert_array_almost_equal(desired_a,a)
  475. class TestCgeru(BaseGerComplex):
  476. blas_func = fblas.cgeru
  477. dtype = complex64
  478. def transform(self,x):
  479. return x
  480. class TestZgeru(BaseGerComplex):
  481. blas_func = fblas.zgeru
  482. dtype = complex128
  483. def transform(self,x):
  484. return x
  485. class TestCgerc(BaseGerComplex):
  486. blas_func = fblas.cgerc
  487. dtype = complex64
  488. def transform(self,x):
  489. return conjugate(x)
  490. class TestZgerc(BaseGerComplex):
  491. blas_func = fblas.zgerc
  492. dtype = complex128
  493. def transform(self,x):
  494. return conjugate(x)
  495. """