Utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. #
  2. # Cython -- Things that don't belong
  3. # anywhere else in particular
  4. #
  5. from __future__ import absolute_import
  6. try:
  7. from __builtin__ import basestring
  8. except ImportError:
  9. basestring = str
  10. import os
  11. import sys
  12. import re
  13. import io
  14. import codecs
  15. import shutil
  16. from contextlib import contextmanager
  17. modification_time = os.path.getmtime
  18. def cached_function(f):
  19. cache = {}
  20. uncomputed = object()
  21. def wrapper(*args):
  22. res = cache.get(args, uncomputed)
  23. if res is uncomputed:
  24. res = cache[args] = f(*args)
  25. return res
  26. wrapper.uncached = f
  27. return wrapper
  28. def cached_method(f):
  29. cache_name = '__%s_cache' % f.__name__
  30. def wrapper(self, *args):
  31. cache = getattr(self, cache_name, None)
  32. if cache is None:
  33. cache = {}
  34. setattr(self, cache_name, cache)
  35. if args in cache:
  36. return cache[args]
  37. res = cache[args] = f(self, *args)
  38. return res
  39. return wrapper
  40. def replace_suffix(path, newsuf):
  41. base, _ = os.path.splitext(path)
  42. return base + newsuf
  43. def open_new_file(path):
  44. if os.path.exists(path):
  45. # Make sure to create a new file here so we can
  46. # safely hard link the output files.
  47. os.unlink(path)
  48. # we use the ISO-8859-1 encoding here because we only write pure
  49. # ASCII strings or (e.g. for file names) byte encoded strings as
  50. # Unicode, so we need a direct mapping from the first 256 Unicode
  51. # characters to a byte sequence, which ISO-8859-1 provides
  52. # note: can't use io.open() in Py2 as we may be writing str objects
  53. return codecs.open(path, "w", encoding="ISO-8859-1")
  54. def castrate_file(path, st):
  55. # Remove junk contents from an output file after a
  56. # failed compilation.
  57. # Also sets access and modification times back to
  58. # those specified by st (a stat struct).
  59. try:
  60. f = open_new_file(path)
  61. except EnvironmentError:
  62. pass
  63. else:
  64. f.write(
  65. "#error Do not use this file, it is the result of a failed Cython compilation.\n")
  66. f.close()
  67. if st:
  68. os.utime(path, (st.st_atime, st.st_mtime-1))
  69. def file_newer_than(path, time):
  70. ftime = modification_time(path)
  71. return ftime > time
  72. def safe_makedirs(path):
  73. try:
  74. os.makedirs(path)
  75. except OSError:
  76. if not os.path.isdir(path):
  77. raise
  78. def copy_file_to_dir_if_newer(sourcefile, destdir):
  79. """
  80. Copy file sourcefile to directory destdir (creating it if needed),
  81. preserving metadata. If the destination file exists and is not
  82. older than the source file, the copying is skipped.
  83. """
  84. destfile = os.path.join(destdir, os.path.basename(sourcefile))
  85. try:
  86. desttime = modification_time(destfile)
  87. except OSError:
  88. # New file does not exist, destdir may or may not exist
  89. safe_makedirs(destdir)
  90. else:
  91. # New file already exists
  92. if not file_newer_than(sourcefile, desttime):
  93. return
  94. shutil.copy2(sourcefile, destfile)
  95. @cached_function
  96. def search_include_directories(dirs, qualified_name, suffix, pos,
  97. include=False, sys_path=False):
  98. # Search the list of include directories for the given
  99. # file name. If a source file position is given, first
  100. # searches the directory containing that file. Returns
  101. # None if not found, but does not report an error.
  102. # The 'include' option will disable package dereferencing.
  103. # If 'sys_path' is True, also search sys.path.
  104. if sys_path:
  105. dirs = dirs + tuple(sys.path)
  106. if pos:
  107. file_desc = pos[0]
  108. from Cython.Compiler.Scanning import FileSourceDescriptor
  109. if not isinstance(file_desc, FileSourceDescriptor):
  110. raise RuntimeError("Only file sources for code supported")
  111. if include:
  112. dirs = (os.path.dirname(file_desc.filename),) + dirs
  113. else:
  114. dirs = (find_root_package_dir(file_desc.filename),) + dirs
  115. dotted_filename = qualified_name
  116. if suffix:
  117. dotted_filename += suffix
  118. if not include:
  119. names = qualified_name.split('.')
  120. package_names = tuple(names[:-1])
  121. module_name = names[-1]
  122. module_filename = module_name + suffix
  123. package_filename = "__init__" + suffix
  124. for dir in dirs:
  125. path = os.path.join(dir, dotted_filename)
  126. if path_exists(path):
  127. return path
  128. if not include:
  129. package_dir = check_package_dir(dir, package_names)
  130. if package_dir is not None:
  131. path = os.path.join(package_dir, module_filename)
  132. if path_exists(path):
  133. return path
  134. path = os.path.join(dir, package_dir, module_name,
  135. package_filename)
  136. if path_exists(path):
  137. return path
  138. return None
  139. @cached_function
  140. def find_root_package_dir(file_path):
  141. dir = os.path.dirname(file_path)
  142. if file_path == dir:
  143. return dir
  144. elif is_package_dir(dir):
  145. return find_root_package_dir(dir)
  146. else:
  147. return dir
  148. @cached_function
  149. def check_package_dir(dir, package_names):
  150. for dirname in package_names:
  151. dir = os.path.join(dir, dirname)
  152. if not is_package_dir(dir):
  153. return None
  154. return dir
  155. @cached_function
  156. def is_package_dir(dir_path):
  157. for filename in ("__init__.py",
  158. "__init__.pyc",
  159. "__init__.pyx",
  160. "__init__.pxd"):
  161. path = os.path.join(dir_path, filename)
  162. if path_exists(path):
  163. return 1
  164. @cached_function
  165. def path_exists(path):
  166. # try on the filesystem first
  167. if os.path.exists(path):
  168. return True
  169. # figure out if a PEP 302 loader is around
  170. try:
  171. loader = __loader__
  172. # XXX the code below assumes a 'zipimport.zipimporter' instance
  173. # XXX should be easy to generalize, but too lazy right now to write it
  174. archive_path = getattr(loader, 'archive', None)
  175. if archive_path:
  176. normpath = os.path.normpath(path)
  177. if normpath.startswith(archive_path):
  178. arcname = normpath[len(archive_path)+1:]
  179. try:
  180. loader.get_data(arcname)
  181. return True
  182. except IOError:
  183. return False
  184. except NameError:
  185. pass
  186. return False
  187. # file name encodings
  188. def decode_filename(filename):
  189. if isinstance(filename, bytes):
  190. try:
  191. filename_encoding = sys.getfilesystemencoding()
  192. if filename_encoding is None:
  193. filename_encoding = sys.getdefaultencoding()
  194. filename = filename.decode(filename_encoding)
  195. except UnicodeDecodeError:
  196. pass
  197. return filename
  198. # support for source file encoding detection
  199. _match_file_encoding = re.compile(u"coding[:=]\s*([-\w.]+)").search
  200. def detect_file_encoding(source_filename):
  201. f = open_source_file(source_filename, encoding="UTF-8", error_handling='ignore')
  202. try:
  203. return detect_opened_file_encoding(f)
  204. finally:
  205. f.close()
  206. def detect_opened_file_encoding(f):
  207. # PEPs 263 and 3120
  208. # Most of the time the first two lines fall in the first 250 chars,
  209. # and this bulk read/split is much faster.
  210. lines = f.read(250).split(u"\n")
  211. if len(lines) > 1:
  212. m = _match_file_encoding(lines[0])
  213. if m:
  214. return m.group(1)
  215. elif len(lines) > 2:
  216. m = _match_file_encoding(lines[1])
  217. if m:
  218. return m.group(1)
  219. else:
  220. return "UTF-8"
  221. # Fallback to one-char-at-a-time detection.
  222. f.seek(0)
  223. chars = []
  224. for i in range(2):
  225. c = f.read(1)
  226. while c and c != u'\n':
  227. chars.append(c)
  228. c = f.read(1)
  229. encoding = _match_file_encoding(u''.join(chars))
  230. if encoding:
  231. return encoding.group(1)
  232. return "UTF-8"
  233. def skip_bom(f):
  234. """
  235. Read past a BOM at the beginning of a source file.
  236. This could be added to the scanner, but it's *substantially* easier
  237. to keep it at this level.
  238. """
  239. if f.read(1) != u'\uFEFF':
  240. f.seek(0)
  241. def open_source_file(source_filename, mode="r",
  242. encoding=None, error_handling=None):
  243. if encoding is None:
  244. # Most of the time the coding is unspecified, so be optimistic that
  245. # it's UTF-8.
  246. f = open_source_file(source_filename, encoding="UTF-8", mode=mode, error_handling='ignore')
  247. encoding = detect_opened_file_encoding(f)
  248. if encoding == "UTF-8" and error_handling == 'ignore':
  249. f.seek(0)
  250. skip_bom(f)
  251. return f
  252. else:
  253. f.close()
  254. if not os.path.exists(source_filename):
  255. try:
  256. loader = __loader__
  257. if source_filename.startswith(loader.archive):
  258. return open_source_from_loader(
  259. loader, source_filename,
  260. encoding, error_handling)
  261. except (NameError, AttributeError):
  262. pass
  263. stream = io.open(source_filename, mode=mode,
  264. encoding=encoding, errors=error_handling)
  265. skip_bom(stream)
  266. return stream
  267. def open_source_from_loader(loader,
  268. source_filename,
  269. encoding=None, error_handling=None):
  270. nrmpath = os.path.normpath(source_filename)
  271. arcname = nrmpath[len(loader.archive)+1:]
  272. data = loader.get_data(arcname)
  273. return io.TextIOWrapper(io.BytesIO(data),
  274. encoding=encoding,
  275. errors=error_handling)
  276. def str_to_number(value):
  277. # note: this expects a string as input that was accepted by the
  278. # parser already, with an optional "-" sign in front
  279. is_neg = False
  280. if value[:1] == '-':
  281. is_neg = True
  282. value = value[1:]
  283. if len(value) < 2:
  284. value = int(value, 0)
  285. elif value[0] == '0':
  286. literal_type = value[1] # 0'o' - 0'b' - 0'x'
  287. if literal_type in 'xX':
  288. # hex notation ('0x1AF')
  289. value = int(value[2:], 16)
  290. elif literal_type in 'oO':
  291. # Py3 octal notation ('0o136')
  292. value = int(value[2:], 8)
  293. elif literal_type in 'bB':
  294. # Py3 binary notation ('0b101')
  295. value = int(value[2:], 2)
  296. else:
  297. # Py2 octal notation ('0136')
  298. value = int(value, 8)
  299. else:
  300. value = int(value, 0)
  301. return -value if is_neg else value
  302. def long_literal(value):
  303. if isinstance(value, basestring):
  304. value = str_to_number(value)
  305. return not -2**31 <= value < 2**31
  306. @cached_function
  307. def get_cython_cache_dir():
  308. """get the cython cache dir
  309. Priority:
  310. 1. CYTHON_CACHE_DIR
  311. 2. (OS X): ~/Library/Caches/Cython
  312. (posix not OS X): XDG_CACHE_HOME/cython if XDG_CACHE_HOME defined
  313. 3. ~/.cython
  314. """
  315. if 'CYTHON_CACHE_DIR' in os.environ:
  316. return os.environ['CYTHON_CACHE_DIR']
  317. parent = None
  318. if os.name == 'posix':
  319. if sys.platform == 'darwin':
  320. parent = os.path.expanduser('~/Library/Caches')
  321. else:
  322. # this could fallback on ~/.cache
  323. parent = os.environ.get('XDG_CACHE_HOME')
  324. if parent and os.path.isdir(parent):
  325. return os.path.join(parent, 'cython')
  326. # last fallback: ~/.cython
  327. return os.path.expanduser(os.path.join('~', '.cython'))
  328. @contextmanager
  329. def captured_fd(stream=2, encoding=None):
  330. pipe_in = t = None
  331. orig_stream = os.dup(stream) # keep copy of original stream
  332. try:
  333. pipe_in, pipe_out = os.pipe()
  334. os.dup2(pipe_out, stream) # replace stream by copy of pipe
  335. try:
  336. os.close(pipe_out) # close original pipe-out stream
  337. data = []
  338. def copy():
  339. try:
  340. while True:
  341. d = os.read(pipe_in, 1000)
  342. if d:
  343. data.append(d)
  344. else:
  345. break
  346. finally:
  347. os.close(pipe_in)
  348. def get_output():
  349. output = b''.join(data)
  350. if encoding:
  351. output = output.decode(encoding)
  352. return output
  353. from threading import Thread
  354. t = Thread(target=copy)
  355. t.daemon = True # just in case
  356. t.start()
  357. yield get_output
  358. finally:
  359. os.dup2(orig_stream, stream) # restore original stream
  360. if t is not None:
  361. t.join()
  362. finally:
  363. os.close(orig_stream)
  364. def print_bytes(s, end=b'\n', file=sys.stdout, flush=True):
  365. file.flush()
  366. try:
  367. out = file.buffer # Py3
  368. except AttributeError:
  369. out = file # Py2
  370. out.write(s)
  371. if end:
  372. out.write(end)
  373. if flush:
  374. out.flush()
  375. class LazyStr:
  376. def __init__(self, callback):
  377. self.callback = callback
  378. def __str__(self):
  379. return self.callback()
  380. def __repr__(self):
  381. return self.callback()
  382. def __add__(self, right):
  383. return self.callback() + right
  384. def __radd__(self, left):
  385. return left + self.callback()
  386. class OrderedSet(object):
  387. def __init__(self, elements=()):
  388. self._list = []
  389. self._set = set()
  390. self.update(elements)
  391. def __iter__(self):
  392. return iter(self._list)
  393. def update(self, elements):
  394. for e in elements:
  395. self.add(e)
  396. def add(self, e):
  397. if e not in self._set:
  398. self._list.append(e)
  399. self._set.add(e)
  400. # Class decorator that adds a metaclass and recreates the class with it.
  401. # Copied from 'six'.
  402. def add_metaclass(metaclass):
  403. """Class decorator for creating a class with a metaclass."""
  404. def wrapper(cls):
  405. orig_vars = cls.__dict__.copy()
  406. slots = orig_vars.get('__slots__')
  407. if slots is not None:
  408. if isinstance(slots, str):
  409. slots = [slots]
  410. for slots_var in slots:
  411. orig_vars.pop(slots_var)
  412. orig_vars.pop('__dict__', None)
  413. orig_vars.pop('__weakref__', None)
  414. return metaclass(cls.__name__, cls.__bases__, orig_vars)
  415. return wrapper