-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
System information
- centos 7.0
- python 3.4
- tensorflow 1.2.1
Describe the problem
When I call tensorflow.contrib.learn.DNNRegressor.fit(x_train_dict , y_train,steps=1000) , x_train_dict is dict and y_train is array , the program throws the following exception:
File "/home/star/yuce.ddxq.mobi/zhuge/management/commands/forecast_product_sale.py", line 148, in tflearn_dnn_train2
regressor.fit(x_train_dict, y_train, steps=10000, batch_size=10)
File "/usr/lib/python3.4/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func
return func(*args, **kwargs)
File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 439, in fit
SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1340, in fit
epochs=None)
File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 137, in _get_input_fn
epochs=epochs)
File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py", line 152, in setup_train_data_feeder
x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py", line 326, in __init__
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
AttributeError: 'numpy.ndarray' object has no attribute 'items'
and the related code is:
x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(
y, dict)
if isinstance(y, list):
y = np.array(y)
self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
]) if x_is_dict else check_array(x, x.dtype)
self._y = None if y is None else \
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
the last line of the code seems wrong , it should use the y_is_dict instead of x_is_dict ?
I change the code to :
dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if y_is_dict else check_array(y, y.dtype)
and then it works .
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels