diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py index 6bf29d7..f39321f 100644 --- a/dask_xgboost/core.py +++ b/dask_xgboost/core.py @@ -14,6 +14,7 @@ sparse = False ss = False +import dask from dask import delayed from dask.distributed import wait, default_client import dask.dataframe as dd @@ -26,6 +27,13 @@ logger = logging.getLogger(__name__) +def maybe_get_client(): + try: + return default_client() + except ValueError: + return None + + def parse_host_port(address): if '://' in address: address = address.rsplit('://', 1)[1] @@ -187,6 +195,12 @@ def train(client, params, data, labels, dmatrix_kwargs={}, **kwargs): -------- predict """ + if (client is None or + not dask.is_dask_collection(data) or + not dask.is_dask_collection(labels)): + dtrain = xgb.DMatrix(data, labels, **dmatrix_kwargs) + return xgb.train(params, dtrain, **kwargs) + return client.sync(_train, client, params, data, labels, dmatrix_kwargs, **kwargs) @@ -233,7 +247,10 @@ def predict(client, model, data): -------- train """ - if isinstance(data, dd._Frame): + if client is None or isinstance(data, (np.ndarray, pd.DataFrame)): + dm = xgb.DMatrix(data) + result = model.predict(dm) + elif isinstance(data, dd._Frame): result = data.map_partitions(_predict_part, model=model) result = result.values elif isinstance(data, da.Array): @@ -276,14 +293,14 @@ def fit(self, X, y=None): ``eval_metric``, ``early_stopping_rounds`` and ``verbose`` fit kwargs. """ - client = default_client() + client = maybe_get_client() xgb_options = self.get_xgb_params() self._Booster = train(client, xgb_options, X, y, num_boost_round=self.n_estimators) return self def predict(self, X): - client = default_client() + client = maybe_get_client() return predict(client, self._Booster, X) @@ -316,14 +333,16 @@ def fit(self, X, y=None, classes=None): 2. The labels are not automatically label-encoded 3. The ``classes_`` and ``n_classes_`` attributes are not learned """ - client = default_client() + client = maybe_get_client() if classes is None: - if isinstance(y, da.Array): + if isinstance(y, np.ndarray): + classes = np.unique(classes) + elif isinstance(y, da.Array): classes = da.unique(y) else: classes = y.unique() - classes = classes.compute() + classes = dask.compute(classes) else: classes = np.asarray(classes) self.classes_ = classes @@ -346,16 +365,20 @@ def fit(self, X, y=None, classes=None): # TODO: auto label-encode y # that will require a dependency on dask-ml # TODO: sample weight - - self._Booster = train(client, xgb_options, X, y, - num_boost_round=self.n_estimators) + bst = train(client, xgb_options, X, y, + num_boost_round=self.n_estimators) + self._Booster = bst return self def predict(self, X): - client = default_client() + client = maybe_get_client() + class_probs = predict(client, self._Booster, X) if class_probs.ndim > 1: - cidx = da.argmax(class_probs, axis=1) + if isinstance(class_probs, (pd.DataFrame, np.ndarray)): + cidx = np.argmax(class_probs, axis=1) + else: + cidx = da.argmax(class_probs, axis=1) else: cidx = (class_probs > 0).astype(np.int64) return cidx diff --git a/dask_xgboost/tests/test_core.py b/dask_xgboost/tests/test_core.py index 22ca104..a9f3e63 100644 --- a/dask_xgboost/tests/test_core.py +++ b/dask_xgboost/tests/test_core.py @@ -269,3 +269,25 @@ def f(part): yield dxgb.train(c, param, df, df.x) assert 'foo' in str(info.value) + + +@gen_cluster(client=True, timeout=None, check_new_threads=False) +def test_concrete(c, s, a, b): + for est in [dxgb.XGBClassifier(), dxgb.XGBRegressor()]: + est.fit(X, y) + result = est.predict(X) + assert isinstance(result, np.ndarray) + est.score(X, y) + + +def test_dask_search_cv(loop): # noqa + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop): + model_selection = pytest.importorskip('dask_ml.model_selection') + est = dxgb.XGBClassifier() + cv = model_selection.RandomizedSearchCV(est, + {'max_depth': [1, 10]}) + dX = da.from_array(X, 5) + dy = da.from_array(y, 5) + cv.fit(dX, dy)