Skip to content

src.utils

Functions:

Name Description
log_data_distribution

Calculate and log the data distribution.

Functions

log_data_distribution

log_data_distribution(config: Config, train_loader: DataLoader, test_loader: DataLoader) -> None

Calculate and log the data distribution.

Parameters:

Name Type Description Default

config

Config

training configuration and logging handle

required

train_loader

DataLoader

training data loader

required

test_loader

DataLoader

testing data loader

required
Source code in src/utils.py
def log_data_distribution(config: Config, train_loader: DataLoader, test_loader: DataLoader) -> None:
    """
    Calculate and log the data distribution.

    Args:
        config (Config): training configuration and logging handle
        train_loader (DataLoader): training data loader
        test_loader (DataLoader): testing data loader
    """
    if config.args.round != 0:
        return

    def add_label_distribution(tag, loader: DataLoader):
        if hasattr(loader.dataset, "targets"):
            targets = loader.dataset.targets.tolist()
        else:
            targets = torch.cat([target for _, target in loader], dim=0).tolist()
        counter = Counter(targets)
        for label_idx in sorted(counter.keys()):
            config.summary_writer.add_scalar(tag, counter.get(label_idx), label_idx)

    config.summary_writer.add_scalar("Data/Training Sample Size", len(train_loader.dataset))  # type: ignore[arg-type]
    config.summary_writer.add_scalar("Data/Testing Sample Size", len(test_loader.dataset))  # type: ignore[arg-type]
    add_label_distribution("Data/Training Label Distribution", train_loader)
    add_label_distribution("Data/Testing Label Distribution", test_loader)