Skip to content

ColumnTransformer _hstack incompatible with scikit's version #1019

@avalanche-pwn

Description

@avalanche-pwn

Describe the issue:
The current dask_ml's transformer's _hstack method has different signature than the method from scikit - it lacks the n_samples argument.

Minimal Complete Verifiable Example:

from dask_ml.wrappers import Incremental
from dask_ml.feature_extraction.text import HashingVectorizer
import dask.dataframe as dd
import pandas as pd
from dask_ml.compose import ColumnTransformer

data = {
    "test1": ["example", "text"],
    "test2": ["lorem", "ipsum"]
}

df = pd.DataFrame(data)
df = dd.from_pandas(df).astype(str)

pipeline = ColumnTransformer([
    ("test1", HashingVectorizer(), "test1"),
    ("test2", HashingVectorizer(), "test2"),
    ])

pipeline.fit(df)

Anything else we need to know?:
This causes a crash:

Traceback (most recent call last):
  File "/home/antoni/Documents/projects/dask/reproducers/1/main.py", line 20, in <module>
    pipeline.fit(df)
    ~~~~~~~~~~~~^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 947, in fit
    self.fit_transform(X, y=y, **params)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/utils/_set_output.py", line 319, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/base.py", line 1389, in wrapper
    return fit_method(estimator, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 1031, in fit_transform
    return self._hstack(list(Xs), n_samples=n_samples)
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ColumnTransformer._hstack() got an unexpected keyword argument 'n_samples'

The fix seems simple enough it would be just adding a check similar to the one in scikit's version before returning. I can implement this just please let me know if this kind of fix seems like enough.

Environment:

  • Dask version: 2025.5.0
  • Dask-ml version: 2025.1.0
  • scikit-learn: 1.6.1
  • Python version: 3.13.3
  • Operating System: Linux
  • Install method (conda, pip, source): pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions