test_pickle.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # pylint: disable=E1101,E1103,W0232
  2. """
  3. manage legacy pickle tests
  4. How to add pickle tests:
  5. 1. Install pandas version intended to output the pickle.
  6. 2. Execute "generate_legacy_storage_files.py" to create the pickle.
  7. $ python generate_legacy_storage_files.py <output_dir> pickle
  8. 3. Move the created pickle to "data/legacy_pickle/<version>" directory.
  9. """
  10. from distutils.version import LooseVersion
  11. import glob
  12. import os
  13. import shutil
  14. from warnings import catch_warnings, simplefilter
  15. import pytest
  16. from pandas.compat import PY3, is_platform_little_endian
  17. import pandas.util._test_decorators as td
  18. import pandas as pd
  19. from pandas import Index
  20. import pandas.util.testing as tm
  21. from pandas.tseries.offsets import Day, MonthEnd
  22. @pytest.fixture(scope='module')
  23. def current_pickle_data():
  24. # our current version pickle data
  25. from pandas.tests.io.generate_legacy_storage_files import (
  26. create_pickle_data)
  27. return create_pickle_data()
  28. # ---------------------
  29. # comparison functions
  30. # ---------------------
  31. def compare_element(result, expected, typ, version=None):
  32. if isinstance(expected, Index):
  33. tm.assert_index_equal(expected, result)
  34. return
  35. if typ.startswith('sp_'):
  36. comparator = getattr(tm, "assert_%s_equal" % typ)
  37. comparator(result, expected, exact_indices=False)
  38. elif typ == 'timestamp':
  39. if expected is pd.NaT:
  40. assert result is pd.NaT
  41. else:
  42. assert result == expected
  43. assert result.freq == expected.freq
  44. else:
  45. comparator = getattr(tm, "assert_%s_equal" %
  46. typ, tm.assert_almost_equal)
  47. comparator(result, expected)
  48. def compare(data, vf, version):
  49. # py3 compat when reading py2 pickle
  50. try:
  51. data = pd.read_pickle(vf)
  52. except (ValueError) as e:
  53. if 'unsupported pickle protocol:' in str(e):
  54. # trying to read a py3 pickle in py2
  55. return
  56. else:
  57. raise
  58. m = globals()
  59. for typ, dv in data.items():
  60. for dt, result in dv.items():
  61. try:
  62. expected = data[typ][dt]
  63. except (KeyError):
  64. if version in ('0.10.1', '0.11.0') and dt == 'reg':
  65. break
  66. else:
  67. raise
  68. # use a specific comparator
  69. # if available
  70. comparator = "compare_{typ}_{dt}".format(typ=typ, dt=dt)
  71. comparator = m.get(comparator, m['compare_element'])
  72. comparator(result, expected, typ, version)
  73. return data
  74. def compare_sp_series_ts(res, exp, typ, version):
  75. # SparseTimeSeries integrated into SparseSeries in 0.12.0
  76. # and deprecated in 0.17.0
  77. if version and LooseVersion(version) <= LooseVersion("0.12.0"):
  78. tm.assert_sp_series_equal(res, exp, check_series_type=False)
  79. else:
  80. tm.assert_sp_series_equal(res, exp)
  81. def compare_series_ts(result, expected, typ, version):
  82. # GH 7748
  83. tm.assert_series_equal(result, expected)
  84. assert result.index.freq == expected.index.freq
  85. assert not result.index.freq.normalize
  86. tm.assert_series_equal(result > 0, expected > 0)
  87. # GH 9291
  88. freq = result.index.freq
  89. assert freq + Day(1) == Day(2)
  90. res = freq + pd.Timedelta(hours=1)
  91. assert isinstance(res, pd.Timedelta)
  92. assert res == pd.Timedelta(days=1, hours=1)
  93. res = freq + pd.Timedelta(nanoseconds=1)
  94. assert isinstance(res, pd.Timedelta)
  95. assert res == pd.Timedelta(days=1, nanoseconds=1)
  96. def compare_series_dt_tz(result, expected, typ, version):
  97. # 8260
  98. # dtype is object < 0.17.0
  99. if LooseVersion(version) < LooseVersion('0.17.0'):
  100. expected = expected.astype(object)
  101. tm.assert_series_equal(result, expected)
  102. else:
  103. tm.assert_series_equal(result, expected)
  104. def compare_series_cat(result, expected, typ, version):
  105. # Categorical dtype is added in 0.15.0
  106. # ordered is changed in 0.16.0
  107. if LooseVersion(version) < LooseVersion('0.15.0'):
  108. tm.assert_series_equal(result, expected, check_dtype=False,
  109. check_categorical=False)
  110. elif LooseVersion(version) < LooseVersion('0.16.0'):
  111. tm.assert_series_equal(result, expected, check_categorical=False)
  112. else:
  113. tm.assert_series_equal(result, expected)
  114. def compare_frame_dt_mixed_tzs(result, expected, typ, version):
  115. # 8260
  116. # dtype is object < 0.17.0
  117. if LooseVersion(version) < LooseVersion('0.17.0'):
  118. expected = expected.astype(object)
  119. tm.assert_frame_equal(result, expected)
  120. else:
  121. tm.assert_frame_equal(result, expected)
  122. def compare_frame_cat_onecol(result, expected, typ, version):
  123. # Categorical dtype is added in 0.15.0
  124. # ordered is changed in 0.16.0
  125. if LooseVersion(version) < LooseVersion('0.15.0'):
  126. tm.assert_frame_equal(result, expected, check_dtype=False,
  127. check_categorical=False)
  128. elif LooseVersion(version) < LooseVersion('0.16.0'):
  129. tm.assert_frame_equal(result, expected, check_categorical=False)
  130. else:
  131. tm.assert_frame_equal(result, expected)
  132. def compare_frame_cat_and_float(result, expected, typ, version):
  133. compare_frame_cat_onecol(result, expected, typ, version)
  134. def compare_index_period(result, expected, typ, version):
  135. tm.assert_index_equal(result, expected)
  136. assert isinstance(result.freq, MonthEnd)
  137. assert result.freq == MonthEnd()
  138. assert result.freqstr == 'M'
  139. tm.assert_index_equal(result.shift(2), expected.shift(2))
  140. def compare_sp_frame_float(result, expected, typ, version):
  141. if LooseVersion(version) <= LooseVersion('0.18.1'):
  142. tm.assert_sp_frame_equal(result, expected, exact_indices=False,
  143. check_dtype=False)
  144. else:
  145. tm.assert_sp_frame_equal(result, expected)
  146. files = glob.glob(os.path.join(os.path.dirname(__file__), "data",
  147. "legacy_pickle", "*", "*.pickle"))
  148. @pytest.fixture(params=files)
  149. def legacy_pickle(request, datapath):
  150. return datapath(request.param)
  151. # ---------------------
  152. # tests
  153. # ---------------------
  154. def test_pickles(current_pickle_data, legacy_pickle):
  155. if not is_platform_little_endian():
  156. pytest.skip("known failure on non-little endian")
  157. version = os.path.basename(os.path.dirname(legacy_pickle))
  158. with catch_warnings(record=True):
  159. simplefilter("ignore")
  160. compare(current_pickle_data, legacy_pickle, version)
  161. def test_round_trip_current(current_pickle_data):
  162. try:
  163. import cPickle as c_pickle
  164. def c_pickler(obj, path):
  165. with open(path, 'wb') as fh:
  166. c_pickle.dump(obj, fh, protocol=-1)
  167. def c_unpickler(path):
  168. with open(path, 'rb') as fh:
  169. fh.seek(0)
  170. return c_pickle.load(fh)
  171. except ImportError:
  172. c_pickler = None
  173. c_unpickler = None
  174. import pickle as python_pickle
  175. def python_pickler(obj, path):
  176. with open(path, 'wb') as fh:
  177. python_pickle.dump(obj, fh, protocol=-1)
  178. def python_unpickler(path):
  179. with open(path, 'rb') as fh:
  180. fh.seek(0)
  181. return python_pickle.load(fh)
  182. data = current_pickle_data
  183. for typ, dv in data.items():
  184. for dt, expected in dv.items():
  185. for writer in [pd.to_pickle, c_pickler, python_pickler]:
  186. if writer is None:
  187. continue
  188. with tm.ensure_clean() as path:
  189. # test writing with each pickler
  190. writer(expected, path)
  191. # test reading with each unpickler
  192. result = pd.read_pickle(path)
  193. compare_element(result, expected, typ)
  194. if c_unpickler is not None:
  195. result = c_unpickler(path)
  196. compare_element(result, expected, typ)
  197. result = python_unpickler(path)
  198. compare_element(result, expected, typ)
  199. def test_pickle_v0_14_1(datapath):
  200. cat = pd.Categorical(values=['a', 'b', 'c'], ordered=False,
  201. categories=['a', 'b', 'c', 'd'])
  202. pickle_path = datapath('io', 'data', 'categorical_0_14_1.pickle')
  203. # This code was executed once on v0.14.1 to generate the pickle:
  204. #
  205. # cat = Categorical(labels=np.arange(3), levels=['a', 'b', 'c', 'd'],
  206. # name='foobar')
  207. # with open(pickle_path, 'wb') as f: pickle.dump(cat, f)
  208. #
  209. tm.assert_categorical_equal(cat, pd.read_pickle(pickle_path))
  210. def test_pickle_v0_15_2(datapath):
  211. # ordered -> _ordered
  212. # GH 9347
  213. cat = pd.Categorical(values=['a', 'b', 'c'], ordered=False,
  214. categories=['a', 'b', 'c', 'd'])
  215. pickle_path = datapath('io', 'data', 'categorical_0_15_2.pickle')
  216. # This code was executed once on v0.15.2 to generate the pickle:
  217. #
  218. # cat = Categorical(labels=np.arange(3), levels=['a', 'b', 'c', 'd'],
  219. # name='foobar')
  220. # with open(pickle_path, 'wb') as f: pickle.dump(cat, f)
  221. #
  222. tm.assert_categorical_equal(cat, pd.read_pickle(pickle_path))
  223. def test_pickle_path_pathlib():
  224. df = tm.makeDataFrame()
  225. result = tm.round_trip_pathlib(df.to_pickle, pd.read_pickle)
  226. tm.assert_frame_equal(df, result)
  227. def test_pickle_path_localpath():
  228. df = tm.makeDataFrame()
  229. result = tm.round_trip_localpath(df.to_pickle, pd.read_pickle)
  230. tm.assert_frame_equal(df, result)
  231. # ---------------------
  232. # test pickle compression
  233. # ---------------------
  234. @pytest.fixture
  235. def get_random_path():
  236. return u'__%s__.pickle' % tm.rands(10)
  237. class TestCompression(object):
  238. _compression_to_extension = {
  239. None: ".none",
  240. 'gzip': '.gz',
  241. 'bz2': '.bz2',
  242. 'zip': '.zip',
  243. 'xz': '.xz',
  244. }
  245. def compress_file(self, src_path, dest_path, compression):
  246. if compression is None:
  247. shutil.copyfile(src_path, dest_path)
  248. return
  249. if compression == 'gzip':
  250. import gzip
  251. f = gzip.open(dest_path, "w")
  252. elif compression == 'bz2':
  253. import bz2
  254. f = bz2.BZ2File(dest_path, "w")
  255. elif compression == 'zip':
  256. import zipfile
  257. with zipfile.ZipFile(dest_path, "w",
  258. compression=zipfile.ZIP_DEFLATED) as f:
  259. f.write(src_path, os.path.basename(src_path))
  260. elif compression == 'xz':
  261. lzma = pd.compat.import_lzma()
  262. f = lzma.LZMAFile(dest_path, "w")
  263. else:
  264. msg = 'Unrecognized compression type: {}'.format(compression)
  265. raise ValueError(msg)
  266. if compression != "zip":
  267. with open(src_path, "rb") as fh, f:
  268. f.write(fh.read())
  269. def test_write_explicit(self, compression, get_random_path):
  270. base = get_random_path
  271. path1 = base + ".compressed"
  272. path2 = base + ".raw"
  273. with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2:
  274. df = tm.makeDataFrame()
  275. # write to compressed file
  276. df.to_pickle(p1, compression=compression)
  277. # decompress
  278. with tm.decompress_file(p1, compression=compression) as f:
  279. with open(p2, "wb") as fh:
  280. fh.write(f.read())
  281. # read decompressed file
  282. df2 = pd.read_pickle(p2, compression=None)
  283. tm.assert_frame_equal(df, df2)
  284. @pytest.mark.parametrize('compression', ['', 'None', 'bad', '7z'])
  285. def test_write_explicit_bad(self, compression, get_random_path):
  286. with pytest.raises(ValueError, match="Unrecognized compression type"):
  287. with tm.ensure_clean(get_random_path) as path:
  288. df = tm.makeDataFrame()
  289. df.to_pickle(path, compression=compression)
  290. @pytest.mark.parametrize('ext', [
  291. '', '.gz', '.bz2', '.no_compress',
  292. pytest.param('.xz', marks=td.skip_if_no_lzma)
  293. ])
  294. def test_write_infer(self, ext, get_random_path):
  295. base = get_random_path
  296. path1 = base + ext
  297. path2 = base + ".raw"
  298. compression = None
  299. for c in self._compression_to_extension:
  300. if self._compression_to_extension[c] == ext:
  301. compression = c
  302. break
  303. with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2:
  304. df = tm.makeDataFrame()
  305. # write to compressed file by inferred compression method
  306. df.to_pickle(p1)
  307. # decompress
  308. with tm.decompress_file(p1, compression=compression) as f:
  309. with open(p2, "wb") as fh:
  310. fh.write(f.read())
  311. # read decompressed file
  312. df2 = pd.read_pickle(p2, compression=None)
  313. tm.assert_frame_equal(df, df2)
  314. def test_read_explicit(self, compression, get_random_path):
  315. base = get_random_path
  316. path1 = base + ".raw"
  317. path2 = base + ".compressed"
  318. with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2:
  319. df = tm.makeDataFrame()
  320. # write to uncompressed file
  321. df.to_pickle(p1, compression=None)
  322. # compress
  323. self.compress_file(p1, p2, compression=compression)
  324. # read compressed file
  325. df2 = pd.read_pickle(p2, compression=compression)
  326. tm.assert_frame_equal(df, df2)
  327. @pytest.mark.parametrize('ext', [
  328. '', '.gz', '.bz2', '.zip', '.no_compress',
  329. pytest.param('.xz', marks=td.skip_if_no_lzma)
  330. ])
  331. def test_read_infer(self, ext, get_random_path):
  332. base = get_random_path
  333. path1 = base + ".raw"
  334. path2 = base + ext
  335. compression = None
  336. for c in self._compression_to_extension:
  337. if self._compression_to_extension[c] == ext:
  338. compression = c
  339. break
  340. with tm.ensure_clean(path1) as p1, tm.ensure_clean(path2) as p2:
  341. df = tm.makeDataFrame()
  342. # write to uncompressed file
  343. df.to_pickle(p1, compression=None)
  344. # compress
  345. self.compress_file(p1, p2, compression=compression)
  346. # read compressed file by inferred compression method
  347. df2 = pd.read_pickle(p2)
  348. tm.assert_frame_equal(df, df2)
  349. # ---------------------
  350. # test pickle compression
  351. # ---------------------
  352. class TestProtocol(object):
  353. @pytest.mark.parametrize('protocol', [-1, 0, 1, 2])
  354. def test_read(self, protocol, get_random_path):
  355. with tm.ensure_clean(get_random_path) as path:
  356. df = tm.makeDataFrame()
  357. df.to_pickle(path, protocol=protocol)
  358. df2 = pd.read_pickle(path)
  359. tm.assert_frame_equal(df, df2)
  360. @pytest.mark.parametrize('protocol', [3, 4])
  361. @pytest.mark.skipif(PY3, reason="Testing invalid parameters for Python 2")
  362. def test_read_bad_versions(self, protocol, get_random_path):
  363. # For Python 2, HIGHEST_PROTOCOL should be 2.
  364. msg = ("pickle protocol {protocol} asked for; the highest available "
  365. "protocol is 2").format(protocol=protocol)
  366. with pytest.raises(ValueError, match=msg):
  367. with tm.ensure_clean(get_random_path) as path:
  368. df = tm.makeDataFrame()
  369. df.to_pickle(path, protocol=protocol)