test_interval.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. import numpy as np
  14. import pytest
  15. from pandas.core.dtypes.dtypes import IntervalDtype
  16. from pandas import Interval
  17. from pandas.core.arrays import IntervalArray
  18. from pandas.tests.extension import base
  19. def make_data():
  20. N = 100
  21. left = np.random.uniform(size=N).cumsum()
  22. right = left + np.random.uniform(size=N)
  23. return [Interval(l, r) for l, r in zip(left, right)]
  24. @pytest.fixture
  25. def dtype():
  26. return IntervalDtype()
  27. @pytest.fixture
  28. def data():
  29. """Length-100 PeriodArray for semantics test."""
  30. return IntervalArray(make_data())
  31. @pytest.fixture
  32. def data_missing():
  33. """Length 2 array with [NA, Valid]"""
  34. return IntervalArray.from_tuples([None, (0, 1)])
  35. @pytest.fixture
  36. def data_for_sorting():
  37. return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)])
  38. @pytest.fixture
  39. def data_missing_for_sorting():
  40. return IntervalArray.from_tuples([(1, 2), None, (0, 1)])
  41. @pytest.fixture
  42. def na_value():
  43. return np.nan
  44. @pytest.fixture
  45. def data_for_grouping():
  46. a = (0, 1)
  47. b = (1, 2)
  48. c = (2, 3)
  49. return IntervalArray.from_tuples([b, b, None, None, a, a, b, c])
  50. class BaseInterval(object):
  51. pass
  52. class TestDtype(BaseInterval, base.BaseDtypeTests):
  53. pass
  54. class TestCasting(BaseInterval, base.BaseCastingTests):
  55. pass
  56. class TestConstructors(BaseInterval, base.BaseConstructorsTests):
  57. pass
  58. class TestGetitem(BaseInterval, base.BaseGetitemTests):
  59. pass
  60. class TestGrouping(BaseInterval, base.BaseGroupbyTests):
  61. pass
  62. class TestInterface(BaseInterval, base.BaseInterfaceTests):
  63. pass
  64. class TestReduce(base.BaseNoReduceTests):
  65. pass
  66. class TestMethods(BaseInterval, base.BaseMethodsTests):
  67. @pytest.mark.skip(reason='addition is not defined for intervals')
  68. def test_combine_add(self, data_repeated):
  69. pass
  70. @pytest.mark.skip(reason="Not Applicable")
  71. def test_fillna_length_mismatch(self, data_missing):
  72. pass
  73. class TestMissing(BaseInterval, base.BaseMissingTests):
  74. # Index.fillna only accepts scalar `value`, so we have to skip all
  75. # non-scalar fill tests.
  76. unsupported_fill = pytest.mark.skip("Unsupported fillna option.")
  77. @unsupported_fill
  78. def test_fillna_limit_pad(self):
  79. pass
  80. @unsupported_fill
  81. def test_fillna_series_method(self):
  82. pass
  83. @unsupported_fill
  84. def test_fillna_limit_backfill(self):
  85. pass
  86. @unsupported_fill
  87. def test_fillna_series(self):
  88. pass
  89. def test_non_scalar_raises(self, data_missing):
  90. msg = "Got a 'list' instead."
  91. with pytest.raises(TypeError, match=msg):
  92. data_missing.fillna([1, 1])
  93. class TestReshaping(BaseInterval, base.BaseReshapingTests):
  94. pass
  95. class TestSetitem(BaseInterval, base.BaseSetitemTests):
  96. pass
  97. class TestPrinting(BaseInterval, base.BasePrintingTests):
  98. @pytest.mark.skip(reason="custom repr")
  99. def test_array_repr(self, data, size):
  100. pass
  101. class TestParsing(BaseInterval, base.BaseParsingTests):
  102. @pytest.mark.parametrize('engine', ['c', 'python'])
  103. def test_EA_types(self, engine, data):
  104. expected_msg = r'.*must implement _from_sequence_of_strings.*'
  105. with pytest.raises(NotImplementedError, match=expected_msg):
  106. super(TestParsing, self).test_EA_types(engine, data)