forked from TensorLayer/tensorlayer3
99 lines
1.8 KiB
Python
99 lines
1.8 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import tensorflow as tf
|
|
from tensorflow.keras.metrics import Metric
|
|
|
|
__all__ = [
|
|
'Accuracy',
|
|
'Auc',
|
|
'Precision',
|
|
'Recall',
|
|
]
|
|
|
|
|
|
class Accuracy(object):
|
|
|
|
def __init__(self, topk=1):
|
|
self.topk = topk
|
|
if topk == 1:
|
|
self.accuary = tf.keras.metrics.Accuracy()
|
|
else:
|
|
self.accuary = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=topk)
|
|
|
|
def update(self, y_pred, y_true):
|
|
|
|
if self.topk == 1:
|
|
y_pred = tf.argmax(y_pred, axis=1)
|
|
self.accuary.update_state(y_true, y_pred)
|
|
else:
|
|
self.accuary.update_state(y_true, y_pred)
|
|
|
|
def result(self):
|
|
|
|
return self.accuary.result()
|
|
|
|
def reset(self):
|
|
|
|
self.accuary.reset_states()
|
|
|
|
|
|
class Auc(object):
|
|
|
|
def __init__(
|
|
self,
|
|
curve='ROC',
|
|
num_thresholds=200,
|
|
):
|
|
self.auc = tf.keras.metrics.AUC(num_thresholds=num_thresholds, curve=curve)
|
|
|
|
def update(self, y_pred, y_true):
|
|
|
|
self.auc.update_state(y_true, y_pred)
|
|
|
|
def result(self):
|
|
|
|
return self.auc.result()
|
|
|
|
def reset(self):
|
|
|
|
self.auc.reset_states()
|
|
|
|
|
|
class Precision(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.precision = tf.keras.metrics.Precision()
|
|
|
|
def update(self, y_pred, y_true):
|
|
|
|
self.precision.update_state(y_true, y_pred)
|
|
|
|
def result(self):
|
|
|
|
return self.precision.result()
|
|
|
|
def reset(self):
|
|
|
|
self.precision.reset_states()
|
|
|
|
|
|
class Recall(object):
|
|
|
|
def __init__(self):
|
|
|
|
self.recall = tf.keras.metrics.Recall()
|
|
|
|
def update(self, y_pred, y_true):
|
|
|
|
self.recall.update_state(y_true, y_pred)
|
|
|
|
def result(self):
|
|
|
|
return self.recall.result()
|
|
|
|
def reset(self):
|
|
|
|
self.recall.reset_states()
|