test_feather.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """ test feather-format compat """
  2. from distutils.version import LooseVersion
  3. import numpy as np
  4. import pytest
  5. import pandas as pd
  6. import pandas.util.testing as tm
  7. from pandas.util.testing import assert_frame_equal, ensure_clean
  8. from pandas.io.feather_format import read_feather, to_feather # noqa:E402
  9. pyarrow = pytest.importorskip('pyarrow')
  10. pyarrow_version = LooseVersion(pyarrow.__version__)
  11. @pytest.mark.single
  12. class TestFeather(object):
  13. def check_error_on_write(self, df, exc):
  14. # check that we are raising the exception
  15. # on writing
  16. with pytest.raises(exc):
  17. with ensure_clean() as path:
  18. to_feather(df, path)
  19. def check_round_trip(self, df, expected=None, **kwargs):
  20. if expected is None:
  21. expected = df
  22. with ensure_clean() as path:
  23. to_feather(df, path)
  24. result = read_feather(path, **kwargs)
  25. assert_frame_equal(result, expected)
  26. def test_error(self):
  27. for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'),
  28. np.array([1, 2, 3])]:
  29. self.check_error_on_write(obj, ValueError)
  30. def test_basic(self):
  31. df = pd.DataFrame({'string': list('abc'),
  32. 'int': list(range(1, 4)),
  33. 'uint': np.arange(3, 6).astype('u1'),
  34. 'float': np.arange(4.0, 7.0, dtype='float64'),
  35. 'float_with_null': [1., np.nan, 3],
  36. 'bool': [True, False, True],
  37. 'bool_with_null': [True, np.nan, False],
  38. 'cat': pd.Categorical(list('abc')),
  39. 'dt': pd.date_range('20130101', periods=3),
  40. 'dttz': pd.date_range('20130101', periods=3,
  41. tz='US/Eastern'),
  42. 'dt_with_null': [pd.Timestamp('20130101'), pd.NaT,
  43. pd.Timestamp('20130103')],
  44. 'dtns': pd.date_range('20130101', periods=3,
  45. freq='ns')})
  46. assert df.dttz.dtype.tz.zone == 'US/Eastern'
  47. self.check_round_trip(df)
  48. def test_duplicate_columns(self):
  49. # https://github.com/wesm/feather/issues/53
  50. # not currently able to handle duplicate columns
  51. df = pd.DataFrame(np.arange(12).reshape(4, 3),
  52. columns=list('aaa')).copy()
  53. self.check_error_on_write(df, ValueError)
  54. def test_stringify_columns(self):
  55. df = pd.DataFrame(np.arange(12).reshape(4, 3)).copy()
  56. self.check_error_on_write(df, ValueError)
  57. def test_read_columns(self):
  58. # GH 24025
  59. df = pd.DataFrame({'col1': list('abc'),
  60. 'col2': list(range(1, 4)),
  61. 'col3': list('xyz'),
  62. 'col4': list(range(4, 7))})
  63. columns = ['col1', 'col3']
  64. self.check_round_trip(df, expected=df[columns],
  65. columns=columns)
  66. def test_unsupported_other(self):
  67. # period
  68. df = pd.DataFrame({'a': pd.period_range('2013', freq='M', periods=3)})
  69. # Some versions raise ValueError, others raise ArrowInvalid.
  70. self.check_error_on_write(df, Exception)
  71. def test_rw_nthreads(self):
  72. df = pd.DataFrame({'A': np.arange(100000)})
  73. expected_warning = (
  74. "the 'nthreads' keyword is deprecated, "
  75. "use 'use_threads' instead"
  76. )
  77. # TODO: make the warning work with check_stacklevel=True
  78. with tm.assert_produces_warning(
  79. FutureWarning, check_stacklevel=False) as w:
  80. self.check_round_trip(df, nthreads=2)
  81. # we have an extra FutureWarning because of #GH23752
  82. assert any(expected_warning in str(x) for x in w)
  83. # TODO: make the warning work with check_stacklevel=True
  84. with tm.assert_produces_warning(
  85. FutureWarning, check_stacklevel=False) as w:
  86. self.check_round_trip(df, nthreads=1)
  87. # we have an extra FutureWarnings because of #GH23752
  88. assert any(expected_warning in str(x) for x in w)
  89. def test_rw_use_threads(self):
  90. df = pd.DataFrame({'A': np.arange(100000)})
  91. self.check_round_trip(df, use_threads=True)
  92. self.check_round_trip(df, use_threads=False)
  93. def test_write_with_index(self):
  94. df = pd.DataFrame({'A': [1, 2, 3]})
  95. self.check_round_trip(df)
  96. # non-default index
  97. for index in [[2, 3, 4],
  98. pd.date_range('20130101', periods=3),
  99. list('abc'),
  100. [1, 3, 4],
  101. pd.MultiIndex.from_tuples([('a', 1), ('a', 2),
  102. ('b', 1)]),
  103. ]:
  104. df.index = index
  105. self.check_error_on_write(df, ValueError)
  106. # index with meta-data
  107. df.index = [0, 1, 2]
  108. df.index.name = 'foo'
  109. self.check_error_on_write(df, ValueError)
  110. # column multi-index
  111. df.index = [0, 1, 2]
  112. df.columns = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]),
  113. self.check_error_on_write(df, ValueError)
  114. def test_path_pathlib(self):
  115. df = tm.makeDataFrame().reset_index()
  116. result = tm.round_trip_pathlib(df.to_feather, pd.read_feather)
  117. tm.assert_frame_equal(df, result)
  118. def test_path_localpath(self):
  119. df = tm.makeDataFrame().reset_index()
  120. result = tm.round_trip_localpath(df.to_feather, pd.read_feather)
  121. tm.assert_frame_equal(df, result)