Skip to content

Commit

Permalink
Support drop_first in get_dummies (#16795)
Browse files Browse the repository at this point in the history
closes #16791

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Matthew Murray (https://github.com/Matt711)

URL: #16795
  • Loading branch information
mroeschke authored Sep 17, 2024
1 parent f8d5063 commit 7285efb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/cudf/cudf/core/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,8 @@ def get_dummies(
sparse : boolean, optional
Right now this is NON-FUNCTIONAL argument in rapids.
drop_first : boolean, optional
Right now this is NON-FUNCTIONAL argument in rapids.
Whether to get k-1 dummies out of k categorical levels by removing the
first level.
columns : sequence of str, optional
Names of columns to encode. If not provided, will attempt to encode all
columns. Note this is different from pandas default behavior, which
Expand Down Expand Up @@ -806,9 +807,6 @@ def get_dummies(
if sparse:
raise NotImplementedError("sparse is not supported yet")

if drop_first:
raise NotImplementedError("drop_first is not supported yet")

if isinstance(data, cudf.DataFrame):
encode_fallback_dtypes = ["object", "category"]

Expand Down Expand Up @@ -862,6 +860,7 @@ def get_dummies(
prefix=prefix_map.get(name, prefix),
prefix_sep=prefix_sep_map.get(name, prefix_sep),
dtype=dtype,
drop_first=drop_first,
)
result_data.update(col_enc_data)
return cudf.DataFrame._from_data(result_data, index=data.index)
Expand All @@ -874,6 +873,7 @@ def get_dummies(
prefix=prefix,
prefix_sep=prefix_sep,
dtype=dtype,
drop_first=drop_first,
)
return cudf.DataFrame._from_data(data, index=ser.index)

Expand Down Expand Up @@ -1256,6 +1256,7 @@ def _one_hot_encode_column(
prefix: str | None,
prefix_sep: str | None,
dtype: Dtype | None,
drop_first: bool,
) -> dict[str, ColumnBase]:
"""Encode a single column with one hot encoding. The return dictionary
contains pairs of (category, encodings). The keys may be prefixed with
Expand All @@ -1276,6 +1277,8 @@ def _one_hot_encode_column(
)
data = one_hot_encode(column, categories)

if drop_first and len(data):
data.pop(next(iter(data)))
if prefix is not None and prefix_sep is not None:
data = {f"{prefix}{prefix_sep}{col}": enc for col, enc in data.items()}
if dtype:
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/tests/test_onehot.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,20 @@ def test_get_dummies_cats_deprecated():
df = cudf.DataFrame(range(3))
with pytest.warns(FutureWarning):
cudf.get_dummies(df, cats={0: [0, 1, 2]})


def test_get_dummies_drop_first_series():
result = cudf.get_dummies(cudf.Series(list("abcaa")), drop_first=True)
expected = pd.get_dummies(pd.Series(list("abcaa")), drop_first=True)
assert_eq(result, expected)


def test_get_dummies_drop_first_dataframe():
result = cudf.get_dummies(
cudf.DataFrame({"A": list("abcaa"), "B": list("bcaab")}),
drop_first=True,
)
expected = pd.get_dummies(
pd.DataFrame({"A": list("abcaa"), "B": list("bcaab")}), drop_first=True
)
assert_eq(result, expected)

0 comments on commit 7285efb

Please sign in to comment.