Module src.utils¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from collections import Counter
import torch
from torch.utils.data import DataLoader
from config import Config # type:ignore [import]
def log_data_distribution(config: Config, train_loader: DataLoader, test_loader: DataLoader) -> None:
"""
Calculate and log the data distribution.
Args:
config (Config): training configuration and logging handle
train_loader (DataLoader): training data loader
test_loader (DataLoader): testing data loader
"""
if config.args.round != 0:
return
def add_label_distribution(tag, loader: DataLoader):
if hasattr(loader.dataset, "targets"):
targets = loader.dataset.targets.tolist()
else:
targets = torch.cat([target for _, target in loader], dim=0).tolist()
counter = Counter(targets)
for label_idx in sorted(counter.keys()):
config.summary_writer.add_scalar(tag, counter.get(label_idx), label_idx)
config.summary_writer.add_scalar("Data/Training Sample Size", len(train_loader.dataset)) # type: ignore[arg-type]
config.summary_writer.add_scalar("Data/Testing Sample Size", len(test_loader.dataset)) # type: ignore[arg-type]
add_label_distribution("Data/Training Label Distribution", train_loader)
add_label_distribution("Data/Testing Label Distribution", test_loader)
Functions¶
log_data_distribution¶
def log_data_distribution(
config: config.Config,
train_loader: torch.utils.data.dataloader.DataLoader,
test_loader: torch.utils.data.dataloader.DataLoader
) -> None
Calculate and log the data distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config | Config | training configuration and logging handle | None |
train_loader | DataLoader | training data loader | None |
test_loader | DataLoader | testing data loader | None |
View Source
def log_data_distribution(config: Config, train_loader: DataLoader, test_loader: DataLoader) -> None:
"""
Calculate and log the data distribution.
Args:
config (Config): training configuration and logging handle
train_loader (DataLoader): training data loader
test_loader (DataLoader): testing data loader
"""
if config.args.round != 0:
return
def add_label_distribution(tag, loader: DataLoader):
if hasattr(loader.dataset, "targets"):
targets = loader.dataset.targets.tolist()
else:
targets = torch.cat([target for _, target in loader], dim=0).tolist()
counter = Counter(targets)
for label_idx in sorted(counter.keys()):
config.summary_writer.add_scalar(tag, counter.get(label_idx), label_idx)
config.summary_writer.add_scalar("Data/Training Sample Size", len(train_loader.dataset)) # type: ignore[arg-type]
config.summary_writer.add_scalar("Data/Testing Sample Size", len(test_loader.dataset)) # type: ignore[arg-type]
add_label_distribution("Data/Training Label Distribution", train_loader)
add_label_distribution("Data/Testing Label Distribution", test_loader)