Machine learning metrics for distributed, scalable PyTorch applications.

PyTorchLightning PyTorchLightning Last update: Jun 05, 2022

Machine learning metrics for distributed, scalable PyTorch applications.


What is TorchmetricsImplementing a metricBuilt-in metricsDocsCommunityLicense


PyPI - Python VersionPyPI StatusPyPI StatusCondaCondalicense

CI testing - basePyTorch & CondaBuild Statuscodecov

SlackDocumentation StatusDOIJOSS statuspre-commit.ci status


Installation

Simple installation from PyPI

pip install torchmetrics
Other installations

Install using conda

conda install -c conda-forge torchmetrics

Pip from source

# with gitpip install git+https://github.com/PytorchLightning/metrics.git@release/latest

Pip from archive

pip install https://github.com/PyTorchLightning/metrics/archive/refs/heads/release/latest.zip

Extra dependencies for specialized metrics:

pip install torchmetrics[audio]pip install torchmetrics[image]pip install torchmetrics[text]pip install torchmetrics[all]  # install all of the above

Install latest developer version

pip install https://github.com/PyTorchLightning/metrics/archive/master.zip

What is TorchMetrics

TorchMetrics is a collection of 80+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducibility
  • Reduces boilerplate
  • Automatic accumulation over batches
  • Metrics optimized for distributed-training
  • Automatic synchronization between multiple devices

You can use TorchMetrics with any PyTorch model or with PyTorch Lightning to enjoy additional features such as:

  • Module metrics are automatically placed on the correct device.
  • Native support for logging metrics in Lightning to reduce even more boilerplate.

Using TorchMetrics

Module metrics

The module-based metrics contain internal metric states (similar to the parameters of the PyTorch module) that automate accumulation and synchronization across devices!

  • Automatic accumulation over multiple batches
  • Automatic synchronization between multiple devices
  • Metric arithmetic

This can be run on CPU, single GPU or multi-GPUs!

For the single GPU/CPU case:

import torch# import our libraryimport torchmetrics# initialize metricmetric = torchmetrics.Accuracy()# move the metric to device you want computations to take placedevice = "cuda" if torch.cuda.is_available() else "cpu"metric.to(device)n_batches = 10for i in range(n_batches):    # simulate a classification problem    preds = torch.randn(10, 5).softmax(dim=-1).to(device)    target = torch.randint(5, (10,)).to(device)    # metric on current batch    acc = metric(preds, target)    print(f"Accuracy on batch {i}: {acc}")# metric on all batches using custom accumulationacc = metric.compute()print(f"Accuracy on all data: {acc}")

Module metric usage remains the same when using multiple GPUs or multiple nodes.

Example using DDP
import osimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch import nnfrom torch.nn.parallel import DistributedDataParallel as DDPimport torchmetricsdef metric_ddp(rank, world_size):    os.environ["MASTER_ADDR"] = "localhost"    os.environ["MASTER_PORT"] = "12355"    # create default process group    dist.init_process_group("gloo", rank=rank, world_size=world_size)    # initialize model    metric = torchmetrics.Accuracy()    # define a model and append your metric to it    # this allows metric states to be placed on correct accelerators when    # .to(device) is called on the model    model = nn.Linear(10, 10)    model.metric = metric    model = model.to(rank)    # initialize DDP    model = DDP(model, device_ids=[rank])    n_epochs = 5    # this shows iteration over multiple training epochs    for n in range(n_epochs):        # this will be replaced by a DataLoader with a DistributedSampler        n_batches = 10        for i in range(n_batches):            # simulate a classification problem            preds = torch.randn(10, 5).softmax(dim=-1)            target = torch.randint(5, (10,))            # metric on current batch            acc = metric(preds, target)            if rank == 0:  # print only for rank 0                print(f"Accuracy on batch {i}: {acc}")        # metric on all batches and all accelerators using custom accumulation        # accuracy is same across both accelerators        acc = metric.compute()        print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")        # Reseting internal state such that metric ready for new data        metric.reset()    # cleanup    dist.destroy_process_group()if __name__ == "__main__":    world_size = 2  # number of gpus to parallelize over    mp.spawn(metric_ddp, args=(world_size,), nprocs=world_size, join=True)

Implementing your own Module metric

Implementing your own metric is as easy as subclassing an torch.nn.Module. Simply, subclass torchmetrics.Metricand implement the following methods:

import torchfrom torchmetrics import Metricclass MyAccuracy(Metric):    def __init__(self):        super().__init__()        # call `self.add_state`for every internal state that is needed for the metrics computations        # dist_reduce_fx indicates the function that should be used to reduce        # state from multiple processes        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")    def update(self, preds: torch.Tensor, target: torch.Tensor):        # update metric states        preds, target = self._input_format(preds, target)        assert preds.shape == target.shape        self.correct += torch.sum(preds == target)        self.total += target.numel()    def compute(self):        # compute final result        return self.correct.float() / self.total

Functional metrics

Similar to torch.nn, most metrics have both a module-based and a functional version.The functional versions are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor.

import torch# import our libraryimport torchmetrics# simulate a classification problempreds = torch.randn(10, 5).softmax(dim=-1)target = torch.randint(5, (10,))acc = torchmetrics.functional.accuracy(preds, target)

Covered domains and example metrics

We currently have implemented metrics within the following domains:

In total TorchMetrics contains 80+ metrics!

Contribute!

The lightning + TorchMetrics team is hard at work adding even more metrics.But we're looking for incredible contributors like you to submit new metricsand improve existing ones!

Join our Slack to get help become a contributor!

Community

For help or questions, join our huge community on Slack!

Citation

We’re excited to continue the strong legacy of open source software and have been inspiredover the years by Caffe, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai.

If you want to cite this framework feel free to use GitHub's built-in citation option to generate a bibtex or APA-Style citation based on this file (but only if you loved it 😊).

License

Please observe the Apache 2.0 license that is listed in this repository.In addition, the Lightning framework is Patent Pending.

Subscribe to our newsletter