Fabric

class lightning.fabric.fabric.Fabric(*, accelerator='auto', strategy='auto', devices='auto', num_nodes=1, precision=None, plugins=None, callbacks=None, loggers=None)[source]

Bases: object

Fabric accelerates your PyTorch training or inference code with minimal changes required.

Key Features:
  • Automatic placement of models and data onto the device.

  • Automatic support for mixed and double precision (smaller memory footprint).

  • Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.).

  • Automated spawning of processes, no launch utilities required.

  • Multi-node support.

Parameters:
  • accelerator (Union[str, Accelerator]) – The hardware to run on. Possible choices are: "cpu", "cuda", "mps", "gpu", "tpu", "auto". Defaults to "auto".

  • strategy (Union[str, Strategy]) – Strategy for how to run across multiple devices. Possible choices are: "dp", "ddp", "ddp_spawn", "deepspeed", "fsdp", "auto". Defaults to "auto".

  • devices (Union[list[int], str, int]) – Number of devices to train on (int), which GPUs to train on (list or str), or "auto". The value applies per node. Defaults to "auto".

  • num_nodes (int) – Number of GPU nodes for distributed training. Defaults to 1.

  • precision (Union[Literal[64, 32, 16], Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], Literal['64', '32', '16', 'bf16'], None]) – Double precision ("64"), full precision ("32"), half precision AMP ("16-mixed"), or bfloat16 precision AMP ("bf16-mixed"). If None, defaults will be used based on the device.

  • plugins (Union[Precision, ClusterEnvironment, CheckpointIO, list[Union[Precision, ClusterEnvironment, CheckpointIO]], None]) – One or several custom plugins as a single plugin or list of plugins.

  • callbacks (Union[list[Any], Any, None]) – A single callback or a list of callbacks. A callback can contain any arbitrary methods that can be invoked through call() by the user.

  • loggers (Union[Logger, list[Logger], None]) – A single logger or a list of loggers. See log() for more information.

Example:

# Basic usage
fabric = Fabric(accelerator="gpu", devices=2)

# Set up model and optimizer
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    fabric.backward(loss)
    optimizer.step()
_setup_dataloader(dataloader, use_distributed_sampler=True, move_to_device=True)[source]

Set up a single dataloader for accelerated training.

Parameters:
  • dataloader (DataLoader) – The dataloader to accelerate.

  • use_distributed_sampler (bool) – If set True (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this argument to False.

  • move_to_device (bool) – If set True (default), moves the data returned by the dataloader automatically to the correct device. Set this to False and alternatively use to_device() manually on the returned data.

Return type:

DataLoader

Returns:

The wrapped dataloader.

all_gather(data, group=None, sync_grads=False)[source]

Gather tensors or collections of tensors from multiple processes.

This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.

Parameters:
  • data (Union[Tensor, dict, list, tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world).

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all_gather operation

Return type:

Union[Tensor, dict, list, tuple]

Returns:

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape. For the special case where world_size is 1, no additional dimension is added to the tensor(s).

all_reduce(data, group=None, reduce_op='mean')[source]

Reduce tensors or collections of tensors from multiple processes.

The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.

Parameters:
  • data (Union[Tensor, dict, list, tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof. Tensor will be modified in-place.

  • group (Optional[Any]) – the process group to reduce results across. Defaults to all processes (world).

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp. Some strategies may limit the choices here.

Return type:

Union[Tensor, dict, list, tuple]

Returns:

A tensor of the same shape as the input with values reduced pointwise across processes. The same is applied to tensors in a collection if a collection is given as input.

autocast()[source]

A context manager to automatically convert operations for the chosen precision.

Use this only if the forward method of your model does not cover all operations you wish to run with the chosen precision setting.

Return type:

AbstractContextManager

backward(tensor, *args, model=None, **kwargs)[source]

Replaces loss.backward() in your training loop. Handles precision automatically for you.

Parameters:
  • tensor (Tensor) – The tensor (loss) to back-propagate gradients from.

  • *args (Any) – Optional positional arguments passed to the underlying backward function.

  • model (Optional[_FabricModule]) – Optional model instance for plugins that require the model for backward(). Required when using DeepSpeed strategy with multiple models.

  • **kwargs (Any) – Optional named keyword arguments passed to the underlying backward function.

Return type:

None

Note

When using strategy="deepspeed" and multiple models were set up, it is required to pass in the model as argument here.

Example:

loss = criterion(output, target)
fabric.backward(loss)

# With DeepSpeed and multiple models
fabric.backward(loss, model=model)
barrier(name=None)[source]

Wait for all processes to enter this call.

Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. This method needs to be called on all processes. Failing to do so will cause your program to stall forever.

Return type:

None

broadcast(obj, src=0)[source]

Send a tensor from one process to all others.

This method needs to be called on all processes. Failing to do so will cause your program to stall forever.

Parameters:
  • obj (TypeVar(TBroadcast)) – The object to broadcast to all other members. Any serializable object is supported, but it is most efficient with the object being a Tensor.

  • src (int) – The (global) rank of the process that should send the data to all others.

Return type:

TypeVar(TBroadcast)

Returns:

The transferred data, the same value on every rank.

call(hook_name, *args, **kwargs)[source]

Trigger the callback methods with the given name and arguments.

Not all objects registered via Fabric(callbacks=...) must implement a method with the given name. The ones that have a matching method name will get called.

Parameters:
  • hook_name (str) – The name of the callback method.

  • *args (Any) – Optional positional arguments that get passed down to the callback method.

  • **kwargs (Any) – Optional keyword arguments that get passed down to the callback method.

Return type:

None

Example:

class MyCallback:
    def on_train_epoch_end(self, results):
        ...

fabric = Fabric(callbacks=[MyCallback()])
fabric.call("on_train_epoch_end", results={...})
clip_gradients(module, optimizer, clip_val=None, max_norm=None, norm_type=2.0, error_if_nonfinite=True)[source]

Clip the gradients of the model to a given max value or max norm.

Parameters:
  • module (Union[Module, _FabricModule]) – The module whose parameters should be clipped.

  • optimizer (Union[Optimizer, _FabricOptimizer]) – The optimizer referencing the parameters to be clipped.

  • clip_val (Union[float, int, None]) – If passed, gradients will be clipped to this value. Cannot be used together with max_norm.

  • max_norm (Union[float, int, None]) – If passed, clips the gradients in such a way that the p-norm of the resulting parameters is no larger than the given value. Cannot be used together with clip_val.

  • norm_type (Union[float, int]) – The type of norm if max_norm was passed. Can be 'inf' for infinity norm. Defaults to 2-norm.

  • error_if_nonfinite (bool) – An error is raised if the total norm of the gradients is NaN or infinite. Only applies when max_norm is used.

Return type:

Optional[Tensor]

Returns:

The total norm of the gradients (before clipping was applied) as a scalar tensor if max_norm was passed, otherwise None.

Raises:

ValueError – If both clip_val and max_norm are provided, or if neither is provided.

Example:

# Clip by value
fabric.clip_gradients(model, optimizer, clip_val=1.0)

# Clip by norm
total_norm = fabric.clip_gradients(model, optimizer, max_norm=1.0)
init_module(empty_init=None)[source]

Instantiate the model and its parameters under this context manager to reduce peak memory usage.

The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options. Set this to True if you are loading a checkpoint into a large model.

Return type:

AbstractContextManager

init_tensor()[source]

Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.

Return type:

AbstractContextManager

launch(function=<function _do_nothing>, *args, **kwargs)[source]

Launch and initialize all the processes needed for distributed execution.

Parameters:
  • function (Callable[[Fabric], Any]) – Optional function to launch when using a spawn/fork-based strategy, for example, when using the XLA strategy (accelerator="tpu"). The function must accept at least one argument, to which the Fabric object itself will be passed. If not provided, only process initialization will be performed.

  • *args (Any) – Optional positional arguments to be passed to the function.

  • **kwargs (Any) – Optional keyword arguments to be passed to the function.

Return type:

Any

Returns:

Returns the output of the function that ran in worker process with rank 0.

Raises:
  • RuntimeError – If called when script was launched through the CLI.

  • TypeError – If function is provided but not callable, or if function doesn’t accept required arguments.

Note

The launch() method should only be used if you intend to specify accelerator, devices, and so on in the code (programmatically). If you are launching with the Lightning CLI, fabric run ..., remove launch() from your code.

The launch() is a no-op when called multiple times and no function is passed in.

Example:

def train_function(fabric):
    model, optimizer = fabric.setup(model, optimizer)
    # ... training code ...

fabric = Fabric(accelerator="tpu", devices=8)
fabric.launch(train_function)
load(path, state=None, strict=True)[source]

Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)

How and which processes load gets determined by the strategy. This method must be called on all processes!

Parameters:
  • path (Union[str, Path]) – A path to where the file is located.

  • state (Optional[dict[str, Union[Module, Optimizer, Any]]]) – A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full.

  • strict (bool) – Whether to enforce that the keys in state match the keys in the checkpoint.

Return type:

dict[str, Any]

Returns:

The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned.

Example:

# Load full checkpoint
checkpoint = fabric.load("checkpoint.pth")

# Load into existing objects
state = {"model": model, "optimizer": optimizer}
remainder = fabric.load("checkpoint.pth", state)
epoch = remainder.get("epoch", 0)
load_raw(path, obj, strict=True)[source]

Load the state of a module or optimizer from a single state-dict file.

Use this for loading a raw PyTorch model checkpoint created without Fabric. This is conceptually equivalent to obj.load_state_dict(torch.load(path)), but is agnostic to the strategy being used.

Parameters:
Return type:

None

log(name, value, step=None)[source]

Log a scalar to all loggers that were added to Fabric.

Parameters:
  • name (str) – The name of the metric to log.

  • value (Any) – The metric value to collect. If the value is a torch.Tensor, it gets detached from the graph automatically.

  • step (Optional[int]) – Optional step number. Most Logger implementations auto-increment the step value by one with every log call. You can specify your own value here.

Return type:

None

log_dict(metrics, step=None)[source]

Log multiple scalars at once to all loggers that were added to Fabric.

Parameters:
  • metrics (Mapping[str, Any]) – A dictionary where the key is the name of the metric and the value the scalar to be logged. Any torch.Tensor in the dictionary get detached from the graph automatically.

  • step (Optional[int]) – Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here.

Return type:

None

no_backward_sync(module, enabled=True)[source]

Skip gradient synchronization during backward to avoid redundant communication overhead.

Use this context manager when performing gradient accumulation to speed up training with multiple devices. Both the model’s .forward() and the fabric.backward() call need to run under this context.

Parameters:
  • module (_FabricModule) – The module for which to control the gradient synchronization. Must be a module that was set up with setup() or setup_module().

  • enabled (bool) – Whether the context manager is enabled or not. True means skip the sync, False means do not skip.

Return type:

AbstractContextManager

Returns:

A context manager that controls gradient synchronization.

Raises:

TypeError – If the module was not set up with Fabric first.

Note

For strategies that don’t support gradient sync control, a warning is emitted and the context manager becomes a no-op. For single-device strategies, it is always a no-op.

Example:

# Accumulate gradients over 8 batches
for batch_idx, batch in enumerate(dataloader):
    with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
        output = model(batch)
        loss = criterion(output, target)
        fabric.backward(loss)

    if batch_idx % 8 == 0:
        optimizer.step()
        optimizer.zero_grad()
print(*args, **kwargs)[source]

Print something only on the first process. If running on multiple machines, it will print from the first process in each machine.

Arguments passed to this method are forwarded to the Python built-in print() function.

Return type:

None

rank_zero_first(local=False)[source]

The code block under this context manager gets executed first on the main process (rank 0) and only when completed, the other processes get to run the code in parallel.

Parameters:

local (bool) – Set this to True if the local rank should be the one going first. Useful if you are downloading data and the filesystem isn’t shared between the nodes.

Return type:

Generator

Example:

with fabric.rank_zero_first():
    dataset = MNIST("datasets/", download=True)
run(*args, **kwargs)[source]

All the code inside this run method gets accelerated by Fabric.

You can pass arbitrary arguments to this function when overriding it.

Return type:

Any

save(path, state, filter=None)[source]

Save checkpoint contents to a file.

How and which processes save gets determined by the strategy. For example, the ddp strategy saves checkpoints only on process 0, while the fsdp strategy saves files from every rank. This method must be called on all processes!

Parameters:
  • path (Union[str, Path]) – A path to where the file(s) should be saved.

  • state (dict[str, Union[Module, Optimizer, Any]]) – A dictionary with contents to be saved. If the dict contains modules or optimizers, their state-dict will be retrieved and converted automatically.

  • filter (Optional[dict[str, Callable[[str, Any], bool]]]) – An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (True) or filtered out (False). Each filter key should match a state key, where its filter will be applied to the state_dict generated.

Raises:
  • TypeError – If filter is not a dictionary or contains non-callable values.

  • ValueError – If filter keys don’t match state keys.

Return type:

None

Example:

state = {"model": model, "optimizer": optimizer, "epoch": epoch}
fabric.save("checkpoint.pth", state)

# With filter
def param_filter(name, param):
    return "bias" not in name  # Save only non-bias parameters

fabric.save("checkpoint.pth", state, filter={"model": param_filter})
static seed_everything(seed=None, workers=None, verbose=True)[source]

Helper function to seed everything without explicitly importing Lightning.

See seed_everything() for more details.

Return type:

int

setup(module, *optimizers, scheduler=None, move_to_device=True, _reapply_compile=True)[source]

Set up a model and its optimizers for accelerated training.

Parameters:
  • module (Module) – A torch.nn.Module to set up.

  • *optimizers (Optimizer) – The optimizer(s) to set up. Can be zero or more optimizers.

  • scheduler (Optional[_LRScheduler]) – An optional learning rate scheduler to set up. Must be provided after optimizers if used.

  • move_to_device (bool) – If set True (default), moves the model to the correct device. Set this to False and alternatively use to_device() manually.

  • _reapply_compile (bool) – If True (default), and the model was torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False if compiling DDP/FSDP is causing issues.

Return type:

Any

Returns:

If no optimizers are passed, returns the wrapped module. If optimizers are passed, returns a tuple containing the wrapped module and optimizers, and optionally the scheduler if provided, in the same order they were passed in.

Note

For certain strategies like FSDP, you may need to set up the model first using setup_module(), then create the optimizer, and finally set up the optimizer using setup_optimizers().

Example:

# Basic usage
model, optimizer = fabric.setup(model, optimizer)

# With multiple optimizers and scheduler
model, opt1, opt2, scheduler = fabric.setup(model, opt1, opt2, scheduler=scheduler)

# Model only
model = fabric.setup(model)
setup_dataloaders(*dataloaders, use_distributed_sampler=True, move_to_device=True)[source]

Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one.

Parameters:
  • *dataloaders (DataLoader) – One or more PyTorch DataLoader instances to set up.

  • use_distributed_sampler (bool) – If set True (default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this argument to False.

  • move_to_device (bool) – If set True (default), moves the data returned by the dataloader(s) automatically to the correct device. Set this to False and alternatively use to_device() manually on the returned data.

Return type:

Union[DataLoader, list[DataLoader]]

Returns:

If a single dataloader is passed, returns the wrapped dataloader. If multiple dataloaders are passed, returns a list of wrapped dataloaders in the same order they were passed in.

Example:

# Single dataloader
train_loader = fabric.setup_dataloaders(train_loader)

# Multiple dataloaders
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
setup_module(module, move_to_device=True, _reapply_compile=True)[source]

Set up a model for accelerated training or inference.

This is the same as calling .setup(model) with no optimizers. It is useful for inference or for certain strategies like FSDP that require setting up the module before the optimizer can be created and set up. See also setup_optimizers().

Parameters:
  • module (Module) – A torch.nn.Module to set up.

  • move_to_device (bool) – If set True (default), moves the model to the correct device. Set this to False and alternatively use to_device() manually.

  • _reapply_compile (bool) – If True (default), and the model was torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False if compiling DDP/FSDP is causing issues.

Return type:

_FabricModule

Returns:

The wrapped model as a _FabricModule.

Example:

# Set up model first (useful for FSDP)
model = fabric.setup_module(model)

# Then create and set up optimizer
optimizer = torch.optim.Adam(model.parameters())
optimizer = fabric.setup_optimizers(optimizer)
setup_optimizers(*optimizers)[source]

Set up one or more optimizers for accelerated training.

Some strategies do not allow setting up model and optimizer independently. For them, you should call .setup(model, optimizer, ...) instead to jointly set them up.

Parameters:

*optimizers (Optimizer) – One or more optimizers to set up. Must provide at least one optimizer.

Return type:

Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]

Returns:

If a single optimizer is passed, returns the wrapped optimizer. If multiple optimizers are passed, returns a tuple of wrapped optimizers in the same order they were passed in.

Raises:

RuntimeError – If using DeepSpeed or XLA strategies, which require joint model-optimizer setup.

Note

This method cannot be used with DeepSpeed or XLA strategies. Use setup() instead for those strategies.

Example:

# Single optimizer
optimizer = fabric.setup_optimizers(optimizer)

# Multiple optimizers
opt1, opt2 = fabric.setup_optimizers(opt1, opt2)
sharded_model()[source]

Instantiate a model under this context manager to prepare it for model-parallel sharding. :rtype: AbstractContextManager

Deprecated since version This: context manager is deprecated in favor of init_module(), use it instead.

to_device(obj)[source]

Move a torch.nn.Module or a collection of tensors to the current device, if it is not already on that device.

Parameters:

obj (Union[Module, Tensor, Any]) – An object to move to the device. Can be an instance of torch.nn.Module, a tensor, or a (nested) collection of tensors (e.g., a dictionary).

Return type:

Union[Module, Tensor, Any]

Returns:

A reference to the object that was moved to the new device.

property device: device

The current device this process runs on.

Use this to create tensors directly on the device if needed.

property global_rank: int

The global index of the current process across all devices and nodes.

property is_global_zero: bool

Whether this rank is rank zero.

property local_rank: int

The index of the current process among the processes running on the local node.

property logger: Logger

Returns the first logger in the list passed to Fabric, which is considered the main logger.

property loggers: list[lightning.fabric.loggers.logger.Logger]

Returns all loggers passed to Fabric.

property node_rank: int

The index of the current node.

property world_size: int

The total number of processes running across all devices and nodes.