entrypoints.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. """Discover and load entry points from installed packages."""
  2. # Copyright (c) Thomas Kluyver and contributors
  3. # Distributed under the terms of the MIT license; see LICENSE file.
  4. from contextlib import contextmanager
  5. import glob
  6. from importlib import import_module
  7. import io
  8. import itertools
  9. import os.path as osp
  10. import re
  11. import sys
  12. import warnings
  13. import zipfile
  14. if sys.version_info[0] >= 3:
  15. import configparser
  16. else:
  17. from backports import configparser
  18. entry_point_pattern = re.compile(r"""
  19. (?P<modulename>\w+(\.\w+)*)
  20. (:(?P<objectname>\w+(\.\w+)*))?
  21. \s*
  22. (\[(?P<extras>.+)\])?
  23. $
  24. """, re.VERBOSE)
  25. file_in_zip_pattern = re.compile(r"""
  26. (?P<dist_version>[^/\\]+)\.(dist|egg)-info
  27. [/\\]entry_points.txt$
  28. """, re.VERBOSE)
  29. __version__ = '0.3'
  30. class BadEntryPoint(Exception):
  31. """Raised when an entry point can't be parsed.
  32. """
  33. def __init__(self, epstr):
  34. self.epstr = epstr
  35. def __str__(self):
  36. return "Couldn't parse entry point spec: %r" % self.epstr
  37. @staticmethod
  38. @contextmanager
  39. def err_to_warnings():
  40. try:
  41. yield
  42. except BadEntryPoint as e:
  43. warnings.warn(str(e))
  44. class NoSuchEntryPoint(Exception):
  45. """Raised by :func:`get_single` when no matching entry point is found."""
  46. def __init__(self, group, name):
  47. self.group = group
  48. self.name = name
  49. def __str__(self):
  50. return "No {!r} entry point found in group {!r}".format(self.name, self.group)
  51. class CaseSensitiveConfigParser(configparser.ConfigParser):
  52. optionxform = staticmethod(str)
  53. class EntryPoint(object):
  54. def __init__(self, name, module_name, object_name, extras=None, distro=None):
  55. self.name = name
  56. self.module_name = module_name
  57. self.object_name = object_name
  58. self.extras = extras
  59. self.distro = distro
  60. def __repr__(self):
  61. return "EntryPoint(%r, %r, %r, %r)" % \
  62. (self.name, self.module_name, self.object_name, self.distro)
  63. def load(self):
  64. """Load the object to which this entry point refers.
  65. """
  66. mod = import_module(self.module_name)
  67. obj = mod
  68. if self.object_name:
  69. for attr in self.object_name.split('.'):
  70. obj = getattr(obj, attr)
  71. return obj
  72. @classmethod
  73. def from_string(cls, epstr, name, distro=None):
  74. """Parse an entry point from the syntax in entry_points.txt
  75. :param str epstr: The entry point string (not including 'name =')
  76. :param str name: The name of this entry point
  77. :param Distribution distro: The distribution in which the entry point was found
  78. :rtype: EntryPoint
  79. :raises BadEntryPoint: if *epstr* can't be parsed as an entry point.
  80. """
  81. m = entry_point_pattern.match(epstr)
  82. if m:
  83. mod, obj, extras = m.group('modulename', 'objectname', 'extras')
  84. if extras is not None:
  85. extras = re.split(r',\s*', extras)
  86. return cls(name, mod, obj, extras, distro)
  87. else:
  88. raise BadEntryPoint(epstr)
  89. class Distribution(object):
  90. def __init__(self, name, version):
  91. self.name = name
  92. self.version = version
  93. def __repr__(self):
  94. return "Distribution(%r, %r)" % (self.name, self.version)
  95. def iter_files_distros(path=None, repeated_distro='first'):
  96. if path is None:
  97. path = sys.path
  98. # Distributions found earlier in path will shadow those with the same name
  99. # found later. If these distributions used different module names, it may
  100. # actually be possible to import both, but in most cases this shadowing
  101. # will be correct.
  102. distro_names_seen = set()
  103. for folder in path:
  104. if folder.rstrip('/\\').endswith('.egg'):
  105. # Gah, eggs
  106. egg_name = osp.basename(folder)
  107. if '-' in egg_name:
  108. distro = Distribution(*egg_name.split('-')[:2])
  109. if (repeated_distro == 'first') \
  110. and (distro.name in distro_names_seen):
  111. continue
  112. distro_names_seen.add(distro.name)
  113. else:
  114. distro = None
  115. if osp.isdir(folder):
  116. ep_path = osp.join(folder, 'EGG-INFO', 'entry_points.txt')
  117. if osp.isfile(ep_path):
  118. cp = CaseSensitiveConfigParser(delimiters=('=',))
  119. cp.read([ep_path])
  120. yield cp, distro
  121. elif zipfile.is_zipfile(folder):
  122. z = zipfile.ZipFile(folder)
  123. try:
  124. info = z.getinfo('EGG-INFO/entry_points.txt')
  125. except KeyError:
  126. continue
  127. cp = CaseSensitiveConfigParser(delimiters=('=',))
  128. with z.open(info) as f:
  129. fu = io.TextIOWrapper(f)
  130. cp.read_file(fu,
  131. source=osp.join(folder, 'EGG-INFO', 'entry_points.txt'))
  132. yield cp, distro
  133. # zip imports, not egg
  134. elif zipfile.is_zipfile(folder):
  135. with zipfile.ZipFile(folder) as zf:
  136. for info in zf.infolist():
  137. m = file_in_zip_pattern.match(info.filename)
  138. if not m:
  139. continue
  140. distro_name_version = m.group('dist_version')
  141. if '-' in distro_name_version:
  142. distro = Distribution(*distro_name_version.split('-', 1))
  143. if (repeated_distro == 'first') \
  144. and (distro.name in distro_names_seen):
  145. continue
  146. distro_names_seen.add(distro.name)
  147. else:
  148. distro = None
  149. cp = CaseSensitiveConfigParser(delimiters=('=',))
  150. with zf.open(info) as f:
  151. fu = io.TextIOWrapper(f)
  152. cp.read_file(fu, source=osp.join(folder, info.filename))
  153. yield cp, distro
  154. # Regular file imports (not egg, not zip file)
  155. for path in itertools.chain(
  156. glob.iglob(osp.join(folder, '*.dist-info', 'entry_points.txt')),
  157. glob.iglob(osp.join(folder, '*.egg-info', 'entry_points.txt'))
  158. ):
  159. distro_name_version = osp.splitext(osp.basename(osp.dirname(path)))[0]
  160. if '-' in distro_name_version:
  161. distro = Distribution(*distro_name_version.split('-', 1))
  162. if (repeated_distro == 'first') \
  163. and (distro.name in distro_names_seen):
  164. continue
  165. distro_names_seen.add(distro.name)
  166. else:
  167. distro = None
  168. cp = CaseSensitiveConfigParser(delimiters=('=',))
  169. cp.read([path])
  170. yield cp, distro
  171. def get_single(group, name, path=None):
  172. """Find a single entry point.
  173. Returns an :class:`EntryPoint` object, or raises :exc:`NoSuchEntryPoint`
  174. if no match is found.
  175. """
  176. for config, distro in iter_files_distros(path=path):
  177. if (group in config) and (name in config[group]):
  178. epstr = config[group][name]
  179. with BadEntryPoint.err_to_warnings():
  180. return EntryPoint.from_string(epstr, name, distro)
  181. raise NoSuchEntryPoint(group, name)
  182. def get_group_named(group, path=None):
  183. """Find a group of entry points with unique names.
  184. Returns a dictionary of names to :class:`EntryPoint` objects.
  185. """
  186. result = {}
  187. for ep in get_group_all(group, path=path):
  188. if ep.name not in result:
  189. result[ep.name] = ep
  190. return result
  191. def get_group_all(group, path=None):
  192. """Find all entry points in a group.
  193. Returns a list of :class:`EntryPoint` objects.
  194. """
  195. result = []
  196. for config, distro in iter_files_distros(path=path):
  197. if group in config:
  198. for name, epstr in config[group].items():
  199. with BadEntryPoint.err_to_warnings():
  200. result.append(EntryPoint.from_string(epstr, name, distro))
  201. return result
  202. if __name__ == '__main__':
  203. import pprint
  204. pprint.pprint(get_group_all('console_scripts'))