Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions twml/twml/contrib/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def get_eval_metric_ops(graph_output, labels, weights):
else:
predcol_list=list(predcols)
for col in predcol_list:
assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !'
if not (0 <= col < graph_output['output'].shape[class_dim]):
raise ValueError('Invalid Prediction Column Index !')
preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col]
labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col]

Expand All @@ -55,7 +56,8 @@ def mean_numeric_label_topK(labels, predictions, weights, name, topK_id):
return tf.metrics.mean(values=top_k_labels, name=name)

def mean_gated_numeric_label_topK(labels, predictions, weights, name, topK_id, bar=2.0):
assert isinstance(bar, int) or isinstance(bar, float), "bar must be int or float"
if not (isinstance(bar, int) or isinstance(bar, float)):
raise TypeError("bar must be int or float")
top_k_labels = tf.gather(params=labels, indices=topK_id, axis=0) # [topK, 1]
gated_top_k_labels = tf.cast(top_k_labels > bar*1.0, tf.int32)
return tf.metrics.mean(values=gated_top_k_labels, name=name)
Expand Down Expand Up @@ -108,8 +110,10 @@ def get_eval_metric_ops(graph_output, labels, weights):
if predcol is None:
pred = graph_output['output']
else:
assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
assert labelcol is not None
if not (0 <= predcol < graph_output['output'].shape[1]):
raise ValueError('Invalid Prediction Column Index !')
if labelcol is None:
raise ValueError('labelcol must be provided when predcol is set')
pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1])
labels = tf.reshape(labels[:, labelcol], shape=[-1, 1])
numOut = graph_output['output'].shape[1]
Expand Down
Loading