_warnings.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # From scikit-image: https://github.com/scikit-image/scikit-image/blob/c2f8c4ab123ebe5f7b827bc495625a32bb225c10/skimage/_shared/_warnings.py
  2. # Licensed under modified BSD license
  3. __all__ = ['all_warnings', 'expected_warnings']
  4. from contextlib import contextmanager
  5. import sys
  6. import warnings
  7. import inspect
  8. import re
  9. @contextmanager
  10. def all_warnings():
  11. """
  12. Context for use in testing to ensure that all warnings are raised.
  13. Examples
  14. --------
  15. >>> import warnings
  16. >>> def foo():
  17. ... warnings.warn(RuntimeWarning("bar"))
  18. We raise the warning once, while the warning filter is set to "once".
  19. Hereafter, the warning is invisible, even with custom filters:
  20. >>> with warnings.catch_warnings():
  21. ... warnings.simplefilter('once')
  22. ... foo()
  23. We can now run ``foo()`` without a warning being raised:
  24. >>> from numpy.testing import assert_warns
  25. >>> foo()
  26. To catch the warning, we call in the help of ``all_warnings``:
  27. >>> with all_warnings():
  28. ... assert_warns(RuntimeWarning, foo)
  29. """
  30. # Whenever a warning is triggered, Python adds a __warningregistry__
  31. # member to the *calling* module. The exercize here is to find
  32. # and eradicate all those breadcrumbs that were left lying around.
  33. #
  34. # We proceed by first searching all parent calling frames and explicitly
  35. # clearing their warning registries (necessary for the doctests above to
  36. # pass). Then, we search for all submodules of skimage and clear theirs
  37. # as well (necessary for the skimage test suite to pass).
  38. frame = inspect.currentframe()
  39. if frame:
  40. for f in inspect.getouterframes(frame):
  41. f[0].f_locals['__warningregistry__'] = {}
  42. del frame
  43. for mod_name, mod in list(sys.modules.items()):
  44. if 'six.moves' in mod_name:
  45. continue
  46. try:
  47. mod.__warningregistry__.clear()
  48. except AttributeError:
  49. pass
  50. with warnings.catch_warnings(record=True) as w:
  51. warnings.simplefilter("always")
  52. yield w
  53. @contextmanager
  54. def expected_warnings(matching):
  55. """Context for use in testing to catch known warnings matching regexes
  56. Parameters
  57. ----------
  58. matching : list of strings or compiled regexes
  59. Regexes for the desired warning to catch
  60. Examples
  61. --------
  62. >>> from skimage import data, img_as_ubyte, img_as_float
  63. >>> with expected_warnings(['precision loss']):
  64. ... d = img_as_ubyte(img_as_float(data.coins()))
  65. Notes
  66. -----
  67. Uses `all_warnings` to ensure all warnings are raised.
  68. Upon exiting, it checks the recorded warnings for the desired matching
  69. pattern(s).
  70. Raises a ValueError if any match was not found or an unexpected
  71. warning was raised.
  72. Allows for three types of behaviors: "and", "or", and "optional" matches.
  73. This is done to accomodate different build enviroments or loop conditions
  74. that may produce different warnings. The behaviors can be combined.
  75. If you pass multiple patterns, you get an orderless "and", where all of the
  76. warnings must be raised.
  77. If you use the "|" operator in a pattern, you can catch one of several warnings.
  78. Finally, you can use "|\A\Z" in a pattern to signify it as optional.
  79. """
  80. with all_warnings() as w:
  81. # enter context
  82. yield w
  83. # exited user context, check the recorded warnings
  84. remaining = [m for m in matching if not '\A\Z' in m.split('|')]
  85. for warn in w:
  86. found = False
  87. for match in matching:
  88. if re.search(match, str(warn.message)) is not None:
  89. found = True
  90. if match in remaining:
  91. remaining.remove(match)
  92. if not found:
  93. raise ValueError('Unexpected warning: %s' % str(warn.message))
  94. if len(remaining) > 0:
  95. msg = 'No warning raised matching:\n%s' % '\n'.join(remaining)
  96. raise ValueError(msg)