test_register_accessor.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import contextlib
  2. import pytest
  3. import pandas as pd
  4. import pandas.util.testing as tm
  5. @contextlib.contextmanager
  6. def ensure_removed(obj, attr):
  7. """Ensure that an attribute added to 'obj' during the test is
  8. removed when we're done"""
  9. try:
  10. yield
  11. finally:
  12. try:
  13. delattr(obj, attr)
  14. except AttributeError:
  15. pass
  16. obj._accessors.discard(attr)
  17. class MyAccessor(object):
  18. def __init__(self, obj):
  19. self.obj = obj
  20. self.item = 'item'
  21. @property
  22. def prop(self):
  23. return self.item
  24. def method(self):
  25. return self.item
  26. @pytest.mark.parametrize('obj, registrar', [
  27. (pd.Series, pd.api.extensions.register_series_accessor),
  28. (pd.DataFrame, pd.api.extensions.register_dataframe_accessor),
  29. (pd.Index, pd.api.extensions.register_index_accessor)
  30. ])
  31. def test_register(obj, registrar):
  32. with ensure_removed(obj, 'mine'):
  33. before = set(dir(obj))
  34. registrar('mine')(MyAccessor)
  35. assert obj([]).mine.prop == 'item'
  36. after = set(dir(obj))
  37. assert (before ^ after) == {'mine'}
  38. assert 'mine' in obj._accessors
  39. def test_accessor_works():
  40. with ensure_removed(pd.Series, 'mine'):
  41. pd.api.extensions.register_series_accessor('mine')(MyAccessor)
  42. s = pd.Series([1, 2])
  43. assert s.mine.obj is s
  44. assert s.mine.prop == 'item'
  45. assert s.mine.method() == 'item'
  46. def test_overwrite_warns():
  47. # Need to restore mean
  48. mean = pd.Series.mean
  49. try:
  50. with tm.assert_produces_warning(UserWarning) as w:
  51. pd.api.extensions.register_series_accessor('mean')(MyAccessor)
  52. s = pd.Series([1, 2])
  53. assert s.mean.prop == 'item'
  54. msg = str(w[0].message)
  55. assert 'mean' in msg
  56. assert 'MyAccessor' in msg
  57. assert 'Series' in msg
  58. finally:
  59. pd.Series.mean = mean
  60. def test_raises_attribute_error():
  61. with ensure_removed(pd.Series, 'bad'):
  62. @pd.api.extensions.register_series_accessor("bad")
  63. class Bad(object):
  64. def __init__(self, data):
  65. raise AttributeError("whoops")
  66. with pytest.raises(AttributeError, match="whoops"):
  67. pd.Series([]).bad