Skip to content

fl_server_core.utils.torch_serialization

Functions:

Name Description
from_torch

Serialize a PyTorch object into bytes.

from_torch_module

Serialize a PyTorch module into bytes.

from_torch_module_or_tensor

Serialize a PyTorch module or tensor into bytes.

from_torch_tensor

Serialize a PyTorch tensor into bytes.

is_torchscript_instance

Check if an object is an instance of torch.jit.ScriptModule or torch.jit.ScriptFunction.

to_torch

Convert a serialized PyTorch object back into a PyTorch object.

to_torch_module

Convert a serialized PyTorch module back into a PyTorch module.

to_torch_module_or_tensor

Convert a serialized PyTorch module or tensor back into a PyTorch module or tensor.

to_torch_tensor

Convert a serialized PyTorch tensor back into a PyTorch tensor.

Attributes:

Name Type Description
T

Attributes

T module-attribute

T = TypeVar('T')

Classes

Functions

from_torch

from_torch(obj: Any, *args, **kwargs) -> bytes

Serialize a PyTorch object into bytes.

Parameters:

Name Type Description Default

obj

Any

The PyTorch object to serialize.

required

*args

Additional arguments to pass to the torch.save or torch.jit.save function.

()

**kwargs

Additional keyword arguments to pass to the torch.save or torch.jit.save function.

{}

Returns:

Name Type Description
bytes bytes

The serialized PyTorch object as bytes.

Source code in fl_server_core/utils/torch_serialization.py
def from_torch(obj: Any, *args, **kwargs) -> bytes:
    """
    Serialize a PyTorch object into bytes.

    Args:
        obj (Any): The PyTorch object to serialize.
        *args: Additional arguments to pass to the `torch.save` or `torch.jit.save` function.
        **kwargs: Additional keyword arguments to pass to the `torch.save` or `torch.jit.save` function.

    Returns:
        bytes: The serialized PyTorch object as bytes.
    """
    buffer = BytesIO()
    if is_torchscript_instance(obj):
        torch.jit.save(obj, buffer, *args, **kwargs)
    else:
        torch.save(obj, buffer, *args, **kwargs)
    return buffer.getvalue()

from_torch_module

from_torch_module(obj: Module, *args, **kwargs) -> bytes

Serialize a PyTorch module into bytes.

Parameters:

Name Type Description Default

obj

Module

The PyTorch module to serialize.

required

*args

Additional arguments to pass to the torch.save or torch.jit.save function.

()

**kwargs

Additional keyword arguments to pass to the torch.save or torch.jit.save function.

{}

Returns:

Name Type Description
bytes bytes

The serialized PyTorch module as bytes.

Source code in fl_server_core/utils/torch_serialization.py
def from_torch_module(obj: torch.nn.Module, *args, **kwargs) -> bytes:
    """
    Serialize a PyTorch module into bytes.

    Args:
        obj (torch.nn.Module): The PyTorch module to serialize.
        *args: Additional arguments to pass to the `torch.save` or `torch.jit.save` function.
        **kwargs: Additional keyword arguments to pass to the `torch.save` or `torch.jit.save` function.

    Returns:
        bytes: The serialized PyTorch module as bytes.
    """
    return from_torch(obj, *args, **kwargs)

from_torch_module_or_tensor

from_torch_module_or_tensor(obj: Module | Tensor, *args, **kwargs) -> bytes

Serialize a PyTorch module or tensor into bytes.

Parameters:

Name Type Description Default

obj

Module | Tensor

The PyTorch module or tensor to serialize.

required

*args

Additional arguments to pass to the torch.save or torch.jit.save function.

()

**kwargs

Additional keyword arguments to pass to the torch.save or torch.jit.save function.

{}

Returns:

Name Type Description
bytes bytes

The serialized PyTorch module or tensor as bytes.

Source code in fl_server_core/utils/torch_serialization.py
def from_torch_module_or_tensor(obj: torch.nn.Module | torch.Tensor, *args, **kwargs) -> bytes:
    """
    Serialize a PyTorch module or tensor into bytes.

    Args:
        obj (torch.nn.Module | torch.Tensor): The PyTorch module or tensor to serialize.
        *args: Additional arguments to pass to the `torch.save` or `torch.jit.save` function.
        **kwargs: Additional keyword arguments to pass to the `torch.save` or `torch.jit.save` function.

    Returns:
        bytes: The serialized PyTorch module or tensor as bytes.
    """
    return from_torch(obj, *args, **kwargs)

from_torch_tensor

from_torch_tensor(obj: Tensor, *args, **kwargs) -> bytes

Serialize a PyTorch tensor into bytes.

Parameters:

Name Type Description Default

obj

Tensor

The PyTorch tensor to serialize.

required

*args

Additional arguments to pass to the torch.save or torch.jit.save function.

()

**kwargs

Additional keyword arguments to pass to the torch.save or torch.jit.save function.

{}

Returns:

Name Type Description
bytes bytes

The serialized PyTorch tensor as bytes.

Source code in fl_server_core/utils/torch_serialization.py
def from_torch_tensor(obj: torch.Tensor, *args, **kwargs) -> bytes:
    """
    Serialize a PyTorch tensor into bytes.

    Args:
        obj (torch.Tensor): The PyTorch tensor to serialize.
        *args: Additional arguments to pass to the `torch.save` or `torch.jit.save` function.
        **kwargs: Additional keyword arguments to pass to the `torch.save` or `torch.jit.save` function.

    Returns:
        bytes: The serialized PyTorch tensor as bytes.
    """
    return from_torch(obj, *args, **kwargs)

is_torchscript_instance

is_torchscript_instance(obj: Any) -> bool

Check if an object is an instance of torch.jit.ScriptModule or torch.jit.ScriptFunction.

Parameters:

Name Type Description Default

obj

Any

The object to check.

required

Returns:

Name Type Description
bool bool

True if the object is an instance of torch.jit.ScriptModule or torch.jit.ScriptFunction, otherwise False.

Source code in fl_server_core/utils/torch_serialization.py
def is_torchscript_instance(obj: Any) -> bool:
    """
    Check if an object is an instance of `torch.jit.ScriptModule` or `torch.jit.ScriptFunction`.

    Args:
        obj (Any): The object to check.

    Returns:
        bool: `True` if the object is an instance of `torch.jit.ScriptModule` or `torch.jit.ScriptFunction`,
            otherwise `False`.
    """
    return isinstance(obj, torch.jit.ScriptModule | torch.jit.ScriptFunction)

to_torch

to_torch(obj: Any, supported_types: Type[T] | Tuple[Type[T], ...])

Convert a serialized PyTorch object back into a PyTorch object.

Parameters:

Name Type Description Default

obj

Any

The serialized PyTorch object.

required

supported_types

Type[T] | Tuple[Type[T], ...]

The expected type or types of the PyTorch object.

required

Returns:

Name Type Description
T

The deserialized PyTorch object.

Raises:

Type Description
TorchDeserializationException

If there is an error during deserialization or if the deserialized object is not of the expected type.

Source code in fl_server_core/utils/torch_serialization.py
def to_torch(obj: Any, supported_types: Type[T] | Tuple[Type[T], ...]):
    """
    Convert a serialized PyTorch object back into a PyTorch object.

    Args:
        obj (Any): The serialized PyTorch object.
        supported_types (Type[T] | Tuple[Type[T], ...]): The expected type or types of the PyTorch object.

    Returns:
        T: The deserialized PyTorch object.

    Raises:
        TorchDeserializationException: If there is an error during deserialization or if the deserialized
            object is not of the expected type.
    """
    obj = BytesIO(obj) if isinstance(obj, Buffer) else obj
    try:
        # torch.load support torch.nn.Module as well as torchscript (but with user warning)
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore",
                message="'torch.load' received a zip file that looks like a TorchScript archive",
                category=UserWarning
            )
            t_obj = torch.load(obj, weights_only=False)
    except Exception as e:
        getLogger("fl.server").error(f"Error loading torch object: {e}")
        raise TorchDeserializationException("Error loading torch object") from e
    if isinstance(t_obj, supported_types):
        return t_obj
    getLogger("fl.server").error("Loaded torch object is not of expected type.")
    raise TorchDeserializationException("Loaded torch object is not of expected type.")

to_torch_module

to_torch_module(obj: Any) -> Module

Convert a serialized PyTorch module back into a PyTorch module.

Parameters:

Name Type Description Default

obj

Any

The serialized PyTorch module.

required

Returns:

Type Description
Module

torch.nn.Module: The deserialized PyTorch module.

Raises:

Type Description
TorchDeserializationException

If there is an error during deserialization or if the deserialized object is not a PyTorch module.

Source code in fl_server_core/utils/torch_serialization.py
def to_torch_module(obj: Any) -> torch.nn.Module:
    """
    Convert a serialized PyTorch module back into a PyTorch module.

    Args:
        obj (Any): The serialized PyTorch module.

    Returns:
        torch.nn.Module: The deserialized PyTorch module.

    Raises:
        TorchDeserializationException: If there is an error during deserialization or if the deserialized
            object is not a PyTorch module.
    """
    return to_torch(obj, torch.nn.Module)

to_torch_module_or_tensor

to_torch_module_or_tensor(obj: Any) -> Module | Tensor

Convert a serialized PyTorch module or tensor back into a PyTorch module or tensor.

Parameters:

Name Type Description Default

obj

Any

The serialized PyTorch module or tensor.

required

Returns:

Type Description
Module | Tensor

torch.nn.Module | torch.Tensor: The deserialized PyTorch module or tensor.

Raises:

Type Description
TorchDeserializationException

If there is an error during deserialization or if the deserialized object is not a PyTorch module or tensor.

Source code in fl_server_core/utils/torch_serialization.py
def to_torch_module_or_tensor(obj: Any) -> torch.nn.Module | torch.Tensor:
    """
    Convert a serialized PyTorch module or tensor back into a PyTorch module or tensor.

    Args:
        obj (Any): The serialized PyTorch module or tensor.

    Returns:
        torch.nn.Module | torch.Tensor: The deserialized PyTorch module or tensor.

    Raises:
        TorchDeserializationException: If there is an error during deserialization or if the deserialized
            object is not a PyTorch module or tensor.
    """
    return to_torch(obj, (torch.nn.Module, torch.Tensor))

to_torch_tensor

to_torch_tensor(obj: Any) -> Tensor

Convert a serialized PyTorch tensor back into a PyTorch tensor.

Parameters:

Name Type Description Default

obj

Any

The serialized PyTorch tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The deserialized PyTorch tensor.

Raises:

Type Description
TorchDeserializationException

If there is an error during deserialization or if the deserialized object is not a PyTorch tensor.

Source code in fl_server_core/utils/torch_serialization.py
def to_torch_tensor(obj: Any) -> torch.Tensor:
    """
    Convert a serialized PyTorch tensor back into a PyTorch tensor.

    Args:
        obj (Any): The serialized PyTorch tensor.

    Returns:
        torch.Tensor: The deserialized PyTorch tensor.

    Raises:
        TorchDeserializationException: If there is an error during deserialization or if the deserialized
            object is not a PyTorch tensor.
    """
    return to_torch(obj, torch.Tensor)