_util.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. from __future__ import division, print_function, absolute_import
  2. import functools
  3. import operator
  4. import sys
  5. import warnings
  6. import numbers
  7. from collections import namedtuple
  8. from multiprocessing import Pool
  9. import inspect
  10. import numpy as np
  11. def _broadcast_arrays(a, b):
  12. """
  13. Same as np.broadcast_arrays(a, b) but old writeability rules.
  14. Numpy >= 1.17.0 transitions broadcast_arrays to return
  15. read-only arrays. Set writeability explicitly to avoid warnings.
  16. Retain the old writeability rules, as our Cython code assumes
  17. the old behavior.
  18. """
  19. # backport based on gh-10379
  20. x, y = np.broadcast_arrays(a, b)
  21. x.flags.writeable = a.flags.writeable
  22. y.flags.writeable = b.flags.writeable
  23. return x, y
  24. def _valarray(shape, value=np.nan, typecode=None):
  25. """Return an array of all value.
  26. """
  27. out = np.ones(shape, dtype=bool) * value
  28. if typecode is not None:
  29. out = out.astype(typecode)
  30. if not isinstance(out, np.ndarray):
  31. out = np.asarray(out)
  32. return out
  33. def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
  34. """
  35. np.where(cond, x, fillvalue) always evaluates x even where cond is False.
  36. This one only evaluates f(arr1[cond], arr2[cond], ...).
  37. For example,
  38. >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
  39. >>> def f(a, b):
  40. return a*b
  41. >>> _lazywhere(a > 2, (a, b), f, np.nan)
  42. array([ nan, nan, 21., 32.])
  43. Notice it assumes that all `arrays` are of the same shape, or can be
  44. broadcasted together.
  45. """
  46. if fillvalue is None:
  47. if f2 is None:
  48. raise ValueError("One of (fillvalue, f2) must be given.")
  49. else:
  50. fillvalue = np.nan
  51. else:
  52. if f2 is not None:
  53. raise ValueError("Only one of (fillvalue, f2) can be given.")
  54. arrays = np.broadcast_arrays(*arrays)
  55. temp = tuple(np.extract(cond, arr) for arr in arrays)
  56. tcode = np.mintypecode([a.dtype.char for a in arrays])
  57. out = _valarray(np.shape(arrays[0]), value=fillvalue, typecode=tcode)
  58. np.place(out, cond, f(*temp))
  59. if f2 is not None:
  60. temp = tuple(np.extract(~cond, arr) for arr in arrays)
  61. np.place(out, ~cond, f2(*temp))
  62. return out
  63. def _lazyselect(condlist, choicelist, arrays, default=0):
  64. """
  65. Mimic `np.select(condlist, choicelist)`.
  66. Notice it assumes that all `arrays` are of the same shape, or can be
  67. broadcasted together.
  68. All functions in `choicelist` must accept array arguments in the order
  69. given in `arrays` and must return an array of the same shape as broadcasted
  70. `arrays`.
  71. Examples
  72. --------
  73. >>> x = np.arange(6)
  74. >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
  75. array([ 0, 1, 4, 0, 64, 125])
  76. >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
  77. array([ 0., 1., 4., 0., 64., 125.])
  78. >>> a = -np.ones_like(x)
  79. >>> _lazyselect([x < 3, x > 3],
  80. ... [lambda x, a: x**2, lambda x, a: a * x**3],
  81. ... (x, a), default=np.nan)
  82. array([ 0., 1., 4., nan, -64., -125.])
  83. """
  84. arrays = np.broadcast_arrays(*arrays)
  85. tcode = np.mintypecode([a.dtype.char for a in arrays])
  86. out = _valarray(np.shape(arrays[0]), value=default, typecode=tcode)
  87. for index in range(len(condlist)):
  88. func, cond = choicelist[index], condlist[index]
  89. if np.all(cond is False):
  90. continue
  91. cond, _ = np.broadcast_arrays(cond, arrays[0])
  92. temp = tuple(np.extract(cond, arr) for arr in arrays)
  93. np.place(out, cond, func(*temp))
  94. return out
  95. def _aligned_zeros(shape, dtype=float, order="C", align=None):
  96. """Allocate a new ndarray with aligned memory.
  97. Primary use case for this currently is working around a f2py issue
  98. in Numpy 1.9.1, where dtype.alignment is such that np.zeros() does
  99. not necessarily create arrays aligned up to it.
  100. """
  101. dtype = np.dtype(dtype)
  102. if align is None:
  103. align = dtype.alignment
  104. if not hasattr(shape, '__len__'):
  105. shape = (shape,)
  106. size = functools.reduce(operator.mul, shape) * dtype.itemsize
  107. buf = np.empty(size + align + 1, np.uint8)
  108. offset = buf.__array_interface__['data'][0] % align
  109. if offset != 0:
  110. offset = align - offset
  111. # Note: slices producing 0-size arrays do not necessarily change
  112. # data pointer --- so we use and allocate size+1
  113. buf = buf[offset:offset+size+1][:-1]
  114. data = np.ndarray(shape, dtype, buf, order=order)
  115. data.fill(0)
  116. return data
  117. def _prune_array(array):
  118. """Return an array equivalent to the input array. If the input
  119. array is a view of a much larger array, copy its contents to a
  120. newly allocated array. Otherwise, return the input unchanged.
  121. """
  122. if array.base is not None and array.size < array.base.size // 2:
  123. return array.copy()
  124. return array
  125. class DeprecatedImport(object):
  126. """
  127. Deprecated import, with redirection + warning.
  128. Examples
  129. --------
  130. Suppose you previously had in some module::
  131. from foo import spam
  132. If this has to be deprecated, do::
  133. spam = DeprecatedImport("foo.spam", "baz")
  134. to redirect users to use "baz" module instead.
  135. """
  136. def __init__(self, old_module_name, new_module_name):
  137. self._old_name = old_module_name
  138. self._new_name = new_module_name
  139. __import__(self._new_name)
  140. self._mod = sys.modules[self._new_name]
  141. def __dir__(self):
  142. return dir(self._mod)
  143. def __getattr__(self, name):
  144. warnings.warn("Module %s is deprecated, use %s instead"
  145. % (self._old_name, self._new_name),
  146. DeprecationWarning)
  147. return getattr(self._mod, name)
  148. # copy-pasted from scikit-learn utils/validation.py
  149. def check_random_state(seed):
  150. """Turn seed into a np.random.RandomState instance
  151. If seed is None (or np.random), return the RandomState singleton used
  152. by np.random.
  153. If seed is an int, return a new RandomState instance seeded with seed.
  154. If seed is already a RandomState instance, return it.
  155. Otherwise raise ValueError.
  156. """
  157. if seed is None or seed is np.random:
  158. return np.random.mtrand._rand
  159. if isinstance(seed, (numbers.Integral, np.integer)):
  160. return np.random.RandomState(seed)
  161. if isinstance(seed, np.random.RandomState):
  162. return seed
  163. raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
  164. ' instance' % seed)
  165. def _asarray_validated(a, check_finite=True,
  166. sparse_ok=False, objects_ok=False, mask_ok=False,
  167. as_inexact=False):
  168. """
  169. Helper function for scipy argument validation.
  170. Many scipy linear algebra functions do support arbitrary array-like
  171. input arguments. Examples of commonly unsupported inputs include
  172. matrices containing inf/nan, sparse matrix representations, and
  173. matrices with complicated elements.
  174. Parameters
  175. ----------
  176. a : array_like
  177. The array-like input.
  178. check_finite : bool, optional
  179. Whether to check that the input matrices contain only finite numbers.
  180. Disabling may give a performance gain, but may result in problems
  181. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  182. Default: True
  183. sparse_ok : bool, optional
  184. True if scipy sparse matrices are allowed.
  185. objects_ok : bool, optional
  186. True if arrays with dype('O') are allowed.
  187. mask_ok : bool, optional
  188. True if masked arrays are allowed.
  189. as_inexact : bool, optional
  190. True to convert the input array to a np.inexact dtype.
  191. Returns
  192. -------
  193. ret : ndarray
  194. The converted validated array.
  195. """
  196. if not sparse_ok:
  197. import scipy.sparse
  198. if scipy.sparse.issparse(a):
  199. msg = ('Sparse matrices are not supported by this function. '
  200. 'Perhaps one of the scipy.sparse.linalg functions '
  201. 'would work instead.')
  202. raise ValueError(msg)
  203. if not mask_ok:
  204. if np.ma.isMaskedArray(a):
  205. raise ValueError('masked arrays are not supported')
  206. toarray = np.asarray_chkfinite if check_finite else np.asarray
  207. a = toarray(a)
  208. if not objects_ok:
  209. if a.dtype is np.dtype('O'):
  210. raise ValueError('object arrays are not supported')
  211. if as_inexact:
  212. if not np.issubdtype(a.dtype, np.inexact):
  213. a = toarray(a, dtype=np.float_)
  214. return a
  215. # Add a replacement for inspect.getargspec() which is deprecated in python 3.5
  216. # The version below is borrowed from Django,
  217. # https://github.com/django/django/pull/4846
  218. # Note an inconsistency between inspect.getargspec(func) and
  219. # inspect.signature(func). If `func` is a bound method, the latter does *not*
  220. # list `self` as a first argument, while the former *does*.
  221. # Hence cook up a common ground replacement: `getargspec_no_self` which
  222. # mimics `inspect.getargspec` but does not list `self`.
  223. #
  224. # This way, the caller code does not need to know whether it uses a legacy
  225. # .getargspec or bright and shiny .signature.
  226. try:
  227. # is it python 3.3 or higher?
  228. inspect.signature
  229. # Apparently, yes. Wrap inspect.signature
  230. ArgSpec = namedtuple('ArgSpec', ['args', 'varargs', 'keywords', 'defaults'])
  231. def getargspec_no_self(func):
  232. """inspect.getargspec replacement using inspect.signature.
  233. inspect.getargspec is deprecated in python 3. This is a replacement
  234. based on the (new in python 3.3) `inspect.signature`.
  235. Parameters
  236. ----------
  237. func : callable
  238. A callable to inspect
  239. Returns
  240. -------
  241. argspec : ArgSpec(args, varargs, varkw, defaults)
  242. This is similar to the result of inspect.getargspec(func) under
  243. python 2.x.
  244. NOTE: if the first argument of `func` is self, it is *not*, I repeat
  245. *not* included in argspec.args.
  246. This is done for consistency between inspect.getargspec() under
  247. python 2.x, and inspect.signature() under python 3.x.
  248. """
  249. sig = inspect.signature(func)
  250. args = [
  251. p.name for p in sig.parameters.values()
  252. if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  253. ]
  254. varargs = [
  255. p.name for p in sig.parameters.values()
  256. if p.kind == inspect.Parameter.VAR_POSITIONAL
  257. ]
  258. varargs = varargs[0] if varargs else None
  259. varkw = [
  260. p.name for p in sig.parameters.values()
  261. if p.kind == inspect.Parameter.VAR_KEYWORD
  262. ]
  263. varkw = varkw[0] if varkw else None
  264. defaults = [
  265. p.default for p in sig.parameters.values()
  266. if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
  267. p.default is not p.empty)
  268. ] or None
  269. return ArgSpec(args, varargs, varkw, defaults)
  270. except AttributeError:
  271. # python 2.x
  272. def getargspec_no_self(func):
  273. """inspect.getargspec replacement for compatibility with python 3.x.
  274. inspect.getargspec is deprecated in python 3. This wraps it, and
  275. *removes* `self` from the argument list of `func`, if present.
  276. This is done for forward compatibility with python 3.
  277. Parameters
  278. ----------
  279. func : callable
  280. A callable to inspect
  281. Returns
  282. -------
  283. argspec : ArgSpec(args, varargs, varkw, defaults)
  284. This is similar to the result of inspect.getargspec(func) under
  285. python 2.x.
  286. NOTE: if the first argument of `func` is self, it is *not*, I repeat
  287. *not* included in argspec.args.
  288. This is done for consistency between inspect.getargspec() under
  289. python 2.x, and inspect.signature() under python 3.x.
  290. """
  291. argspec = inspect.getargspec(func)
  292. if argspec.args[0] == 'self':
  293. argspec.args.pop(0)
  294. return argspec
  295. class MapWrapper(object):
  296. """
  297. Parallelisation wrapper for working with map-like callables, such as
  298. `multiprocessing.Pool.map`.
  299. Parameters
  300. ----------
  301. pool : int or map-like callable
  302. If `pool` is an integer, then it specifies the number of threads to
  303. use for parallelization. If ``int(pool) == 1``, then no parallel
  304. processing is used and the map builtin is used.
  305. If ``pool == -1``, then the pool will utilise all available CPUs.
  306. If `pool` is a map-like callable that follows the same
  307. calling sequence as the built-in map function, then this callable is
  308. used for parallelisation.
  309. """
  310. def __init__(self, pool=1):
  311. self.pool = None
  312. self._mapfunc = map
  313. self._own_pool = False
  314. if callable(pool):
  315. self.pool = pool
  316. self._mapfunc = self.pool
  317. else:
  318. # user supplies a number
  319. if int(pool) == -1:
  320. # use as many processors as possible
  321. self.pool = Pool()
  322. self._mapfunc = self.pool.map
  323. self._own_pool = True
  324. elif int(pool) == 1:
  325. pass
  326. elif int(pool) > 1:
  327. # use the number of processors requested
  328. self.pool = Pool(processes=int(pool))
  329. self._mapfunc = self.pool.map
  330. self._own_pool = True
  331. else:
  332. raise RuntimeError("Number of workers specified must be -1,"
  333. " an int >= 1, or an object with a 'map' method")
  334. def __enter__(self):
  335. return self
  336. def __del__(self):
  337. self.close()
  338. self.terminate()
  339. def terminate(self):
  340. if self._own_pool:
  341. self.pool.terminate()
  342. def join(self):
  343. if self._own_pool:
  344. self.pool.join()
  345. def close(self):
  346. if self._own_pool:
  347. self.pool.close()
  348. def __exit__(self, exc_type, exc_value, traceback):
  349. if self._own_pool:
  350. self.pool.close()
  351. self.pool.terminate()
  352. def __call__(self, func, iterable):
  353. # only accept one iterable because that's all Pool.map accepts
  354. try:
  355. return self._mapfunc(func, iterable)
  356. except TypeError:
  357. # wrong number of arguments
  358. raise TypeError("The map-like callable must be of the"
  359. " form f(func, iterable)")