test_array.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import datetime
  2. import decimal
  3. import numpy as np
  4. import pytest
  5. import pytz
  6. from pandas.core.dtypes.dtypes import registry
  7. import pandas as pd
  8. from pandas.api.extensions import register_extension_dtype
  9. from pandas.core.arrays import PandasArray, integer_array, period_array
  10. from pandas.tests.extension.decimal import (
  11. DecimalArray, DecimalDtype, to_decimal)
  12. import pandas.util.testing as tm
  13. @pytest.mark.parametrize("data, dtype, expected", [
  14. # Basic NumPy defaults.
  15. ([1, 2], None, PandasArray(np.array([1, 2]))),
  16. ([1, 2], object, PandasArray(np.array([1, 2], dtype=object))),
  17. ([1, 2], np.dtype('float32'),
  18. PandasArray(np.array([1., 2.0], dtype=np.dtype('float32')))),
  19. (np.array([1, 2]), None, PandasArray(np.array([1, 2]))),
  20. # String alias passes through to NumPy
  21. ([1, 2], 'float32', PandasArray(np.array([1, 2], dtype='float32'))),
  22. # Period alias
  23. ([pd.Period('2000', 'D'), pd.Period('2001', 'D')], 'Period[D]',
  24. period_array(['2000', '2001'], freq='D')),
  25. # Period dtype
  26. ([pd.Period('2000', 'D')], pd.PeriodDtype('D'),
  27. period_array(['2000'], freq='D')),
  28. # Datetime (naive)
  29. ([1, 2], np.dtype('datetime64[ns]'),
  30. pd.arrays.DatetimeArray._from_sequence(
  31. np.array([1, 2], dtype='datetime64[ns]'))),
  32. (np.array([1, 2], dtype='datetime64[ns]'), None,
  33. pd.arrays.DatetimeArray._from_sequence(
  34. np.array([1, 2], dtype='datetime64[ns]'))),
  35. (pd.DatetimeIndex(['2000', '2001']), np.dtype('datetime64[ns]'),
  36. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'])),
  37. (pd.DatetimeIndex(['2000', '2001']), None,
  38. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'])),
  39. (['2000', '2001'], np.dtype('datetime64[ns]'),
  40. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'])),
  41. # Datetime (tz-aware)
  42. (['2000', '2001'], pd.DatetimeTZDtype(tz="CET"),
  43. pd.arrays.DatetimeArray._from_sequence(
  44. ['2000', '2001'], dtype=pd.DatetimeTZDtype(tz="CET"))),
  45. # Timedelta
  46. (['1H', '2H'], np.dtype('timedelta64[ns]'),
  47. pd.arrays.TimedeltaArray._from_sequence(['1H', '2H'])),
  48. (pd.TimedeltaIndex(['1H', '2H']), np.dtype('timedelta64[ns]'),
  49. pd.arrays.TimedeltaArray._from_sequence(['1H', '2H'])),
  50. (pd.TimedeltaIndex(['1H', '2H']), None,
  51. pd.arrays.TimedeltaArray._from_sequence(['1H', '2H'])),
  52. # Category
  53. (['a', 'b'], 'category', pd.Categorical(['a', 'b'])),
  54. (['a', 'b'], pd.CategoricalDtype(None, ordered=True),
  55. pd.Categorical(['a', 'b'], ordered=True)),
  56. # Interval
  57. ([pd.Interval(1, 2), pd.Interval(3, 4)], 'interval',
  58. pd.arrays.IntervalArray.from_tuples([(1, 2), (3, 4)])),
  59. # Sparse
  60. ([0, 1], 'Sparse[int64]', pd.SparseArray([0, 1], dtype='int64')),
  61. # IntegerNA
  62. ([1, None], 'Int16', integer_array([1, None], dtype='Int16')),
  63. (pd.Series([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  64. # Index
  65. (pd.Index([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  66. # Series[EA] returns the EA
  67. (pd.Series(pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])),
  68. None,
  69. pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])),
  70. # "3rd party" EAs work
  71. ([decimal.Decimal(0), decimal.Decimal(1)], 'decimal', to_decimal([0, 1])),
  72. # pass an ExtensionArray, but a different dtype
  73. (period_array(['2000', '2001'], freq='D'),
  74. 'category',
  75. pd.Categorical([pd.Period('2000', 'D'), pd.Period('2001', 'D')])),
  76. ])
  77. def test_array(data, dtype, expected):
  78. result = pd.array(data, dtype=dtype)
  79. tm.assert_equal(result, expected)
  80. def test_array_copy():
  81. a = np.array([1, 2])
  82. # default is to copy
  83. b = pd.array(a)
  84. assert np.shares_memory(a, b._ndarray) is False
  85. # copy=True
  86. b = pd.array(a, copy=True)
  87. assert np.shares_memory(a, b._ndarray) is False
  88. # copy=False
  89. b = pd.array(a, copy=False)
  90. assert np.shares_memory(a, b._ndarray) is True
  91. cet = pytz.timezone("CET")
  92. @pytest.mark.parametrize('data, expected', [
  93. # period
  94. ([pd.Period("2000", "D"), pd.Period("2001", "D")],
  95. period_array(["2000", "2001"], freq="D")),
  96. # interval
  97. ([pd.Interval(0, 1), pd.Interval(1, 2)],
  98. pd.arrays.IntervalArray.from_breaks([0, 1, 2])),
  99. # datetime
  100. ([pd.Timestamp('2000',), pd.Timestamp('2001')],
  101. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'])),
  102. ([datetime.datetime(2000, 1, 1), datetime.datetime(2001, 1, 1)],
  103. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'])),
  104. (np.array([1, 2], dtype='M8[ns]'),
  105. pd.arrays.DatetimeArray(np.array([1, 2], dtype='M8[ns]'))),
  106. (np.array([1, 2], dtype='M8[us]'),
  107. pd.arrays.DatetimeArray(np.array([1000, 2000], dtype='M8[ns]'))),
  108. # datetimetz
  109. ([pd.Timestamp('2000', tz='CET'), pd.Timestamp('2001', tz='CET')],
  110. pd.arrays.DatetimeArray._from_sequence(
  111. ['2000', '2001'], dtype=pd.DatetimeTZDtype(tz='CET'))),
  112. ([datetime.datetime(2000, 1, 1, tzinfo=cet),
  113. datetime.datetime(2001, 1, 1, tzinfo=cet)],
  114. pd.arrays.DatetimeArray._from_sequence(['2000', '2001'],
  115. tz=cet)),
  116. # timedelta
  117. ([pd.Timedelta('1H'), pd.Timedelta('2H')],
  118. pd.arrays.TimedeltaArray._from_sequence(['1H', '2H'])),
  119. (np.array([1, 2], dtype='m8[ns]'),
  120. pd.arrays.TimedeltaArray(np.array([1, 2], dtype='m8[ns]'))),
  121. (np.array([1, 2], dtype='m8[us]'),
  122. pd.arrays.TimedeltaArray(np.array([1000, 2000], dtype='m8[ns]'))),
  123. ])
  124. def test_array_inference(data, expected):
  125. result = pd.array(data)
  126. tm.assert_equal(result, expected)
  127. @pytest.mark.parametrize('data', [
  128. # mix of frequencies
  129. [pd.Period("2000", "D"), pd.Period("2001", "A")],
  130. # mix of closed
  131. [pd.Interval(0, 1, closed='left'), pd.Interval(1, 2, closed='right')],
  132. # Mix of timezones
  133. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000", tz="UTC")],
  134. # Mix of tz-aware and tz-naive
  135. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000")],
  136. np.array([pd.Timestamp('2000'), pd.Timestamp('2000', tz='CET')]),
  137. ])
  138. def test_array_inference_fails(data):
  139. result = pd.array(data)
  140. expected = PandasArray(np.array(data, dtype=object))
  141. tm.assert_extension_array_equal(result, expected)
  142. @pytest.mark.parametrize("data", [
  143. np.array([[1, 2], [3, 4]]),
  144. [[1, 2], [3, 4]],
  145. ])
  146. def test_nd_raises(data):
  147. with pytest.raises(ValueError, match='PandasArray must be 1-dimensional'):
  148. pd.array(data)
  149. def test_scalar_raises():
  150. with pytest.raises(ValueError,
  151. match="Cannot pass scalar '1'"):
  152. pd.array(1)
  153. # ---------------------------------------------------------------------------
  154. # A couple dummy classes to ensure that Series and Indexes are unboxed before
  155. # getting to the EA classes.
  156. @register_extension_dtype
  157. class DecimalDtype2(DecimalDtype):
  158. name = 'decimal2'
  159. @classmethod
  160. def construct_array_type(cls):
  161. return DecimalArray2
  162. class DecimalArray2(DecimalArray):
  163. @classmethod
  164. def _from_sequence(cls, scalars, dtype=None, copy=False):
  165. if isinstance(scalars, (pd.Series, pd.Index)):
  166. raise TypeError
  167. return super(DecimalArray2, cls)._from_sequence(
  168. scalars, dtype=dtype, copy=copy
  169. )
  170. @pytest.mark.parametrize("box", [pd.Series, pd.Index])
  171. def test_array_unboxes(box):
  172. data = box([decimal.Decimal('1'), decimal.Decimal('2')])
  173. # make sure it works
  174. with pytest.raises(TypeError):
  175. DecimalArray2._from_sequence(data)
  176. result = pd.array(data, dtype='decimal2')
  177. expected = DecimalArray2._from_sequence(data.values)
  178. tm.assert_equal(result, expected)
  179. @pytest.fixture
  180. def registry_without_decimal():
  181. idx = registry.dtypes.index(DecimalDtype)
  182. registry.dtypes.pop(idx)
  183. yield
  184. registry.dtypes.append(DecimalDtype)
  185. def test_array_not_registered(registry_without_decimal):
  186. # check we aren't on it
  187. assert registry.find('decimal') is None
  188. data = [decimal.Decimal('1'), decimal.Decimal('2')]
  189. result = pd.array(data, dtype=DecimalDtype)
  190. expected = DecimalArray._from_sequence(data)
  191. tm.assert_equal(result, expected)