diff --git a/twml/twml/contrib/metrics/metrics.py b/twml/twml/contrib/metrics/metrics.py index dea1a5273..10693dd37 100644 --- a/twml/twml/contrib/metrics/metrics.py +++ b/twml/twml/contrib/metrics/metrics.py @@ -31,7 +31,8 @@ def get_eval_metric_ops(graph_output, labels, weights): else: predcol_list=list(predcols) for col in predcol_list: - assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !' + if not (0 <= col < graph_output['output'].shape[class_dim]): + raise ValueError('Invalid Prediction Column Index !') preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col] labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col] @@ -55,7 +56,8 @@ def mean_numeric_label_topK(labels, predictions, weights, name, topK_id): return tf.metrics.mean(values=top_k_labels, name=name) def mean_gated_numeric_label_topK(labels, predictions, weights, name, topK_id, bar=2.0): - assert isinstance(bar, int) or isinstance(bar, float), "bar must be int or float" + if not (isinstance(bar, int) or isinstance(bar, float)): + raise TypeError("bar must be int or float") top_k_labels = tf.gather(params=labels, indices=topK_id, axis=0) # [topK, 1] gated_top_k_labels = tf.cast(top_k_labels > bar*1.0, tf.int32) return tf.metrics.mean(values=gated_top_k_labels, name=name) @@ -108,8 +110,10 @@ def get_eval_metric_ops(graph_output, labels, weights): if predcol is None: pred = graph_output['output'] else: - assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !' - assert labelcol is not None + if not (0 <= predcol < graph_output['output'].shape[1]): + raise ValueError('Invalid Prediction Column Index !') + if labelcol is None: + raise ValueError('labelcol must be provided when predcol is set') pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1]) labels = tf.reshape(labels[:, labelcol], shape=[-1, 1]) numOut = graph_output['output'].shape[1]