Skip to content

src.config

Classes:

Name Description
Config

Training configuration class including logging (summary writer) handle.

Classes

Config

Bases: ContextDecorator


              flowchart TD
              src.config.Config[Config]

              

              click src.config.Config href "" "src.config.Config"
            

Training configuration class including logging (summary writer) handle.

Example:

with Config(args) as config:
    trained_model, metrics, sample_size = train(model, config)
    # ...

Methods:

Name Description
__enter__
__exit__
__init__
get_global_training_epoch

Get the global training epoch.

Attributes:

Name Type Description
args
batch_size
device
epochs
log_interval
logger
loss
optimizer
scheduler
summary_writer
Source code in src/config.py
class Config(ContextDecorator):
    """
    Training configuration class including logging (summary writer) handle.

    Example:

    ```python
    with Config(args) as config:
        trained_model, metrics, sample_size = train(model, config)
        # ...
    ```
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    epochs = 2
    batch_size = 256
    optimizer = partial(SGD, lr=0.05, momentum=0.9, nesterov=True, weight_decay=0.0001)
    scheduler = partial(StepLR, step_size=5, gamma=0.95)
    loss = CrossEntropyLoss(reduction="mean")
    logger = logging.getLogger("fl.client")
    log_interval = 20

    def __init__(self, args: Namespace) -> None:
        self.args = args
        self.summary_writer = SummaryWriter(f"s3://trainings/{self.args.training_id}/{self.args.client_id}")

    def get_global_training_epoch(self, local_epoch: int) -> int:
        """
        Get the global training epoch.

        Calculates and returns the global training epoch based on the local epoch and the training round.

        Note:

        - `self.args.round` (training round) is zero based
        - testing rounds are not considered or included

        Example:

        Consider a scenario where the client returns the model to the server after every three local epochs.
        If we are in the second training round and the first local epoch, the global training epoch would be
        calculated as 1 + 3*2, which equals 7.

        Args:
            local_epoch (int): local training epoch

        Returns:
            int: global training epoch
        """
        # NOTE:
        # - self.args.round is zero based
        # - testing rounds are not considered
        return max(local_epoch, 0) + self.epochs * self.args.round

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.summary_writer.close()

Attributes

args instance-attribute
args = args
batch_size class-attribute instance-attribute
batch_size = 256
device class-attribute instance-attribute
device = device('cuda' if is_available() else 'cpu')
epochs class-attribute instance-attribute
epochs = 2
log_interval class-attribute instance-attribute
log_interval = 20
logger class-attribute instance-attribute
logger = getLogger('fl.client')
loss class-attribute instance-attribute
loss = CrossEntropyLoss(reduction='mean')
optimizer class-attribute instance-attribute
optimizer = partial(SGD, lr=0.05, momentum=0.9, nesterov=True, weight_decay=0.0001)
scheduler class-attribute instance-attribute
scheduler = partial(StepLR, step_size=5, gamma=0.95)
summary_writer instance-attribute
summary_writer = SummaryWriter(f's3://trainings/{training_id}/{client_id}')

Functions

__enter__
__enter__()
Source code in src/config.py
def __enter__(self):
    return self
__exit__
__exit__(*args)
Source code in src/config.py
def __exit__(self, *args):
    self.summary_writer.close()
__init__
__init__(args: Namespace) -> None
Source code in src/config.py
def __init__(self, args: Namespace) -> None:
    self.args = args
    self.summary_writer = SummaryWriter(f"s3://trainings/{self.args.training_id}/{self.args.client_id}")
get_global_training_epoch
get_global_training_epoch(local_epoch: int) -> int

Get the global training epoch.

Calculates and returns the global training epoch based on the local epoch and the training round.

Note:

  • self.args.round (training round) is zero based
  • testing rounds are not considered or included

Example:

Consider a scenario where the client returns the model to the server after every three local epochs. If we are in the second training round and the first local epoch, the global training epoch would be calculated as 1 + 3*2, which equals 7.

Parameters:

Name Type Description Default
local_epoch
int

local training epoch

required

Returns:

Name Type Description
int int

global training epoch

Source code in src/config.py
def get_global_training_epoch(self, local_epoch: int) -> int:
    """
    Get the global training epoch.

    Calculates and returns the global training epoch based on the local epoch and the training round.

    Note:

    - `self.args.round` (training round) is zero based
    - testing rounds are not considered or included

    Example:

    Consider a scenario where the client returns the model to the server after every three local epochs.
    If we are in the second training round and the first local epoch, the global training epoch would be
    calculated as 1 + 3*2, which equals 7.

    Args:
        local_epoch (int): local training epoch

    Returns:
        int: global training epoch
    """
    # NOTE:
    # - self.args.round is zero based
    # - testing rounds are not considered
    return max(local_epoch, 0) + self.epochs * self.args.round