Skip to content
Snippets Groups Projects
Unverified Commit 05f7ef9a authored by Francesco Bruzzesi's avatar Francesco Bruzzesi Committed by GitHub
Browse files

BUG: Fix `ListAccessor` methods to preserve original name (#60527)

* fix: preserve series name in ListAccessor

* formatting

* add whatsnew v3.0.0 entry
parent 59f947ff
No related branches found
No related tags found
No related merge requests found
......@@ -798,6 +798,7 @@ Other
- Bug in :meth:`read_csv` where chained fsspec TAR file and ``compression="infer"`` fails with ``tarfile.ReadError`` (:issue:`60028`)
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)
- Bug in ``Series.list`` methods not preserving the original :class:`Index`. (:issue:`58425`)
- Bug in ``Series.list`` methods not preserving the original name. (:issue:`60522`)
- Bug in printing a :class:`DataFrame` with a :class:`DataFrame` stored in :attr:`DataFrame.attrs` raised a ``ValueError`` (:issue:`60455`)
.. ***DO NOT USE THIS SECTION***
......
......@@ -117,7 +117,10 @@ class ListAccessor(ArrowAccessor):
value_lengths = pc.list_value_length(self._pa_array)
return Series(
value_lengths, dtype=ArrowDtype(value_lengths.type), index=self._data.index
value_lengths,
dtype=ArrowDtype(value_lengths.type),
index=self._data.index,
name=self._data.name,
)
def __getitem__(self, key: int | slice) -> Series:
......@@ -162,7 +165,10 @@ class ListAccessor(ArrowAccessor):
# key = pc.add(key, pc.list_value_length(self._pa_array))
element = pc.list_element(self._pa_array, key)
return Series(
element, dtype=ArrowDtype(element.type), index=self._data.index
element,
dtype=ArrowDtype(element.type),
index=self._data.index,
name=self._data.name,
)
elif isinstance(key, slice):
if pa_version_under11p0:
......@@ -181,7 +187,12 @@ class ListAccessor(ArrowAccessor):
if step is None:
step = 1
sliced = pc.list_slice(self._pa_array, start, stop, step)
return Series(sliced, dtype=ArrowDtype(sliced.type), index=self._data.index)
return Series(
sliced,
dtype=ArrowDtype(sliced.type),
index=self._data.index,
name=self._data.name,
)
else:
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")
......@@ -223,7 +234,12 @@ class ListAccessor(ArrowAccessor):
counts = pa.compute.list_value_length(self._pa_array)
flattened = pa.compute.list_flatten(self._pa_array)
index = self._data.index.repeat(counts.fill_null(pa.scalar(0, counts.type)))
return Series(flattened, dtype=ArrowDtype(flattened.type), index=index)
return Series(
flattened,
dtype=ArrowDtype(flattened.type),
index=index,
name=self._data.name,
)
class StructAccessor(ArrowAccessor):
......
......@@ -25,9 +25,10 @@ def test_list_getitem(list_dtype):
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(list_dtype),
name="a",
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]")
expected = Series([2, None, None], dtype="int64[pyarrow]", name="a")
tm.assert_series_equal(actual, expected)
......@@ -37,9 +38,15 @@ def test_list_getitem_index():
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]", index=[1, 3, 7])
expected = Series(
[2, None, None],
dtype="int64[pyarrow]",
index=[1, 3, 7],
name="a",
)
tm.assert_series_equal(actual, expected)
......@@ -48,6 +55,7 @@ def test_list_getitem_slice():
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
if pa_version_under11p0:
with pytest.raises(
......@@ -60,6 +68,7 @@ def test_list_getitem_slice():
[[2, 3], [None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
name="a",
)
tm.assert_series_equal(actual, expected)
......@@ -68,9 +77,10 @@ def test_list_len():
ser = Series(
[[1, 2, 3], [4, None], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
name="a",
)
actual = ser.list.len()
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()))
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()), name="a")
tm.assert_series_equal(actual, expected)
......@@ -78,12 +88,14 @@ def test_list_flatten():
ser = Series(
[[1, 2, 3], None, [4, None], [], [7, 8]],
dtype=ArrowDtype(pa.list_(pa.int64())),
name="a",
)
actual = ser.list.flatten()
expected = Series(
[1, 2, 3, 4, None, 7, 8],
dtype=ArrowDtype(pa.int64()),
index=[0, 0, 0, 2, 2, 4, 4],
name="a",
)
tm.assert_series_equal(actual, expected)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment