diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py index 6bf29d7..fb26fa3 100644 --- a/dask_xgboost/core.py +++ b/dask_xgboost/core.py @@ -14,6 +14,7 @@ sparse = False ss = False +import dask from dask import delayed from dask.distributed import wait, default_client import dask.dataframe as dd @@ -107,6 +108,7 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): -------- train """ + # Break apart Dask.array/dataframe into chunks/parts data_parts = data.to_delayed() label_parts = labels.to_delayed() @@ -158,6 +160,58 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): raise gen.Return(result) +def compute_array_chunks(arr): + assert isinstance(arr, da.Array) + parts = arr.to_delayed() + if isinstance(parts, np.ndarray): + parts = parts.flatten().tolist() + chunks = tuple([part.shape[0].compute() for part in parts]) + return chunks + + +def align_training_data(client, data, labels): + """Aligns training data and labels + + Parameters + ---------- + client: dask.distributed.Client + data: dask Array or dask DataFrame + Training features + labels: dask Array or dask DataFrame + Training target + + Returns + ------- + data : dask Array or dask DataFrame + labels : dask Array or dask DataFrame + """ + with dask.config.set(scheduler=client): + # Compute data chunk/partition sizes + if isinstance(data, dd._Frame): + data_chunks = tuple(data.map_partitions(len).compute()) + elif isinstance(data, da.Array): + if any(np.isnan(sum(c)) for c in data.chunks): + data_chunks = compute_array_chunks(data) + else: + data_chunks = data.chunks[0] + + # Re-chunk/partition labels to match data + # Only rechunk if there is a size mismatch betwen data and labels + if isinstance(labels, dd._Frame): + labels_arr = labels.to_dask_array(lengths=True) + if labels_arr.chunks != (data_chunks,): + labels_arr = labels_arr.rechunk({0: data_chunks}) + labels = labels_arr.to_dask_dataframe() + elif isinstance(labels, da.Array): + if any(np.isnan(sum(c)) for c in labels.chunks): + labels_chunks = compute_array_chunks(labels) + labels._chunks = (labels_chunks,) + if labels.chunks != data_chunks: + labels = labels.rechunk({0: data_chunks}) + + return data, labels + + def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): """ Train an XGBoost model on a Dask Cluster @@ -187,6 +241,7 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): -------- predict """ + data, labels = align_training_data(client, data, labels) return client.sync(_train, client, params, data, labels, dmatrix_kwargs, **kwargs) diff --git a/dask_xgboost/tests/test_core.py b/dask_xgboost/tests/test_core.py index 22ca104..ab76bf1 100644 --- a/dask_xgboost/tests/test_core.py +++ b/dask_xgboost/tests/test_core.py @@ -15,6 +15,7 @@ from distributed.utils_test import gen_cluster, loop, cluster # noqa import dask_xgboost as dxgb +from dask_xgboost.core import align_training_data # Workaround for conflict with distributed 1.23.0 # https://github.com/dask/dask-xgboost/pull/27#issuecomment-417474734 @@ -158,6 +159,23 @@ def test_basic(c, s, a, b): assert ((predictions > 0.5) != labels).sum() < 2 +@pytest.mark.parametrize('X, y', [ # noqa + (dd.from_pandas(df, chunksize=5), + dd.from_pandas(labels, chunksize=6)), + (dd.from_pandas(df, chunksize=5).values, + dd.from_pandas(labels, chunksize=6)), + (dd.from_pandas(df, chunksize=5), + dd.from_pandas(labels, chunksize=6).values), + (dd.from_pandas(df, chunksize=5).values, + dd.from_pandas(labels, chunksize=6).values), +]) +def test_unequal_partition_lengths(loop, X, y): # noqa + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop): + clf = dxgb.XGBClassifier() + clf.fit(X, y) + + @gen_cluster(client=True, timeout=None, check_new_threads=False) def test_dmatrix_kwargs(c, s, a, b): xgb.rabit.init() # workaround for "Doing rabit call after Finalize" @@ -269,3 +287,38 @@ def f(part): yield dxgb.train(c, param, df, df.x) assert 'foo' in str(info.value) + + +def test_align_training_data_dataframe(loop): # noqa + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as client: + X = dd.from_pandas(df, chunksize=5) + y = dd.from_pandas(labels, chunksize=6) + + X_partition_lengths = tuple(X.map_partitions(len).compute()) + y_partition_lengths = tuple(y.map_partitions(len).compute()) + assert X_partition_lengths != y_partition_lengths + + X_align, y_align = align_training_data(client, X, y) + assert isinstance(X_align, dd._Frame) + assert isinstance(y_align, dd._Frame) + + X_partition_lengths = tuple(X_align.map_partitions(len).compute()) + y_partition_lengths = tuple(y_align.map_partitions(len).compute()) + assert X_partition_lengths == y_partition_lengths + + +@pytest.mark.parametrize('equal_partitions', [True, False]) # noqa +def test_align_training_data_rechunk(loop, equal_partitions): # noqa + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as client: + X = dd.from_pandas(df, chunksize=5) + if equal_partitions: + y = dd.from_pandas(labels, chunksize=5) + else: + y = dd.from_pandas(labels, chunksize=6) + + X_align, y_align = align_training_data(client, X, y) + assert X_align is X + if equal_partitions: + assert y_align is y