vq.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. """
  2. ====================================================================
  3. K-means clustering and vector quantization (:mod:`scipy.cluster.vq`)
  4. ====================================================================
  5. Provides routines for k-means clustering, generating code books
  6. from k-means models, and quantizing vectors by comparing them with
  7. centroids in a code book.
  8. .. autosummary::
  9. :toctree: generated/
  10. whiten -- Normalize a group of observations so each feature has unit variance
  11. vq -- Calculate code book membership of a set of observation vectors
  12. kmeans -- Performs k-means on a set of observation vectors forming k clusters
  13. kmeans2 -- A different implementation of k-means with more methods
  14. -- for initializing centroids
  15. Background information
  16. ======================
  17. The k-means algorithm takes as input the number of clusters to
  18. generate, k, and a set of observation vectors to cluster. It
  19. returns a set of centroids, one for each of the k clusters. An
  20. observation vector is classified with the cluster number or
  21. centroid index of the centroid closest to it.
  22. A vector v belongs to cluster i if it is closer to centroid i than
  23. any other centroids. If v belongs to i, we say centroid i is the
  24. dominating centroid of v. The k-means algorithm tries to
  25. minimize distortion, which is defined as the sum of the squared distances
  26. between each observation vector and its dominating centroid.
  27. The minimization is achieved by iteratively reclassifying
  28. the observations into clusters and recalculating the centroids until
  29. a configuration is reached in which the centroids are stable. One can
  30. also define a maximum number of iterations.
  31. Since vector quantization is a natural application for k-means,
  32. information theory terminology is often used. The centroid index
  33. or cluster index is also referred to as a "code" and the table
  34. mapping codes to centroids and vice versa is often referred as a
  35. "code book". The result of k-means, a set of centroids, can be
  36. used to quantize vectors. Quantization aims to find an encoding of
  37. vectors that reduces the expected distortion.
  38. All routines expect obs to be a M by N array where the rows are
  39. the observation vectors. The codebook is a k by N array where the
  40. i'th row is the centroid of code word i. The observation vectors
  41. and centroids have the same feature dimension.
  42. As an example, suppose we wish to compress a 24-bit color image
  43. (each pixel is represented by one byte for red, one for blue, and
  44. one for green) before sending it over the web. By using a smaller
  45. 8-bit encoding, we can reduce the amount of data by two
  46. thirds. Ideally, the colors for each of the 256 possible 8-bit
  47. encoding values should be chosen to minimize distortion of the
  48. color. Running k-means with k=256 generates a code book of 256
  49. codes, which fills up all possible 8-bit sequences. Instead of
  50. sending a 3-byte value for each pixel, the 8-bit centroid index
  51. (or code word) of the dominating centroid is transmitted. The code
  52. book is also sent over the wire so each 8-bit code can be
  53. translated back to a 24-bit pixel value representation. If the
  54. image of interest was of an ocean, we would expect many 24-bit
  55. blues to be represented by 8-bit codes. If it was an image of a
  56. human face, more flesh tone colors would be represented in the
  57. code book.
  58. """
  59. from __future__ import division, print_function, absolute_import
  60. import warnings
  61. import numpy as np
  62. from collections import deque
  63. from scipy._lib._util import _asarray_validated
  64. from scipy._lib.six import xrange
  65. from scipy.spatial.distance import cdist
  66. from . import _vq
  67. __docformat__ = 'restructuredtext'
  68. __all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
  69. class ClusterError(Exception):
  70. pass
  71. def whiten(obs, check_finite=True):
  72. """
  73. Normalize a group of observations on a per feature basis.
  74. Before running k-means, it is beneficial to rescale each feature
  75. dimension of the observation set with whitening. Each feature is
  76. divided by its standard deviation across all observations to give
  77. it unit variance.
  78. Parameters
  79. ----------
  80. obs : ndarray
  81. Each row of the array is an observation. The
  82. columns are the features seen during each observation.
  83. >>> # f0 f1 f2
  84. >>> obs = [[ 1., 1., 1.], #o0
  85. ... [ 2., 2., 2.], #o1
  86. ... [ 3., 3., 3.], #o2
  87. ... [ 4., 4., 4.]] #o3
  88. check_finite : bool, optional
  89. Whether to check that the input matrices contain only finite numbers.
  90. Disabling may give a performance gain, but may result in problems
  91. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  92. Default: True
  93. Returns
  94. -------
  95. result : ndarray
  96. Contains the values in `obs` scaled by the standard deviation
  97. of each column.
  98. Examples
  99. --------
  100. >>> from scipy.cluster.vq import whiten
  101. >>> features = np.array([[1.9, 2.3, 1.7],
  102. ... [1.5, 2.5, 2.2],
  103. ... [0.8, 0.6, 1.7,]])
  104. >>> whiten(features)
  105. array([[ 4.17944278, 2.69811351, 7.21248917],
  106. [ 3.29956009, 2.93273208, 9.33380951],
  107. [ 1.75976538, 0.7038557 , 7.21248917]])
  108. """
  109. obs = _asarray_validated(obs, check_finite=check_finite)
  110. std_dev = obs.std(axis=0)
  111. zero_std_mask = std_dev == 0
  112. if zero_std_mask.any():
  113. std_dev[zero_std_mask] = 1.0
  114. warnings.warn("Some columns have standard deviation zero. "
  115. "The values of these columns will not change.",
  116. RuntimeWarning)
  117. return obs / std_dev
  118. def vq(obs, code_book, check_finite=True):
  119. """
  120. Assign codes from a code book to observations.
  121. Assigns a code from a code book to each observation. Each
  122. observation vector in the 'M' by 'N' `obs` array is compared with the
  123. centroids in the code book and assigned the code of the closest
  124. centroid.
  125. The features in `obs` should have unit variance, which can be
  126. achieved by passing them through the whiten function. The code
  127. book can be created with the k-means algorithm or a different
  128. encoding algorithm.
  129. Parameters
  130. ----------
  131. obs : ndarray
  132. Each row of the 'M' x 'N' array is an observation. The columns are
  133. the "features" seen during each observation. The features must be
  134. whitened first using the whiten function or something equivalent.
  135. code_book : ndarray
  136. The code book is usually generated using the k-means algorithm.
  137. Each row of the array holds a different code, and the columns are
  138. the features of the code.
  139. >>> # f0 f1 f2 f3
  140. >>> code_book = [
  141. ... [ 1., 2., 3., 4.], #c0
  142. ... [ 1., 2., 3., 4.], #c1
  143. ... [ 1., 2., 3., 4.]] #c2
  144. check_finite : bool, optional
  145. Whether to check that the input matrices contain only finite numbers.
  146. Disabling may give a performance gain, but may result in problems
  147. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  148. Default: True
  149. Returns
  150. -------
  151. code : ndarray
  152. A length M array holding the code book index for each observation.
  153. dist : ndarray
  154. The distortion (distance) between the observation and its nearest
  155. code.
  156. Examples
  157. --------
  158. >>> from numpy import array
  159. >>> from scipy.cluster.vq import vq
  160. >>> code_book = array([[1.,1.,1.],
  161. ... [2.,2.,2.]])
  162. >>> features = array([[ 1.9,2.3,1.7],
  163. ... [ 1.5,2.5,2.2],
  164. ... [ 0.8,0.6,1.7]])
  165. >>> vq(features,code_book)
  166. (array([1, 1, 0],'i'), array([ 0.43588989, 0.73484692, 0.83066239]))
  167. """
  168. obs = _asarray_validated(obs, check_finite=check_finite)
  169. code_book = _asarray_validated(code_book, check_finite=check_finite)
  170. ct = np.common_type(obs, code_book)
  171. c_obs = obs.astype(ct, copy=False)
  172. c_code_book = code_book.astype(ct, copy=False)
  173. if np.issubdtype(ct, np.float64) or np.issubdtype(ct, np.float32):
  174. return _vq.vq(c_obs, c_code_book)
  175. return py_vq(obs, code_book, check_finite=False)
  176. def py_vq(obs, code_book, check_finite=True):
  177. """ Python version of vq algorithm.
  178. The algorithm computes the euclidian distance between each
  179. observation and every frame in the code_book.
  180. Parameters
  181. ----------
  182. obs : ndarray
  183. Expects a rank 2 array. Each row is one observation.
  184. code_book : ndarray
  185. Code book to use. Same format than obs. Should have same number of
  186. features (eg columns) than obs.
  187. check_finite : bool, optional
  188. Whether to check that the input matrices contain only finite numbers.
  189. Disabling may give a performance gain, but may result in problems
  190. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  191. Default: True
  192. Returns
  193. -------
  194. code : ndarray
  195. code[i] gives the label of the ith obversation, that its code is
  196. code_book[code[i]].
  197. mind_dist : ndarray
  198. min_dist[i] gives the distance between the ith observation and its
  199. corresponding code.
  200. Notes
  201. -----
  202. This function is slower than the C version but works for
  203. all input types. If the inputs have the wrong types for the
  204. C versions of the function, this one is called as a last resort.
  205. It is about 20 times slower than the C version.
  206. """
  207. obs = _asarray_validated(obs, check_finite=check_finite)
  208. code_book = _asarray_validated(code_book, check_finite=check_finite)
  209. if obs.ndim != code_book.ndim:
  210. raise ValueError("Observation and code_book should have the same rank")
  211. if obs.ndim == 1:
  212. obs = obs[:, np.newaxis]
  213. code_book = code_book[:, np.newaxis]
  214. dist = cdist(obs, code_book)
  215. code = dist.argmin(axis=1)
  216. min_dist = dist[np.arange(len(code)), code]
  217. return code, min_dist
  218. # py_vq2 was equivalent to py_vq
  219. py_vq2 = np.deprecate(py_vq, old_name='py_vq2', new_name='py_vq')
  220. def _kmeans(obs, guess, thresh=1e-5):
  221. """ "raw" version of k-means.
  222. Returns
  223. -------
  224. code_book
  225. the lowest distortion codebook found.
  226. avg_dist
  227. the average distance a observation is from a code in the book.
  228. Lower means the code_book matches the data better.
  229. See Also
  230. --------
  231. kmeans : wrapper around k-means
  232. Examples
  233. --------
  234. Note: not whitened in this example.
  235. >>> from numpy import array
  236. >>> from scipy.cluster.vq import _kmeans
  237. >>> features = array([[ 1.9,2.3],
  238. ... [ 1.5,2.5],
  239. ... [ 0.8,0.6],
  240. ... [ 0.4,1.8],
  241. ... [ 1.0,1.0]])
  242. >>> book = array((features[0],features[2]))
  243. >>> _kmeans(features,book)
  244. (array([[ 1.7 , 2.4 ],
  245. [ 0.73333333, 1.13333333]]), 0.40563916697728591)
  246. """
  247. code_book = np.asarray(guess)
  248. diff = np.inf
  249. prev_avg_dists = deque([diff], maxlen=2)
  250. while diff > thresh:
  251. # compute membership and distances between obs and code_book
  252. obs_code, distort = vq(obs, code_book, check_finite=False)
  253. prev_avg_dists.append(distort.mean(axis=-1))
  254. # recalc code_book as centroids of associated obs
  255. code_book, has_members = _vq.update_cluster_means(obs, obs_code,
  256. code_book.shape[0])
  257. code_book = code_book[has_members]
  258. diff = prev_avg_dists[0] - prev_avg_dists[1]
  259. return code_book, prev_avg_dists[1]
  260. def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True):
  261. """
  262. Performs k-means on a set of observation vectors forming k clusters.
  263. The k-means algorithm adjusts the classification of the observations
  264. into clusters and updates the cluster centroids until the position of
  265. the centroids is stable over successive iterations. In this
  266. implementation of the algorithm, the stability of the centroids is
  267. determined by comparing the absolute value of the change in the average
  268. Euclidean distance between the observations and their corresponding
  269. centroids against a threshold. This yields
  270. a code book mapping centroids to codes and vice versa.
  271. Parameters
  272. ----------
  273. obs : ndarray
  274. Each row of the M by N array is an observation vector. The
  275. columns are the features seen during each observation.
  276. The features must be whitened first with the `whiten` function.
  277. k_or_guess : int or ndarray
  278. The number of centroids to generate. A code is assigned to
  279. each centroid, which is also the row index of the centroid
  280. in the code_book matrix generated.
  281. The initial k centroids are chosen by randomly selecting
  282. observations from the observation matrix. Alternatively,
  283. passing a k by N array specifies the initial k centroids.
  284. iter : int, optional
  285. The number of times to run k-means, returning the codebook
  286. with the lowest distortion. This argument is ignored if
  287. initial centroids are specified with an array for the
  288. ``k_or_guess`` parameter. This parameter does not represent the
  289. number of iterations of the k-means algorithm.
  290. thresh : float, optional
  291. Terminates the k-means algorithm if the change in
  292. distortion since the last k-means iteration is less than
  293. or equal to thresh.
  294. check_finite : bool, optional
  295. Whether to check that the input matrices contain only finite numbers.
  296. Disabling may give a performance gain, but may result in problems
  297. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  298. Default: True
  299. Returns
  300. -------
  301. codebook : ndarray
  302. A k by N array of k centroids. The i'th centroid
  303. codebook[i] is represented with the code i. The centroids
  304. and codes generated represent the lowest distortion seen,
  305. not necessarily the globally minimal distortion.
  306. distortion : float
  307. The mean (non-squared) Euclidean distance between the observations
  308. passed and the centroids generated. Note the difference to the standard
  309. definition of distortion in the context of the K-means algorithm, which
  310. is the sum of the squared distances.
  311. See Also
  312. --------
  313. kmeans2 : a different implementation of k-means clustering
  314. with more methods for generating initial centroids but without
  315. using a distortion change threshold as a stopping criterion.
  316. whiten : must be called prior to passing an observation matrix
  317. to kmeans.
  318. Examples
  319. --------
  320. >>> from numpy import array
  321. >>> from scipy.cluster.vq import vq, kmeans, whiten
  322. >>> import matplotlib.pyplot as plt
  323. >>> features = array([[ 1.9,2.3],
  324. ... [ 1.5,2.5],
  325. ... [ 0.8,0.6],
  326. ... [ 0.4,1.8],
  327. ... [ 0.1,0.1],
  328. ... [ 0.2,1.8],
  329. ... [ 2.0,0.5],
  330. ... [ 0.3,1.5],
  331. ... [ 1.0,1.0]])
  332. >>> whitened = whiten(features)
  333. >>> book = np.array((whitened[0],whitened[2]))
  334. >>> kmeans(whitened,book)
  335. (array([[ 2.3110306 , 2.86287398], # random
  336. [ 0.93218041, 1.24398691]]), 0.85684700941625547)
  337. >>> from numpy import random
  338. >>> random.seed((1000,2000))
  339. >>> codes = 3
  340. >>> kmeans(whitened,codes)
  341. (array([[ 2.3110306 , 2.86287398], # random
  342. [ 1.32544402, 0.65607529],
  343. [ 0.40782893, 2.02786907]]), 0.5196582527686241)
  344. >>> # Create 50 datapoints in two clusters a and b
  345. >>> pts = 50
  346. >>> a = np.random.multivariate_normal([0, 0], [[4, 1], [1, 4]], size=pts)
  347. >>> b = np.random.multivariate_normal([30, 10],
  348. ... [[10, 2], [2, 1]],
  349. ... size=pts)
  350. >>> features = np.concatenate((a, b))
  351. >>> # Whiten data
  352. >>> whitened = whiten(features)
  353. >>> # Find 2 clusters in the data
  354. >>> codebook, distortion = kmeans(whitened, 2)
  355. >>> # Plot whitened data and cluster centers in red
  356. >>> plt.scatter(whitened[:, 0], whitened[:, 1])
  357. >>> plt.scatter(codebook[:, 0], codebook[:, 1], c='r')
  358. >>> plt.show()
  359. """
  360. obs = _asarray_validated(obs, check_finite=check_finite)
  361. if iter < 1:
  362. raise ValueError("iter must be at least 1, got %s" % iter)
  363. # Determine whether a count (scalar) or an initial guess (array) was passed.
  364. if not np.isscalar(k_or_guess):
  365. guess = _asarray_validated(k_or_guess, check_finite=check_finite)
  366. if guess.size < 1:
  367. raise ValueError("Asked for 0 clusters. Initial book was %s" %
  368. guess)
  369. return _kmeans(obs, guess, thresh=thresh)
  370. # k_or_guess is a scalar, now verify that it's an integer
  371. k = int(k_or_guess)
  372. if k != k_or_guess:
  373. raise ValueError("If k_or_guess is a scalar, it must be an integer.")
  374. if k < 1:
  375. raise ValueError("Asked for %d clusters." % k)
  376. # initialize best distance value to a large value
  377. best_dist = np.inf
  378. for i in xrange(iter):
  379. # the initial code book is randomly selected from observations
  380. guess = _kpoints(obs, k)
  381. book, dist = _kmeans(obs, guess, thresh=thresh)
  382. if dist < best_dist:
  383. best_book = book
  384. best_dist = dist
  385. return best_book, best_dist
  386. def _kpoints(data, k):
  387. """Pick k points at random in data (one row = one observation).
  388. Parameters
  389. ----------
  390. data : ndarray
  391. Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
  392. dimensional data, rank 2 multidimensional data, in which case one
  393. row is one observation.
  394. k : int
  395. Number of samples to generate.
  396. Returns
  397. -------
  398. x : ndarray
  399. A 'k' by 'N' containing the initial centroids
  400. """
  401. idx = np.random.choice(data.shape[0], size=k, replace=False)
  402. return data[idx]
  403. def _krandinit(data, k):
  404. """Returns k samples of a random variable which parameters depend on data.
  405. More precisely, it returns k observations sampled from a Gaussian random
  406. variable which mean and covariances are the one estimated from data.
  407. Parameters
  408. ----------
  409. data : ndarray
  410. Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
  411. dimensional data, rank 2 multidimensional data, in which case one
  412. row is one observation.
  413. k : int
  414. Number of samples to generate.
  415. Returns
  416. -------
  417. x : ndarray
  418. A 'k' by 'N' containing the initial centroids
  419. """
  420. mu = data.mean(axis=0)
  421. if data.ndim == 1:
  422. cov = np.cov(data)
  423. x = np.random.randn(k)
  424. x *= np.sqrt(cov)
  425. elif data.shape[1] > data.shape[0]:
  426. # initialize when the covariance matrix is rank deficient
  427. _, s, vh = np.linalg.svd(data - mu, full_matrices=False)
  428. x = np.random.randn(k, s.size)
  429. sVh = s[:, None] * vh / np.sqrt(data.shape[0] - 1)
  430. x = x.dot(sVh)
  431. else:
  432. cov = np.atleast_2d(np.cov(data, rowvar=False))
  433. # k rows, d cols (one row = one obs)
  434. # Generate k sample of a random variable ~ Gaussian(mu, cov)
  435. x = np.random.randn(k, mu.size)
  436. x = x.dot(np.linalg.cholesky(cov).T)
  437. x += mu
  438. return x
  439. def _kpp(data, k):
  440. """ Picks k points in data based on the kmeans++ method
  441. Parameters
  442. ----------
  443. data : ndarray
  444. Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
  445. dimensional data, rank 2 multidimensional data, in which case one
  446. row is one observation.
  447. k : int
  448. Number of samples to generate.
  449. Returns
  450. -------
  451. init : ndarray
  452. A 'k' by 'N' containing the initial centroids
  453. References
  454. ----------
  455. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  456. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  457. on Discrete Algorithms, 2007.
  458. """
  459. dims = data.shape[1] if len(data.shape) > 1 else 1
  460. init = np.ndarray((k, dims))
  461. for i in range(k):
  462. if i == 0:
  463. init[i, :] = data[np.random.randint(dims)]
  464. else:
  465. D2 = np.array([min(
  466. [np.inner(init[j]-x, init[j]-x) for j in range(i)]
  467. ) for x in data])
  468. probs = D2/D2.sum()
  469. cumprobs = probs.cumsum()
  470. r = np.random.rand()
  471. init[i, :] = data[np.searchsorted(cumprobs, r)]
  472. return init
  473. _valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}
  474. def _missing_warn():
  475. """Print a warning when called."""
  476. warnings.warn("One of the clusters is empty. "
  477. "Re-run kmeans with a different initialization.")
  478. def _missing_raise():
  479. """raise a ClusterError when called."""
  480. raise ClusterError("One of the clusters is empty. "
  481. "Re-run kmeans with a different initialization.")
  482. _valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
  483. def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
  484. missing='warn', check_finite=True):
  485. """
  486. Classify a set of observations into k clusters using the k-means algorithm.
  487. The algorithm attempts to minimize the Euclidian distance between
  488. observations and centroids. Several initialization methods are
  489. included.
  490. Parameters
  491. ----------
  492. data : ndarray
  493. A 'M' by 'N' array of 'M' observations in 'N' dimensions or a length
  494. 'M' array of 'M' one-dimensional observations.
  495. k : int or ndarray
  496. The number of clusters to form as well as the number of
  497. centroids to generate. If `minit` initialization string is
  498. 'matrix', or if a ndarray is given instead, it is
  499. interpreted as initial cluster to use instead.
  500. iter : int, optional
  501. Number of iterations of the k-means algorithm to run. Note
  502. that this differs in meaning from the iters parameter to
  503. the kmeans function.
  504. thresh : float, optional
  505. (not used yet)
  506. minit : str, optional
  507. Method for initialization. Available methods are 'random',
  508. 'points', '++' and 'matrix':
  509. 'random': generate k centroids from a Gaussian with mean and
  510. variance estimated from the data.
  511. 'points': choose k observations (rows) at random from data for
  512. the initial centroids.
  513. '++': choose k observations accordingly to the kmeans++ method
  514. (careful seeding)
  515. 'matrix': interpret the k parameter as a k by M (or length k
  516. array for one-dimensional data) array of initial centroids.
  517. missing : str, optional
  518. Method to deal with empty clusters. Available methods are
  519. 'warn' and 'raise':
  520. 'warn': give a warning and continue.
  521. 'raise': raise an ClusterError and terminate the algorithm.
  522. check_finite : bool, optional
  523. Whether to check that the input matrices contain only finite numbers.
  524. Disabling may give a performance gain, but may result in problems
  525. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  526. Default: True
  527. Returns
  528. -------
  529. centroid : ndarray
  530. A 'k' by 'N' array of centroids found at the last iteration of
  531. k-means.
  532. label : ndarray
  533. label[i] is the code or index of the centroid the
  534. i'th observation is closest to.
  535. References
  536. ----------
  537. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  538. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  539. on Discrete Algorithms, 2007.
  540. """
  541. if int(iter) < 1:
  542. raise ValueError("Invalid iter (%s), "
  543. "must be a positive integer." % iter)
  544. try:
  545. miss_meth = _valid_miss_meth[missing]
  546. except KeyError:
  547. raise ValueError("Unknown missing method %r" % (missing,))
  548. data = _asarray_validated(data, check_finite=check_finite)
  549. if data.ndim == 1:
  550. d = 1
  551. elif data.ndim == 2:
  552. d = data.shape[1]
  553. else:
  554. raise ValueError("Input of rank > 2 is not supported.")
  555. if data.size < 1:
  556. raise ValueError("Empty input is not supported.")
  557. # If k is not a single value it should be compatible with data's shape
  558. if minit == 'matrix' or not np.isscalar(k):
  559. code_book = np.array(k, copy=True)
  560. if data.ndim != code_book.ndim:
  561. raise ValueError("k array doesn't match data rank")
  562. nc = len(code_book)
  563. if data.ndim > 1 and code_book.shape[1] != d:
  564. raise ValueError("k array doesn't match data dimension")
  565. else:
  566. nc = int(k)
  567. if nc < 1:
  568. raise ValueError("Cannot ask kmeans2 for %d clusters"
  569. " (k was %s)" % (nc, k))
  570. elif nc != k:
  571. warnings.warn("k was not an integer, was converted.")
  572. try:
  573. init_meth = _valid_init_meth[minit]
  574. except KeyError:
  575. raise ValueError("Unknown init method %r" % (minit,))
  576. else:
  577. code_book = init_meth(data, k)
  578. for i in xrange(iter):
  579. # Compute the nearest neighbor for each obs using the current code book
  580. label = vq(data, code_book)[0]
  581. # Update the code book by computing centroids
  582. new_code_book, has_members = _vq.update_cluster_means(data, label, nc)
  583. if not has_members.all():
  584. miss_meth()
  585. # Set the empty clusters to their previous positions
  586. new_code_book[~has_members] = code_book[~has_members]
  587. code_book = new_code_book
  588. return code_book, label