TorchMetrics能够为咱们提供一种简略、洁净、高效的形式来解决验证指标。TorchMetrics提供了许多现成的指标实现,如Accuracy, Dice, F1 Score, Recall, MAE等等,简直最常见的指标都能够在外面找到。torchmetrics目前曾经包好了80+工作评估指标。

TorchMetrics装置也非常简单,只须要PyPI装置最新版本:

 pip install torchmetrics

根本流程介绍

在训练时咱们都是应用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递实现后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保留它(在其外部被称为state)。

当所有的批次实现时(也就是训练的一个Epoch实现),咱们就能够从度量对象返回最终后果(这是对所有批计算的后果)。这里的每个度量对象都是从metric类继承,它蕴含了4个要害办法:

  • metric.forward(pred,target) - 更新度量状态并返回以后批次上计算的度量后果。如果您违心,也能够应用metric(pred, target),没有区别。
  • metric.update(pred,target) - 与forward雷同,然而不会返回计算结果,相当于是只将后果存入了state。如果不须要在以后批处理上计算出的度量后果,则优先应用这个办法,因为他不计算最终后果速度会很快。
  • metric.compute() - 返回在所有批次上计算的最终后果。也就是说其实forward相当于是update+compute。
  • metric.reset() - 重置状态,以便为下一个验证阶段做好筹备。

也就是说:在咱们训练的以后批次,取得了模型的输入后能够forward或update(倡议应用update)。在批次实现后,调用compute以获取最终后果。最初,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标

例如上面的代码:

 import torch import torchmetrics  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = YourModel().to(device) metric = torchmetrics.Accuracy()  for batch_idx, (data, target) in enumerate(val_dataloader):     data, target = data.to(device), target.to(device)     output = model(data)     # metric on current batch     batch_acc = metric.update(preds, target)     print(f"Accuracy on batch {i}: {batch_acc}")  # metric on all batches using custom accumulation val_acc = metric.compute() print(f"Accuracy on all data: {val_acc}")  # Resetting internal state such that metric is ready for new data metric.reset()

MetricCollection

在下面的示例中,应用了单个指标进行计算,但个别状况下可能会蕴含多个指标。Torchmetrics提供了MetricCollection能够将多个指标包装成单个可调用类,其接口与下面的根本用法雷同。这样咱们就无需独自解决每个指标。

代码如下:

 import torch from torchmetrics import MetricCollection, Accuracy, Precision, Recall  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = YourModel().to(device) # collection of all validation metrics metric_collection = MetricCollection({     'acc': Accuracy(),     'prec': Precision(num_classes=10, average='macro'),     'rec': Recall(num_classes=10, average='macro') })  for batch_idx, (data, target) in enumerate(val_dataloader):     data, target = data.to(device), target.to(device)     output = model(data)     batch_metrics = metric_collection.forward(preds, target)     print(f"Metrics on batch {i}: {batch_metrics}")  val_metrics = metric_collection.compute() print(f"Metrics on all data: {val_metrics}") metric.reset()

也能够应用列表而不是字典,然而应用字典会更加清晰。

自定义指标

尽管Torchmetrics蕴含了很多常见的指标,然而有时咱们还须要本人定义一些不罕用的特定指标。咱们只须要继承 Metric 类并且实现 updatecomputing 办法就能够了,另外就是须要在类初始化的时候应用self.add_state(state_name, default)来初始化咱们的对象。

代码也很简略:

 import torch import torchmetrics  class MyAccuracy(Metric):     def __init__(self, delta):         super().__init__()         # to count the correct predictions         self.add_state('corrects', default=torch.tensor(0))         # to count the total predictions         self.add_state('total', default=torch.tensor(0))      def update(self, preds, target):         # update correct predictions count         self.correct += torch.sum(preds == target)         # update total count, numel() returns the total number of elements          self.total += target.numel()      def compute(self):         # final computation         return self.correct / self.total

总结

就是这样,Torchmetrics为咱们指标计算提供了非常简单疾速的解决形式,如果你想更多的理解它的用法,请参考官网文档:

https://avoid.overfit.cn/post/bdedfe4229e04da49049c4e7d56152d1

作者:Mattia Gatti