Skip to content

Module src.training

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

import pickle
import torch
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer
from torcheval.metrics.functional import (
    multiclass_accuracy,
    multiclass_auroc,
    multiclass_recall,
    multiclass_precision,
)
from typing import Dict, Tuple

from config import Config  # type:ignore [import]


def get_datasets() -> Tuple[Dataset, Dataset]:
    """
    Load and return the training and testing datasets.

    Returns:
        Tuple[Dataset, Dataset]: training and testing datasets
    """
    with open("./data/client-train.pt", "rb") as f:
        trainset = pickle.load(f)
    with open("./data/client-test.pt", "rb") as f:
        testset = pickle.load(f)
    return trainset, testset


def get_data_loader(batch_size: int) -> Tuple[DataLoader, DataLoader]:
    """
    Get the training and testing data loaders.

    Args:
        batch_size (int): batch size

    Returns:
        Tuple[DataLoader, DataLoader]: training and testing data loaders
    """
    trainset, testset = get_datasets()
    train_loader = DataLoader(trainset, batch_size=batch_size)
    test_loader = DataLoader(testset, batch_size=batch_size)
    return train_loader, test_loader


def train_epoch(config: Config, model: Module, train_loader: DataLoader, optimizer: Optimizer, epoch: int):
    """
    Train a model for one epoch.

    Args:
        config (Config): training configuration and logging handle
        model (Module): model to train
        train_loader (DataLoader): training data loader
        optimizer (Optimizer): training optimizer
        epoch (int): current epoch
    """
    dataset_length = len(train_loader.dataset)  # type: ignore[arg-type]
    dataloader_length = len(train_loader)

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader, 0):
        data, target = data.to(config.device), target.to(config.device)
        optimizer.zero_grad()
        output = model(data)
        loss = config.loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % config.log_interval == 0 or batch_idx == dataloader_length - 1:
            x = batch_idx * config.batch_size + len(data)
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch, x, dataset_length,
                100. * x / dataset_length, loss.item()
            )
            config.logger.info(msg)

    config.summary_writer.add_scalar("Loss/Training", loss.item(), config.get_global_training_epoch(epoch))


def test(config: Config, model: Module, test_loader: DataLoader, epoch: int) -> Dict[str, float]:
    """
    Test a model.

    Args:
        config (Config): training configuration and logging handle
        model (Module): model to test
        test_loader (DataLoader): testing data loader
        epoch (int): current epoch

    Returns:
        Dict[str, float]: calculated metrics
    """
    model.eval()
    test_loss = 0.
    outputs, targets = torch.tensor([]), torch.tensor([])
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            test_loss += config.loss(output, target).detach().cpu().item()
            outputs = torch.cat((outputs, output.detach().cpu()), dim=0)
            targets = torch.cat((targets, target.detach().cpu()), dim=0)

    metrics: Dict[str, float] = dict(
        loss=test_loss/float(len(targets)),
        accuracy=multiclass_accuracy(outputs, targets, num_classes=10).item(),
        auroc=multiclass_auroc(outputs, targets, num_classes=10).item(),
        recall=multiclass_recall(outputs, targets, num_classes=10).item(),
        precision=multiclass_precision(outputs, targets, num_classes=10).item(),
    )

    msg = "Test set: Average loss: {:.4f}, Accuracy: {}".format(
        metrics["loss"], metrics["accuracy"]
    )
    config.logger.info(msg)

    global_epoch = config.get_global_training_epoch(epoch)
    config.summary_writer.add_scalar("Loss/Testing", metrics["loss"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Accuracy", metrics["accuracy"], global_epoch)
    config.summary_writer.add_scalar("Metrics/AUROC", metrics["auroc"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Recall", metrics["recall"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Precision", metrics["precision"], global_epoch)

    return metrics

Functions

get_data_loader

def get_data_loader(
    batch_size: int
) -> Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader]

Get the training and testing data loaders.

Parameters:

Name Type Description Default
batch_size int batch size None

Returns:

Type Description
Tuple[DataLoader, DataLoader] training and testing data loaders
View Source
def get_data_loader(batch_size: int) -> Tuple[DataLoader, DataLoader]:
    """
    Get the training and testing data loaders.

    Args:
        batch_size (int): batch size

    Returns:
        Tuple[DataLoader, DataLoader]: training and testing data loaders
    """
    trainset, testset = get_datasets()
    train_loader = DataLoader(trainset, batch_size=batch_size)
    test_loader = DataLoader(testset, batch_size=batch_size)
    return train_loader, test_loader

get_datasets

def get_datasets() -> Tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset]

Load and return the training and testing datasets.

Returns:

Type Description
Tuple[Dataset, Dataset] training and testing datasets
View Source
def get_datasets() -> Tuple[Dataset, Dataset]:
    """
    Load and return the training and testing datasets.

    Returns:
        Tuple[Dataset, Dataset]: training and testing datasets
    """
    with open("./data/client-train.pt", "rb") as f:
        trainset = pickle.load(f)
    with open("./data/client-test.pt", "rb") as f:
        testset = pickle.load(f)
    return trainset, testset

test

def test(
    config: config.Config,
    model: torch.nn.modules.module.Module,
    test_loader: torch.utils.data.dataloader.DataLoader,
    epoch: int
) -> Dict[str, float]

Test a model.

Parameters:

Name Type Description Default
config Config training configuration and logging handle None
model Module model to test None
test_loader DataLoader testing data loader None
epoch int current epoch None

Returns:

Type Description
Dict[str, float] calculated metrics
View Source
def test(config: Config, model: Module, test_loader: DataLoader, epoch: int) -> Dict[str, float]:
    """
    Test a model.

    Args:
        config (Config): training configuration and logging handle
        model (Module): model to test
        test_loader (DataLoader): testing data loader
        epoch (int): current epoch

    Returns:
        Dict[str, float]: calculated metrics
    """
    model.eval()
    test_loss = 0.
    outputs, targets = torch.tensor([]), torch.tensor([])
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            test_loss += config.loss(output, target).detach().cpu().item()
            outputs = torch.cat((outputs, output.detach().cpu()), dim=0)
            targets = torch.cat((targets, target.detach().cpu()), dim=0)

    metrics: Dict[str, float] = dict(
        loss=test_loss/float(len(targets)),
        accuracy=multiclass_accuracy(outputs, targets, num_classes=10).item(),
        auroc=multiclass_auroc(outputs, targets, num_classes=10).item(),
        recall=multiclass_recall(outputs, targets, num_classes=10).item(),
        precision=multiclass_precision(outputs, targets, num_classes=10).item(),
    )

    msg = "Test set: Average loss: {:.4f}, Accuracy: {}".format(
        metrics["loss"], metrics["accuracy"]
    )
    config.logger.info(msg)

    global_epoch = config.get_global_training_epoch(epoch)
    config.summary_writer.add_scalar("Loss/Testing", metrics["loss"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Accuracy", metrics["accuracy"], global_epoch)
    config.summary_writer.add_scalar("Metrics/AUROC", metrics["auroc"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Recall", metrics["recall"], global_epoch)
    config.summary_writer.add_scalar("Metrics/Precision", metrics["precision"], global_epoch)

    return metrics

train_epoch

def train_epoch(
    config: config.Config,
    model: torch.nn.modules.module.Module,
    train_loader: torch.utils.data.dataloader.DataLoader,
    optimizer: torch.optim.optimizer.Optimizer,
    epoch: int
)

Train a model for one epoch.

Parameters:

Name Type Description Default
config Config training configuration and logging handle None
model Module model to train None
train_loader DataLoader training data loader None
optimizer Optimizer training optimizer None
epoch int current epoch None
View Source
def train_epoch(config: Config, model: Module, train_loader: DataLoader, optimizer: Optimizer, epoch: int):
    """
    Train a model for one epoch.

    Args:
        config (Config): training configuration and logging handle
        model (Module): model to train
        train_loader (DataLoader): training data loader
        optimizer (Optimizer): training optimizer
        epoch (int): current epoch
    """
    dataset_length = len(train_loader.dataset)  # type: ignore[arg-type]
    dataloader_length = len(train_loader)

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader, 0):
        data, target = data.to(config.device), target.to(config.device)
        optimizer.zero_grad()
        output = model(data)
        loss = config.loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % config.log_interval == 0 or batch_idx == dataloader_length - 1:
            x = batch_idx * config.batch_size + len(data)
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch, x, dataset_length,
                100. * x / dataset_length, loss.item()
            )
            config.logger.info(msg)

    config.summary_writer.add_scalar("Loss/Training", loss.item(), config.get_global_training_epoch(epoch))