einsumfunc.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422
  1. """
  2. Implementation of optimized einsum.
  3. """
  4. from __future__ import division, absolute_import, print_function
  5. import itertools
  6. from numpy.compat import basestring
  7. from numpy.core.multiarray import c_einsum
  8. from numpy.core.numeric import asanyarray, tensordot
  9. from numpy.core.overrides import array_function_dispatch
  10. __all__ = ['einsum', 'einsum_path']
  11. einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
  12. einsum_symbols_set = set(einsum_symbols)
  13. def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
  14. """
  15. Computes the number of FLOPS in the contraction.
  16. Parameters
  17. ----------
  18. idx_contraction : iterable
  19. The indices involved in the contraction
  20. inner : bool
  21. Does this contraction require an inner product?
  22. num_terms : int
  23. The number of terms in a contraction
  24. size_dictionary : dict
  25. The size of each of the indices in idx_contraction
  26. Returns
  27. -------
  28. flop_count : int
  29. The total number of FLOPS required for the contraction.
  30. Examples
  31. --------
  32. >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  33. 90
  34. >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  35. 270
  36. """
  37. overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
  38. op_factor = max(1, num_terms - 1)
  39. if inner:
  40. op_factor += 1
  41. return overall_size * op_factor
  42. def _compute_size_by_dict(indices, idx_dict):
  43. """
  44. Computes the product of the elements in indices based on the dictionary
  45. idx_dict.
  46. Parameters
  47. ----------
  48. indices : iterable
  49. Indices to base the product on.
  50. idx_dict : dictionary
  51. Dictionary of index sizes
  52. Returns
  53. -------
  54. ret : int
  55. The resulting product.
  56. Examples
  57. --------
  58. >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  59. 90
  60. """
  61. ret = 1
  62. for i in indices:
  63. ret *= idx_dict[i]
  64. return ret
  65. def _find_contraction(positions, input_sets, output_set):
  66. """
  67. Finds the contraction for a given set of input and output sets.
  68. Parameters
  69. ----------
  70. positions : iterable
  71. Integer positions of terms used in the contraction.
  72. input_sets : list
  73. List of sets that represent the lhs side of the einsum subscript
  74. output_set : set
  75. Set that represents the rhs side of the overall einsum subscript
  76. Returns
  77. -------
  78. new_result : set
  79. The indices of the resulting contraction
  80. remaining : list
  81. List of sets that have not been contracted, the new set is appended to
  82. the end of this list
  83. idx_removed : set
  84. Indices removed from the entire contraction
  85. idx_contraction : set
  86. The indices used in the current contraction
  87. Examples
  88. --------
  89. # A simple dot product test case
  90. >>> pos = (0, 1)
  91. >>> isets = [set('ab'), set('bc')]
  92. >>> oset = set('ac')
  93. >>> _find_contraction(pos, isets, oset)
  94. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  95. # A more complex case with additional terms in the contraction
  96. >>> pos = (0, 2)
  97. >>> isets = [set('abd'), set('ac'), set('bdc')]
  98. >>> oset = set('ac')
  99. >>> _find_contraction(pos, isets, oset)
  100. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  101. """
  102. idx_contract = set()
  103. idx_remain = output_set.copy()
  104. remaining = []
  105. for ind, value in enumerate(input_sets):
  106. if ind in positions:
  107. idx_contract |= value
  108. else:
  109. remaining.append(value)
  110. idx_remain |= value
  111. new_result = idx_remain & idx_contract
  112. idx_removed = (idx_contract - new_result)
  113. remaining.append(new_result)
  114. return (new_result, remaining, idx_removed, idx_contract)
  115. def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
  116. """
  117. Computes all possible pair contractions, sieves the results based
  118. on ``memory_limit`` and returns the lowest cost path. This algorithm
  119. scales factorial with respect to the elements in the list ``input_sets``.
  120. Parameters
  121. ----------
  122. input_sets : list
  123. List of sets that represent the lhs side of the einsum subscript
  124. output_set : set
  125. Set that represents the rhs side of the overall einsum subscript
  126. idx_dict : dictionary
  127. Dictionary of index sizes
  128. memory_limit : int
  129. The maximum number of elements in a temporary array
  130. Returns
  131. -------
  132. path : list
  133. The optimal contraction order within the memory limit constraint.
  134. Examples
  135. --------
  136. >>> isets = [set('abd'), set('ac'), set('bdc')]
  137. >>> oset = set()
  138. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  139. >>> _path__optimal_path(isets, oset, idx_sizes, 5000)
  140. [(0, 2), (0, 1)]
  141. """
  142. full_results = [(0, [], input_sets)]
  143. for iteration in range(len(input_sets) - 1):
  144. iter_results = []
  145. # Compute all unique pairs
  146. for curr in full_results:
  147. cost, positions, remaining = curr
  148. for con in itertools.combinations(range(len(input_sets) - iteration), 2):
  149. # Find the contraction
  150. cont = _find_contraction(con, remaining, output_set)
  151. new_result, new_input_sets, idx_removed, idx_contract = cont
  152. # Sieve the results based on memory_limit
  153. new_size = _compute_size_by_dict(new_result, idx_dict)
  154. if new_size > memory_limit:
  155. continue
  156. # Build (total_cost, positions, indices_remaining)
  157. total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
  158. new_pos = positions + [con]
  159. iter_results.append((total_cost, new_pos, new_input_sets))
  160. # Update combinatorial list, if we did not find anything return best
  161. # path + remaining contractions
  162. if iter_results:
  163. full_results = iter_results
  164. else:
  165. path = min(full_results, key=lambda x: x[0])[1]
  166. path += [tuple(range(len(input_sets) - iteration))]
  167. return path
  168. # If we have not found anything return single einsum contraction
  169. if len(full_results) == 0:
  170. return [tuple(range(len(input_sets)))]
  171. path = min(full_results, key=lambda x: x[0])[1]
  172. return path
  173. def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
  174. """Compute the cost (removed size + flops) and resultant indices for
  175. performing the contraction specified by ``positions``.
  176. Parameters
  177. ----------
  178. positions : tuple of int
  179. The locations of the proposed tensors to contract.
  180. input_sets : list of sets
  181. The indices found on each tensors.
  182. output_set : set
  183. The output indices of the expression.
  184. idx_dict : dict
  185. Mapping of each index to its size.
  186. memory_limit : int
  187. The total allowed size for an intermediary tensor.
  188. path_cost : int
  189. The contraction cost so far.
  190. naive_cost : int
  191. The cost of the unoptimized expression.
  192. Returns
  193. -------
  194. cost : (int, int)
  195. A tuple containing the size of any indices removed, and the flop cost.
  196. positions : tuple of int
  197. The locations of the proposed tensors to contract.
  198. new_input_sets : list of sets
  199. The resulting new list of indices if this proposed contraction is performed.
  200. """
  201. # Find the contraction
  202. contract = _find_contraction(positions, input_sets, output_set)
  203. idx_result, new_input_sets, idx_removed, idx_contract = contract
  204. # Sieve the results based on memory_limit
  205. new_size = _compute_size_by_dict(idx_result, idx_dict)
  206. if new_size > memory_limit:
  207. return None
  208. # Build sort tuple
  209. old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
  210. removed_size = sum(old_sizes) - new_size
  211. # NB: removed_size used to be just the size of any removed indices i.e.:
  212. # helpers.compute_size_by_dict(idx_removed, idx_dict)
  213. cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
  214. sort = (-removed_size, cost)
  215. # Sieve based on total cost as well
  216. if (path_cost + cost) > naive_cost:
  217. return None
  218. # Add contraction to possible choices
  219. return [sort, positions, new_input_sets]
  220. def _update_other_results(results, best):
  221. """Update the positions and provisional input_sets of ``results`` based on
  222. performing the contraction result ``best``. Remove any involving the tensors
  223. contracted.
  224. Parameters
  225. ----------
  226. results : list
  227. List of contraction results produced by ``_parse_possible_contraction``.
  228. best : list
  229. The best contraction of ``results`` i.e. the one that will be performed.
  230. Returns
  231. -------
  232. mod_results : list
  233. The list of modifed results, updated with outcome of ``best`` contraction.
  234. """
  235. best_con = best[1]
  236. bx, by = best_con
  237. mod_results = []
  238. for cost, (x, y), con_sets in results:
  239. # Ignore results involving tensors just contracted
  240. if x in best_con or y in best_con:
  241. continue
  242. # Update the input_sets
  243. del con_sets[by - int(by > x) - int(by > y)]
  244. del con_sets[bx - int(bx > x) - int(bx > y)]
  245. con_sets.insert(-1, best[2][-1])
  246. # Update the position indices
  247. mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
  248. mod_results.append((cost, mod_con, con_sets))
  249. return mod_results
  250. def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
  251. """
  252. Finds the path by contracting the best pair until the input list is
  253. exhausted. The best pair is found by minimizing the tuple
  254. ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
  255. matrix multiplication or inner product operations, then Hadamard like
  256. operations, and finally outer operations. Outer products are limited by
  257. ``memory_limit``. This algorithm scales cubically with respect to the
  258. number of elements in the list ``input_sets``.
  259. Parameters
  260. ----------
  261. input_sets : list
  262. List of sets that represent the lhs side of the einsum subscript
  263. output_set : set
  264. Set that represents the rhs side of the overall einsum subscript
  265. idx_dict : dictionary
  266. Dictionary of index sizes
  267. memory_limit_limit : int
  268. The maximum number of elements in a temporary array
  269. Returns
  270. -------
  271. path : list
  272. The greedy contraction order within the memory limit constraint.
  273. Examples
  274. --------
  275. >>> isets = [set('abd'), set('ac'), set('bdc')]
  276. >>> oset = set()
  277. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  278. >>> _path__greedy_path(isets, oset, idx_sizes, 5000)
  279. [(0, 2), (0, 1)]
  280. """
  281. # Handle trivial cases that leaked through
  282. if len(input_sets) == 1:
  283. return [(0,)]
  284. elif len(input_sets) == 2:
  285. return [(0, 1)]
  286. # Build up a naive cost
  287. contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
  288. idx_result, new_input_sets, idx_removed, idx_contract = contract
  289. naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
  290. # Initially iterate over all pairs
  291. comb_iter = itertools.combinations(range(len(input_sets)), 2)
  292. known_contractions = []
  293. path_cost = 0
  294. path = []
  295. for iteration in range(len(input_sets) - 1):
  296. # Iterate over all pairs on first step, only previously found pairs on subsequent steps
  297. for positions in comb_iter:
  298. # Always initially ignore outer products
  299. if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
  300. continue
  301. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
  302. naive_cost)
  303. if result is not None:
  304. known_contractions.append(result)
  305. # If we do not have a inner contraction, rescan pairs including outer products
  306. if len(known_contractions) == 0:
  307. # Then check the outer products
  308. for positions in itertools.combinations(range(len(input_sets)), 2):
  309. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
  310. path_cost, naive_cost)
  311. if result is not None:
  312. known_contractions.append(result)
  313. # If we still did not find any remaining contractions, default back to einsum like behavior
  314. if len(known_contractions) == 0:
  315. path.append(tuple(range(len(input_sets))))
  316. break
  317. # Sort based on first index
  318. best = min(known_contractions, key=lambda x: x[0])
  319. # Now propagate as many unused contractions as possible to next iteration
  320. known_contractions = _update_other_results(known_contractions, best)
  321. # Next iteration only compute contractions with the new tensor
  322. # All other contractions have been accounted for
  323. input_sets = best[2]
  324. new_tensor_pos = len(input_sets) - 1
  325. comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
  326. # Update path and total cost
  327. path.append(best[1])
  328. path_cost += best[0][1]
  329. return path
  330. def _can_dot(inputs, result, idx_removed):
  331. """
  332. Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
  333. Parameters
  334. ----------
  335. inputs : list of str
  336. Specifies the subscripts for summation.
  337. result : str
  338. Resulting summation.
  339. idx_removed : set
  340. Indices that are removed in the summation
  341. Returns
  342. -------
  343. type : bool
  344. Returns true if BLAS should and can be used, else False
  345. Notes
  346. -----
  347. If the operations is BLAS level 1 or 2 and is not already aligned
  348. we default back to einsum as the memory movement to copy is more
  349. costly than the operation itself.
  350. Examples
  351. --------
  352. # Standard GEMM operation
  353. >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
  354. True
  355. # Can use the standard BLAS, but requires odd data movement
  356. >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
  357. False
  358. # DDOT where the memory is not aligned
  359. >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
  360. False
  361. """
  362. # All `dot` calls remove indices
  363. if len(idx_removed) == 0:
  364. return False
  365. # BLAS can only handle two operands
  366. if len(inputs) != 2:
  367. return False
  368. input_left, input_right = inputs
  369. for c in set(input_left + input_right):
  370. # can't deal with repeated indices on same input or more than 2 total
  371. nl, nr = input_left.count(c), input_right.count(c)
  372. if (nl > 1) or (nr > 1) or (nl + nr > 2):
  373. return False
  374. # can't do implicit summation or dimension collapse e.g.
  375. # "ab,bc->c" (implicitly sum over 'a')
  376. # "ab,ca->ca" (take diagonal of 'a')
  377. if nl + nr - 1 == int(c in result):
  378. return False
  379. # Build a few temporaries
  380. set_left = set(input_left)
  381. set_right = set(input_right)
  382. keep_left = set_left - idx_removed
  383. keep_right = set_right - idx_removed
  384. rs = len(idx_removed)
  385. # At this point we are a DOT, GEMV, or GEMM operation
  386. # Handle inner products
  387. # DDOT with aligned data
  388. if input_left == input_right:
  389. return True
  390. # DDOT without aligned data (better to use einsum)
  391. if set_left == set_right:
  392. return False
  393. # Handle the 4 possible (aligned) GEMV or GEMM cases
  394. # GEMM or GEMV no transpose
  395. if input_left[-rs:] == input_right[:rs]:
  396. return True
  397. # GEMM or GEMV transpose both
  398. if input_left[:rs] == input_right[-rs:]:
  399. return True
  400. # GEMM or GEMV transpose right
  401. if input_left[-rs:] == input_right[-rs:]:
  402. return True
  403. # GEMM or GEMV transpose left
  404. if input_left[:rs] == input_right[:rs]:
  405. return True
  406. # Einsum is faster than GEMV if we have to copy data
  407. if not keep_left or not keep_right:
  408. return False
  409. # We are a matrix-matrix product, but we need to copy data
  410. return True
  411. def _parse_einsum_input(operands):
  412. """
  413. A reproduction of einsum c side einsum parsing in python.
  414. Returns
  415. -------
  416. input_strings : str
  417. Parsed input strings
  418. output_string : str
  419. Parsed output string
  420. operands : list of array_like
  421. The operands to use in the numpy contraction
  422. Examples
  423. --------
  424. The operand list is simplified to reduce printing:
  425. >>> a = np.random.rand(4, 4)
  426. >>> b = np.random.rand(4, 4, 4)
  427. >>> __parse_einsum_input(('...a,...a->...', a, b))
  428. ('za,xza', 'xz', [a, b])
  429. >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  430. ('za,xza', 'xz', [a, b])
  431. """
  432. if len(operands) == 0:
  433. raise ValueError("No input operands")
  434. if isinstance(operands[0], basestring):
  435. subscripts = operands[0].replace(" ", "")
  436. operands = [asanyarray(v) for v in operands[1:]]
  437. # Ensure all characters are valid
  438. for s in subscripts:
  439. if s in '.,->':
  440. continue
  441. if s not in einsum_symbols:
  442. raise ValueError("Character %s is not a valid symbol." % s)
  443. else:
  444. tmp_operands = list(operands)
  445. operand_list = []
  446. subscript_list = []
  447. for p in range(len(operands) // 2):
  448. operand_list.append(tmp_operands.pop(0))
  449. subscript_list.append(tmp_operands.pop(0))
  450. output_list = tmp_operands[-1] if len(tmp_operands) else None
  451. operands = [asanyarray(v) for v in operand_list]
  452. subscripts = ""
  453. last = len(subscript_list) - 1
  454. for num, sub in enumerate(subscript_list):
  455. for s in sub:
  456. if s is Ellipsis:
  457. subscripts += "..."
  458. elif isinstance(s, int):
  459. subscripts += einsum_symbols[s]
  460. else:
  461. raise TypeError("For this input type lists must contain "
  462. "either int or Ellipsis")
  463. if num != last:
  464. subscripts += ","
  465. if output_list is not None:
  466. subscripts += "->"
  467. for s in output_list:
  468. if s is Ellipsis:
  469. subscripts += "..."
  470. elif isinstance(s, int):
  471. subscripts += einsum_symbols[s]
  472. else:
  473. raise TypeError("For this input type lists must contain "
  474. "either int or Ellipsis")
  475. # Check for proper "->"
  476. if ("-" in subscripts) or (">" in subscripts):
  477. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  478. if invalid or (subscripts.count("->") != 1):
  479. raise ValueError("Subscripts can only contain one '->'.")
  480. # Parse ellipses
  481. if "." in subscripts:
  482. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  483. unused = list(einsum_symbols_set - set(used))
  484. ellipse_inds = "".join(unused)
  485. longest = 0
  486. if "->" in subscripts:
  487. input_tmp, output_sub = subscripts.split("->")
  488. split_subscripts = input_tmp.split(",")
  489. out_sub = True
  490. else:
  491. split_subscripts = subscripts.split(',')
  492. out_sub = False
  493. for num, sub in enumerate(split_subscripts):
  494. if "." in sub:
  495. if (sub.count(".") != 3) or (sub.count("...") != 1):
  496. raise ValueError("Invalid Ellipses.")
  497. # Take into account numerical values
  498. if operands[num].shape == ():
  499. ellipse_count = 0
  500. else:
  501. ellipse_count = max(operands[num].ndim, 1)
  502. ellipse_count -= (len(sub) - 3)
  503. if ellipse_count > longest:
  504. longest = ellipse_count
  505. if ellipse_count < 0:
  506. raise ValueError("Ellipses lengths do not match.")
  507. elif ellipse_count == 0:
  508. split_subscripts[num] = sub.replace('...', '')
  509. else:
  510. rep_inds = ellipse_inds[-ellipse_count:]
  511. split_subscripts[num] = sub.replace('...', rep_inds)
  512. subscripts = ",".join(split_subscripts)
  513. if longest == 0:
  514. out_ellipse = ""
  515. else:
  516. out_ellipse = ellipse_inds[-longest:]
  517. if out_sub:
  518. subscripts += "->" + output_sub.replace("...", out_ellipse)
  519. else:
  520. # Special care for outputless ellipses
  521. output_subscript = ""
  522. tmp_subscripts = subscripts.replace(",", "")
  523. for s in sorted(set(tmp_subscripts)):
  524. if s not in (einsum_symbols):
  525. raise ValueError("Character %s is not a valid symbol." % s)
  526. if tmp_subscripts.count(s) == 1:
  527. output_subscript += s
  528. normal_inds = ''.join(sorted(set(output_subscript) -
  529. set(out_ellipse)))
  530. subscripts += "->" + out_ellipse + normal_inds
  531. # Build output string if does not exist
  532. if "->" in subscripts:
  533. input_subscripts, output_subscript = subscripts.split("->")
  534. else:
  535. input_subscripts = subscripts
  536. # Build output subscripts
  537. tmp_subscripts = subscripts.replace(",", "")
  538. output_subscript = ""
  539. for s in sorted(set(tmp_subscripts)):
  540. if s not in einsum_symbols:
  541. raise ValueError("Character %s is not a valid symbol." % s)
  542. if tmp_subscripts.count(s) == 1:
  543. output_subscript += s
  544. # Make sure output subscripts are in the input
  545. for char in output_subscript:
  546. if char not in input_subscripts:
  547. raise ValueError("Output character %s did not appear in the input"
  548. % char)
  549. # Make sure number operands is equivalent to the number of terms
  550. if len(input_subscripts.split(',')) != len(operands):
  551. raise ValueError("Number of einsum subscripts must be equal to the "
  552. "number of operands.")
  553. return (input_subscripts, output_subscript, operands)
  554. def _einsum_path_dispatcher(*operands, **kwargs):
  555. # NOTE: technically, we should only dispatch on array-like arguments, not
  556. # subscripts (given as strings). But separating operands into
  557. # arrays/subscripts is a little tricky/slow (given einsum's two supported
  558. # signatures), so as a practical shortcut we dispatch on everything.
  559. # Strings will be ignored for dispatching since they don't define
  560. # __array_function__.
  561. return operands
  562. @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
  563. def einsum_path(*operands, **kwargs):
  564. """
  565. einsum_path(subscripts, *operands, optimize='greedy')
  566. Evaluates the lowest cost contraction order for an einsum expression by
  567. considering the creation of intermediate arrays.
  568. Parameters
  569. ----------
  570. subscripts : str
  571. Specifies the subscripts for summation.
  572. *operands : list of array_like
  573. These are the arrays for the operation.
  574. optimize : {bool, list, tuple, 'greedy', 'optimal'}
  575. Choose the type of path. If a tuple is provided, the second argument is
  576. assumed to be the maximum intermediate size created. If only a single
  577. argument is provided the largest input or output array size is used
  578. as a maximum intermediate size.
  579. * if a list is given that starts with ``einsum_path``, uses this as the
  580. contraction path
  581. * if False no optimization is taken
  582. * if True defaults to the 'greedy' algorithm
  583. * 'optimal' An algorithm that combinatorially explores all possible
  584. ways of contracting the listed tensors and choosest the least costly
  585. path. Scales exponentially with the number of terms in the
  586. contraction.
  587. * 'greedy' An algorithm that chooses the best pair contraction
  588. at each step. Effectively, this algorithm searches the largest inner,
  589. Hadamard, and then outer products at each step. Scales cubically with
  590. the number of terms in the contraction. Equivalent to the 'optimal'
  591. path for most contractions.
  592. Default is 'greedy'.
  593. Returns
  594. -------
  595. path : list of tuples
  596. A list representation of the einsum path.
  597. string_repr : str
  598. A printable representation of the einsum path.
  599. Notes
  600. -----
  601. The resulting path indicates which terms of the input contraction should be
  602. contracted first, the result of this contraction is then appended to the
  603. end of the contraction list. This list can then be iterated over until all
  604. intermediate contractions are complete.
  605. See Also
  606. --------
  607. einsum, linalg.multi_dot
  608. Examples
  609. --------
  610. We can begin with a chain dot example. In this case, it is optimal to
  611. contract the ``b`` and ``c`` tensors first as represented by the first
  612. element of the path ``(1, 2)``. The resulting tensor is added to the end
  613. of the contraction and the remaining contraction ``(0, 1)`` is then
  614. completed.
  615. >>> a = np.random.rand(2, 2)
  616. >>> b = np.random.rand(2, 5)
  617. >>> c = np.random.rand(5, 2)
  618. >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
  619. >>> print(path_info[0])
  620. ['einsum_path', (1, 2), (0, 1)]
  621. >>> print(path_info[1])
  622. Complete contraction: ij,jk,kl->il
  623. Naive scaling: 4
  624. Optimized scaling: 3
  625. Naive FLOP count: 1.600e+02
  626. Optimized FLOP count: 5.600e+01
  627. Theoretical speedup: 2.857
  628. Largest intermediate: 4.000e+00 elements
  629. -------------------------------------------------------------------------
  630. scaling current remaining
  631. -------------------------------------------------------------------------
  632. 3 kl,jk->jl ij,jl->il
  633. 3 jl,ij->il il->il
  634. A more complex index transformation example.
  635. >>> I = np.random.rand(10, 10, 10, 10)
  636. >>> C = np.random.rand(10, 10)
  637. >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
  638. optimize='greedy')
  639. >>> print(path_info[0])
  640. ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  641. >>> print(path_info[1])
  642. Complete contraction: ea,fb,abcd,gc,hd->efgh
  643. Naive scaling: 8
  644. Optimized scaling: 5
  645. Naive FLOP count: 8.000e+08
  646. Optimized FLOP count: 8.000e+05
  647. Theoretical speedup: 1000.000
  648. Largest intermediate: 1.000e+04 elements
  649. --------------------------------------------------------------------------
  650. scaling current remaining
  651. --------------------------------------------------------------------------
  652. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  653. 5 bcde,fb->cdef gc,hd,cdef->efgh
  654. 5 cdef,gc->defg hd,defg->efgh
  655. 5 defg,hd->efgh efgh->efgh
  656. """
  657. # Make sure all keywords are valid
  658. valid_contract_kwargs = ['optimize', 'einsum_call']
  659. unknown_kwargs = [k for (k, v) in kwargs.items() if k
  660. not in valid_contract_kwargs]
  661. if len(unknown_kwargs):
  662. raise TypeError("Did not understand the following kwargs:"
  663. " %s" % unknown_kwargs)
  664. # Figure out what the path really is
  665. path_type = kwargs.pop('optimize', True)
  666. if path_type is True:
  667. path_type = 'greedy'
  668. if path_type is None:
  669. path_type = False
  670. memory_limit = None
  671. # No optimization or a named path algorithm
  672. if (path_type is False) or isinstance(path_type, basestring):
  673. pass
  674. # Given an explicit path
  675. elif len(path_type) and (path_type[0] == 'einsum_path'):
  676. pass
  677. # Path tuple with memory limit
  678. elif ((len(path_type) == 2) and isinstance(path_type[0], basestring) and
  679. isinstance(path_type[1], (int, float))):
  680. memory_limit = int(path_type[1])
  681. path_type = path_type[0]
  682. else:
  683. raise TypeError("Did not understand the path: %s" % str(path_type))
  684. # Hidden option, only einsum should call this
  685. einsum_call_arg = kwargs.pop("einsum_call", False)
  686. # Python side parsing
  687. input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
  688. # Build a few useful list and sets
  689. input_list = input_subscripts.split(',')
  690. input_sets = [set(x) for x in input_list]
  691. output_set = set(output_subscript)
  692. indices = set(input_subscripts.replace(',', ''))
  693. # Get length of each unique dimension and ensure all dimensions are correct
  694. dimension_dict = {}
  695. broadcast_indices = [[] for x in range(len(input_list))]
  696. for tnum, term in enumerate(input_list):
  697. sh = operands[tnum].shape
  698. if len(sh) != len(term):
  699. raise ValueError("Einstein sum subscript %s does not contain the "
  700. "correct number of indices for operand %d."
  701. % (input_subscripts[tnum], tnum))
  702. for cnum, char in enumerate(term):
  703. dim = sh[cnum]
  704. # Build out broadcast indices
  705. if dim == 1:
  706. broadcast_indices[tnum].append(char)
  707. if char in dimension_dict.keys():
  708. # For broadcasting cases we always want the largest dim size
  709. if dimension_dict[char] == 1:
  710. dimension_dict[char] = dim
  711. elif dim not in (1, dimension_dict[char]):
  712. raise ValueError("Size of label '%s' for operand %d (%d) "
  713. "does not match previous terms (%d)."
  714. % (char, tnum, dimension_dict[char], dim))
  715. else:
  716. dimension_dict[char] = dim
  717. # Convert broadcast inds to sets
  718. broadcast_indices = [set(x) for x in broadcast_indices]
  719. # Compute size of each input array plus the output array
  720. size_list = [_compute_size_by_dict(term, dimension_dict)
  721. for term in input_list + [output_subscript]]
  722. max_size = max(size_list)
  723. if memory_limit is None:
  724. memory_arg = max_size
  725. else:
  726. memory_arg = memory_limit
  727. # Compute naive cost
  728. # This isn't quite right, need to look into exactly how einsum does this
  729. inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
  730. naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
  731. # Compute the path
  732. if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
  733. # Nothing to be optimized, leave it to einsum
  734. path = [tuple(range(len(input_list)))]
  735. elif path_type == "greedy":
  736. path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
  737. elif path_type == "optimal":
  738. path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
  739. elif path_type[0] == 'einsum_path':
  740. path = path_type[1:]
  741. else:
  742. raise KeyError("Path name %s not found", path_type)
  743. cost_list, scale_list, size_list, contraction_list = [], [], [], []
  744. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  745. for cnum, contract_inds in enumerate(path):
  746. # Make sure we remove inds from right to left
  747. contract_inds = tuple(sorted(list(contract_inds), reverse=True))
  748. contract = _find_contraction(contract_inds, input_sets, output_set)
  749. out_inds, input_sets, idx_removed, idx_contract = contract
  750. cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
  751. cost_list.append(cost)
  752. scale_list.append(len(idx_contract))
  753. size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
  754. bcast = set()
  755. tmp_inputs = []
  756. for x in contract_inds:
  757. tmp_inputs.append(input_list.pop(x))
  758. bcast |= broadcast_indices.pop(x)
  759. new_bcast_inds = bcast - idx_removed
  760. # If we're broadcasting, nix blas
  761. if not len(idx_removed & bcast):
  762. do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
  763. else:
  764. do_blas = False
  765. # Last contraction
  766. if (cnum - len(path)) == -1:
  767. idx_result = output_subscript
  768. else:
  769. sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
  770. idx_result = "".join([x[1] for x in sorted(sort_result)])
  771. input_list.append(idx_result)
  772. broadcast_indices.append(new_bcast_inds)
  773. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  774. contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
  775. contraction_list.append(contraction)
  776. opt_cost = sum(cost_list) + 1
  777. if einsum_call_arg:
  778. return (operands, contraction_list)
  779. # Return the path along with a nice string representation
  780. overall_contraction = input_subscripts + "->" + output_subscript
  781. header = ("scaling", "current", "remaining")
  782. speedup = naive_cost / opt_cost
  783. max_i = max(size_list)
  784. path_print = " Complete contraction: %s\n" % overall_contraction
  785. path_print += " Naive scaling: %d\n" % len(indices)
  786. path_print += " Optimized scaling: %d\n" % max(scale_list)
  787. path_print += " Naive FLOP count: %.3e\n" % naive_cost
  788. path_print += " Optimized FLOP count: %.3e\n" % opt_cost
  789. path_print += " Theoretical speedup: %3.3f\n" % speedup
  790. path_print += " Largest intermediate: %.3e elements\n" % max_i
  791. path_print += "-" * 74 + "\n"
  792. path_print += "%6s %24s %40s\n" % header
  793. path_print += "-" * 74
  794. for n, contraction in enumerate(contraction_list):
  795. inds, idx_rm, einsum_str, remaining, blas = contraction
  796. remaining_str = ",".join(remaining) + "->" + output_subscript
  797. path_run = (scale_list[n], einsum_str, remaining_str)
  798. path_print += "\n%4d %24s %40s" % path_run
  799. path = ['einsum_path'] + path
  800. return (path, path_print)
  801. def _einsum_dispatcher(*operands, **kwargs):
  802. # Arguably we dispatch on more arguments that we really should; see note in
  803. # _einsum_path_dispatcher for why.
  804. for op in operands:
  805. yield op
  806. yield kwargs.get('out')
  807. # Rewrite einsum to handle different cases
  808. @array_function_dispatch(_einsum_dispatcher, module='numpy')
  809. def einsum(*operands, **kwargs):
  810. """
  811. einsum(subscripts, *operands, out=None, dtype=None, order='K',
  812. casting='safe', optimize=False)
  813. Evaluates the Einstein summation convention on the operands.
  814. Using the Einstein summation convention, many common multi-dimensional,
  815. linear algebraic array operations can be represented in a simple fashion.
  816. In *implicit* mode `einsum` computes these values.
  817. In *explicit* mode, `einsum` provides further flexibility to compute
  818. other array operations that might not be considered classical Einstein
  819. summation operations, by disabling, or forcing summation over specified
  820. subscript labels.
  821. See the notes and examples for clarification.
  822. Parameters
  823. ----------
  824. subscripts : str
  825. Specifies the subscripts for summation as comma separated list of
  826. subscript labels. An implicit (classical Einstein summation)
  827. calculation is performed unless the explicit indicator '->' is
  828. included as well as subscript labels of the precise output form.
  829. operands : list of array_like
  830. These are the arrays for the operation.
  831. out : ndarray, optional
  832. If provided, the calculation is done into this array.
  833. dtype : {data-type, None}, optional
  834. If provided, forces the calculation to use the data type specified.
  835. Note that you may have to also give a more liberal `casting`
  836. parameter to allow the conversions. Default is None.
  837. order : {'C', 'F', 'A', 'K'}, optional
  838. Controls the memory layout of the output. 'C' means it should
  839. be C contiguous. 'F' means it should be Fortran contiguous,
  840. 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
  841. 'K' means it should be as close to the layout as the inputs as
  842. is possible, including arbitrarily permuted axes.
  843. Default is 'K'.
  844. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
  845. Controls what kind of data casting may occur. Setting this to
  846. 'unsafe' is not recommended, as it can adversely affect accumulations.
  847. * 'no' means the data types should not be cast at all.
  848. * 'equiv' means only byte-order changes are allowed.
  849. * 'safe' means only casts which can preserve values are allowed.
  850. * 'same_kind' means only safe casts or casts within a kind,
  851. like float64 to float32, are allowed.
  852. * 'unsafe' means any data conversions may be done.
  853. Default is 'safe'.
  854. optimize : {False, True, 'greedy', 'optimal'}, optional
  855. Controls if intermediate optimization should occur. No optimization
  856. will occur if False and True will default to the 'greedy' algorithm.
  857. Also accepts an explicit contraction list from the ``np.einsum_path``
  858. function. See ``np.einsum_path`` for more details. Defaults to False.
  859. Returns
  860. -------
  861. output : ndarray
  862. The calculation based on the Einstein summation convention.
  863. See Also
  864. --------
  865. einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
  866. Notes
  867. -----
  868. .. versionadded:: 1.6.0
  869. The Einstein summation convention can be used to compute
  870. many multi-dimensional, linear algebraic array operations. `einsum`
  871. provides a succinct way of representing these.
  872. A non-exhaustive list of these operations,
  873. which can be computed by `einsum`, is shown below along with examples:
  874. * Trace of an array, :py:func:`numpy.trace`.
  875. * Return a diagonal, :py:func:`numpy.diag`.
  876. * Array axis summations, :py:func:`numpy.sum`.
  877. * Transpositions and permutations, :py:func:`numpy.transpose`.
  878. * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
  879. * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
  880. * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
  881. * Tensor contractions, :py:func:`numpy.tensordot`.
  882. * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
  883. The subscripts string is a comma-separated list of subscript labels,
  884. where each label refers to a dimension of the corresponding operand.
  885. Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
  886. is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
  887. appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
  888. view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
  889. describes traditional matrix multiplication and is equivalent to
  890. :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
  891. operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
  892. to :py:func:`np.trace(a) <numpy.trace>`.
  893. In *implicit mode*, the chosen subscripts are important
  894. since the axes of the output are reordered alphabetically. This
  895. means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
  896. ``np.einsum('ji', a)`` takes its transpose. Additionally,
  897. ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
  898. ``np.einsum('ij,jh', a, b)`` returns the transpose of the
  899. multiplication since subscript 'h' precedes subscript 'i'.
  900. In *explicit mode* the output can be directly controlled by
  901. specifying output subscript labels. This requires the
  902. identifier '->' as well as the list of output subscript labels.
  903. This feature increases the flexibility of the function since
  904. summing can be disabled or forced when required. The call
  905. ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
  906. and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
  907. The difference is that `einsum` does not allow broadcasting by default.
  908. Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
  909. order of the output subscript labels and therefore returns matrix
  910. multiplication, unlike the example above in implicit mode.
  911. To enable and control broadcasting, use an ellipsis. Default
  912. NumPy-style broadcasting is done by adding an ellipsis
  913. to the left of each term, like ``np.einsum('...ii->...i', a)``.
  914. To take the trace along the first and last axes,
  915. you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
  916. product with the left-most indices instead of rightmost, one can do
  917. ``np.einsum('ij...,jk...->ik...', a, b)``.
  918. When there is only one operand, no axes are summed, and no output
  919. parameter is provided, a view into the operand is returned instead
  920. of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
  921. produces a view (changed in version 1.10.0).
  922. `einsum` also provides an alternative way to provide the subscripts
  923. and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
  924. If the output shape is not provided in this format `einsum` will be
  925. calculated in implicit mode, otherwise it will be performed explicitly.
  926. The examples below have corresponding `einsum` calls with the two
  927. parameter methods.
  928. .. versionadded:: 1.10.0
  929. Views returned from einsum are now writeable whenever the input array
  930. is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
  931. have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
  932. and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
  933. of a 2D array.
  934. .. versionadded:: 1.12.0
  935. Added the ``optimize`` argument which will optimize the contraction order
  936. of an einsum expression. For a contraction with three or more operands this
  937. can greatly increase the computational efficiency at the cost of a larger
  938. memory footprint during computation.
  939. Typically a 'greedy' algorithm is applied which empirical tests have shown
  940. returns the optimal path in the majority of cases. In some cases 'optimal'
  941. will return the superlative path through a more expensive, exhaustive search.
  942. For iterative calculations it may be advisable to calculate the optimal path
  943. once and reuse that path by supplying it as an argument. An example is given
  944. below.
  945. See :py:func:`numpy.einsum_path` for more details.
  946. Examples
  947. --------
  948. >>> a = np.arange(25).reshape(5,5)
  949. >>> b = np.arange(5)
  950. >>> c = np.arange(6).reshape(2,3)
  951. Trace of a matrix:
  952. >>> np.einsum('ii', a)
  953. 60
  954. >>> np.einsum(a, [0,0])
  955. 60
  956. >>> np.trace(a)
  957. 60
  958. Extract the diagonal (requires explicit form):
  959. >>> np.einsum('ii->i', a)
  960. array([ 0, 6, 12, 18, 24])
  961. >>> np.einsum(a, [0,0], [0])
  962. array([ 0, 6, 12, 18, 24])
  963. >>> np.diag(a)
  964. array([ 0, 6, 12, 18, 24])
  965. Sum over an axis (requires explicit form):
  966. >>> np.einsum('ij->i', a)
  967. array([ 10, 35, 60, 85, 110])
  968. >>> np.einsum(a, [0,1], [0])
  969. array([ 10, 35, 60, 85, 110])
  970. >>> np.sum(a, axis=1)
  971. array([ 10, 35, 60, 85, 110])
  972. For higher dimensional arrays summing a single axis can be done with ellipsis:
  973. >>> np.einsum('...j->...', a)
  974. array([ 10, 35, 60, 85, 110])
  975. >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
  976. array([ 10, 35, 60, 85, 110])
  977. Compute a matrix transpose, or reorder any number of axes:
  978. >>> np.einsum('ji', c)
  979. array([[0, 3],
  980. [1, 4],
  981. [2, 5]])
  982. >>> np.einsum('ij->ji', c)
  983. array([[0, 3],
  984. [1, 4],
  985. [2, 5]])
  986. >>> np.einsum(c, [1,0])
  987. array([[0, 3],
  988. [1, 4],
  989. [2, 5]])
  990. >>> np.transpose(c)
  991. array([[0, 3],
  992. [1, 4],
  993. [2, 5]])
  994. Vector inner products:
  995. >>> np.einsum('i,i', b, b)
  996. 30
  997. >>> np.einsum(b, [0], b, [0])
  998. 30
  999. >>> np.inner(b,b)
  1000. 30
  1001. Matrix vector multiplication:
  1002. >>> np.einsum('ij,j', a, b)
  1003. array([ 30, 80, 130, 180, 230])
  1004. >>> np.einsum(a, [0,1], b, [1])
  1005. array([ 30, 80, 130, 180, 230])
  1006. >>> np.dot(a, b)
  1007. array([ 30, 80, 130, 180, 230])
  1008. >>> np.einsum('...j,j', a, b)
  1009. array([ 30, 80, 130, 180, 230])
  1010. Broadcasting and scalar multiplication:
  1011. >>> np.einsum('..., ...', 3, c)
  1012. array([[ 0, 3, 6],
  1013. [ 9, 12, 15]])
  1014. >>> np.einsum(',ij', 3, c)
  1015. array([[ 0, 3, 6],
  1016. [ 9, 12, 15]])
  1017. >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
  1018. array([[ 0, 3, 6],
  1019. [ 9, 12, 15]])
  1020. >>> np.multiply(3, c)
  1021. array([[ 0, 3, 6],
  1022. [ 9, 12, 15]])
  1023. Vector outer product:
  1024. >>> np.einsum('i,j', np.arange(2)+1, b)
  1025. array([[0, 1, 2, 3, 4],
  1026. [0, 2, 4, 6, 8]])
  1027. >>> np.einsum(np.arange(2)+1, [0], b, [1])
  1028. array([[0, 1, 2, 3, 4],
  1029. [0, 2, 4, 6, 8]])
  1030. >>> np.outer(np.arange(2)+1, b)
  1031. array([[0, 1, 2, 3, 4],
  1032. [0, 2, 4, 6, 8]])
  1033. Tensor contraction:
  1034. >>> a = np.arange(60.).reshape(3,4,5)
  1035. >>> b = np.arange(24.).reshape(4,3,2)
  1036. >>> np.einsum('ijk,jil->kl', a, b)
  1037. array([[ 4400., 4730.],
  1038. [ 4532., 4874.],
  1039. [ 4664., 5018.],
  1040. [ 4796., 5162.],
  1041. [ 4928., 5306.]])
  1042. >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
  1043. array([[ 4400., 4730.],
  1044. [ 4532., 4874.],
  1045. [ 4664., 5018.],
  1046. [ 4796., 5162.],
  1047. [ 4928., 5306.]])
  1048. >>> np.tensordot(a,b, axes=([1,0],[0,1]))
  1049. array([[ 4400., 4730.],
  1050. [ 4532., 4874.],
  1051. [ 4664., 5018.],
  1052. [ 4796., 5162.],
  1053. [ 4928., 5306.]])
  1054. Writeable returned arrays (since version 1.10.0):
  1055. >>> a = np.zeros((3, 3))
  1056. >>> np.einsum('ii->i', a)[:] = 1
  1057. >>> a
  1058. array([[ 1., 0., 0.],
  1059. [ 0., 1., 0.],
  1060. [ 0., 0., 1.]])
  1061. Example of ellipsis use:
  1062. >>> a = np.arange(6).reshape((3,2))
  1063. >>> b = np.arange(12).reshape((4,3))
  1064. >>> np.einsum('ki,jk->ij', a, b)
  1065. array([[10, 28, 46, 64],
  1066. [13, 40, 67, 94]])
  1067. >>> np.einsum('ki,...k->i...', a, b)
  1068. array([[10, 28, 46, 64],
  1069. [13, 40, 67, 94]])
  1070. >>> np.einsum('k...,jk', a, b)
  1071. array([[10, 28, 46, 64],
  1072. [13, 40, 67, 94]])
  1073. Chained array operations. For more complicated contractions, speed ups
  1074. might be achieved by repeatedly computing a 'greedy' path or pre-computing the
  1075. 'optimal' path and repeatedly applying it, using an
  1076. `einsum_path` insertion (since version 1.12.0). Performance improvements can be
  1077. particularly significant with larger arrays:
  1078. >>> a = np.ones(64).reshape(2,4,8)
  1079. # Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
  1080. >>> for iteration in range(500):
  1081. ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
  1082. # Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
  1083. >>> for iteration in range(500):
  1084. ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
  1085. # Greedy `einsum` (faster optimal path approximation): ~160ms
  1086. >>> for iteration in range(500):
  1087. ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
  1088. # Optimal `einsum` (best usage pattern in some use cases): ~110ms
  1089. >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
  1090. >>> for iteration in range(500):
  1091. ... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
  1092. """
  1093. # Grab non-einsum kwargs; do not optimize by default.
  1094. optimize_arg = kwargs.pop('optimize', False)
  1095. # If no optimization, run pure einsum
  1096. if optimize_arg is False:
  1097. return c_einsum(*operands, **kwargs)
  1098. valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
  1099. einsum_kwargs = {k: v for (k, v) in kwargs.items() if
  1100. k in valid_einsum_kwargs}
  1101. # Make sure all keywords are valid
  1102. valid_contract_kwargs = ['optimize'] + valid_einsum_kwargs
  1103. unknown_kwargs = [k for (k, v) in kwargs.items() if
  1104. k not in valid_contract_kwargs]
  1105. if len(unknown_kwargs):
  1106. raise TypeError("Did not understand the following kwargs: %s"
  1107. % unknown_kwargs)
  1108. # Special handeling if out is specified
  1109. specified_out = False
  1110. out_array = einsum_kwargs.pop('out', None)
  1111. if out_array is not None:
  1112. specified_out = True
  1113. # Build the contraction list and operand
  1114. operands, contraction_list = einsum_path(*operands, optimize=optimize_arg,
  1115. einsum_call=True)
  1116. handle_out = False
  1117. # Start contraction loop
  1118. for num, contraction in enumerate(contraction_list):
  1119. inds, idx_rm, einsum_str, remaining, blas = contraction
  1120. tmp_operands = [operands.pop(x) for x in inds]
  1121. # Do we need to deal with the output?
  1122. handle_out = specified_out and ((num + 1) == len(contraction_list))
  1123. # Call tensordot if still possible
  1124. if blas:
  1125. # Checks have already been handled
  1126. input_str, results_index = einsum_str.split('->')
  1127. input_left, input_right = input_str.split(',')
  1128. tensor_result = input_left + input_right
  1129. for s in idx_rm:
  1130. tensor_result = tensor_result.replace(s, "")
  1131. # Find indices to contract over
  1132. left_pos, right_pos = [], []
  1133. for s in sorted(idx_rm):
  1134. left_pos.append(input_left.find(s))
  1135. right_pos.append(input_right.find(s))
  1136. # Contract!
  1137. new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
  1138. # Build a new view if needed
  1139. if (tensor_result != results_index) or handle_out:
  1140. if handle_out:
  1141. einsum_kwargs["out"] = out_array
  1142. new_view = c_einsum(tensor_result + '->' + results_index, new_view, **einsum_kwargs)
  1143. # Call einsum
  1144. else:
  1145. # If out was specified
  1146. if handle_out:
  1147. einsum_kwargs["out"] = out_array
  1148. # Do the contraction
  1149. new_view = c_einsum(einsum_str, *tmp_operands, **einsum_kwargs)
  1150. # Append new items and dereference what we can
  1151. operands.append(new_view)
  1152. del tmp_operands, new_view
  1153. if specified_out:
  1154. return out_array
  1155. else:
  1156. return operands[0]