lambertw.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Compute a Pade approximation for the principle branch of the
  2. Lambert W function around 0 and compare it to various other
  3. approximations.
  4. """
  5. from __future__ import division, print_function, absolute_import
  6. import numpy as np
  7. try:
  8. import mpmath
  9. import matplotlib.pyplot as plt
  10. except ImportError:
  11. pass
  12. def lambertw_pade():
  13. derivs = []
  14. for n in range(6):
  15. derivs.append(mpmath.diff(mpmath.lambertw, 0, n=n))
  16. p, q = mpmath.pade(derivs, 3, 2)
  17. return p, q
  18. def main():
  19. print(__doc__)
  20. with mpmath.workdps(50):
  21. p, q = lambertw_pade()
  22. p, q = p[::-1], q[::-1]
  23. print("p = {}".format(p))
  24. print("q = {}".format(q))
  25. x, y = np.linspace(-1.5, 1.5, 75), np.linspace(-1.5, 1.5, 75)
  26. x, y = np.meshgrid(x, y)
  27. z = x + 1j*y
  28. lambertw_std = []
  29. for z0 in z.flatten():
  30. lambertw_std.append(complex(mpmath.lambertw(z0)))
  31. lambertw_std = np.array(lambertw_std).reshape(x.shape)
  32. fig, axes = plt.subplots(nrows=3, ncols=1)
  33. # Compare Pade approximation to true result
  34. p = np.array([float(p0) for p0 in p])
  35. q = np.array([float(q0) for q0 in q])
  36. pade_approx = np.polyval(p, z)/np.polyval(q, z)
  37. pade_err = abs(pade_approx - lambertw_std)
  38. axes[0].pcolormesh(x, y, pade_err)
  39. # Compare two terms of asymptotic series to true result
  40. asy_approx = np.log(z) - np.log(np.log(z))
  41. asy_err = abs(asy_approx - lambertw_std)
  42. axes[1].pcolormesh(x, y, asy_err)
  43. # Compare two terms of the series around the branch point to the
  44. # true result
  45. p = np.sqrt(2*(np.exp(1)*z + 1))
  46. series_approx = -1 + p - p**2/3
  47. series_err = abs(series_approx - lambertw_std)
  48. im = axes[2].pcolormesh(x, y, series_err)
  49. fig.colorbar(im, ax=axes.ravel().tolist())
  50. plt.show()
  51. fig, ax = plt.subplots(nrows=1, ncols=1)
  52. pade_better = pade_err < asy_err
  53. im = ax.pcolormesh(x, y, pade_better)
  54. t = np.linspace(-0.3, 0.3)
  55. ax.plot(-2.5*abs(t) - 0.2, t, 'r')
  56. fig.colorbar(im, ax=ax)
  57. plt.show()
  58. if __name__ == '__main__':
  59. main()