diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py index 6bf29d7..6bf442b 100644 --- a/dask_xgboost/core.py +++ b/dask_xgboost/core.py @@ -252,6 +252,11 @@ def predict(client, model, data): result = data.map_blocks(_predict_part, model=model, dtype=np.float32, **kwargs) + else: + raise TypeError( + "Got unexpected input type %s, expected Dask array or dataframe" + % str(data) + ) return result diff --git a/dask_xgboost/tests/test_core.py b/dask_xgboost/tests/test_core.py index 22ca104..15cd300 100644 --- a/dask_xgboost/tests/test_core.py +++ b/dask_xgboost/tests/test_core.py @@ -269,3 +269,10 @@ def f(part): yield dxgb.train(c, param, df, df.x) assert 'foo' in str(info.value) + + +def test_predict_type_error(): + with pytest.raises(TypeError) as info: + dxgb.predict(None, None, 'foo') + assert 'foo' in str(info.value) + assert 'dask array' in str(info.value).lower()