Fabric¶
- class lightning.fabric.fabric.Fabric(*, accelerator='auto', strategy='auto', devices='auto', num_nodes=1, precision=None, plugins=None, callbacks=None, loggers=None)[source]¶
Bases:
object
Fabric accelerates your PyTorch training or inference code with minimal changes required.
- Key Features:
Automatic placement of models and data onto the device.
Automatic support for mixed and double precision (smaller memory footprint).
Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.).
Automated spawning of processes, no launch utilities required.
Multi-node support.
- Parameters:
accelerator¶ (
Union
[str
,Accelerator
]) – The hardware to run on. Possible choices are:"cpu"
,"cuda"
,"mps"
,"gpu"
,"tpu"
,"auto"
. Defaults to"auto"
.strategy¶ (
Union
[str
,Strategy
]) – Strategy for how to run across multiple devices. Possible choices are:"dp"
,"ddp"
,"ddp_spawn"
,"deepspeed"
,"fsdp"
,"auto"
. Defaults to"auto"
.devices¶ (
Union
[list
[int
],str
,int
]) – Number of devices to train on (int
), which GPUs to train on (list
orstr
), or"auto"
. The value applies per node. Defaults to"auto"
.num_nodes¶ (
int
) – Number of GPU nodes for distributed training. Defaults to1
.precision¶ (
Union
[Literal
[64
,32
,16
],Literal
['transformer-engine'
,'transformer-engine-float16'
,'16-true'
,'16-mixed'
,'bf16-true'
,'bf16-mixed'
,'32-true'
,'64-true'
],Literal
['64'
,'32'
,'16'
,'bf16'
],None
]) – Double precision ("64"
), full precision ("32"
), half precision AMP ("16-mixed"
), or bfloat16 precision AMP ("bf16-mixed"
). IfNone
, defaults will be used based on the device.plugins¶ (
Union
[Precision
,ClusterEnvironment
,CheckpointIO
,list
[Union
[Precision
,ClusterEnvironment
,CheckpointIO
]],None
]) – One or several custom plugins as a single plugin or list of plugins.callbacks¶ (
Union
[list
[Any
],Any
,None
]) – A single callback or a list of callbacks. A callback can contain any arbitrary methods that can be invoked throughcall()
by the user.loggers¶ (
Union
[Logger
,list
[Logger
],None
]) – A single logger or a list of loggers. Seelog()
for more information.
Example:
# Basic usage fabric = Fabric(accelerator="gpu", devices=2) # Set up model and optimizer model = MyModel() optimizer = torch.optim.Adam(model.parameters()) model, optimizer = fabric.setup(model, optimizer) # Training loop for batch in dataloader: optimizer.zero_grad() loss = model(batch) fabric.backward(loss) optimizer.step()
- _setup_dataloader(dataloader, use_distributed_sampler=True, move_to_device=True)[source]¶
Set up a single dataloader for accelerated training.
- Parameters:
dataloader¶ (
DataLoader
) – The dataloader to accelerate.use_distributed_sampler¶ (
bool
) – If setTrue
(default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this argument toFalse
.move_to_device¶ (
bool
) – If setTrue
(default), moves the data returned by the dataloader automatically to the correct device. Set this toFalse
and alternatively useto_device()
manually on the returned data.
- Return type:
- Returns:
The wrapped dataloader.
- all_gather(data, group=None, sync_grads=False)[source]¶
Gather tensors or collections of tensors from multiple processes.
This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.
- Parameters:
data¶ (
Union
[Tensor
,dict
,list
,tuple
]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.group¶ (
Optional
[Any
]) – the process group to gather results from. Defaults to all processes (world).sync_grads¶ (
bool
) – flag that allows users to synchronize gradients for theall_gather
operation
- Return type:
- Returns:
A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape. For the special case where world_size is 1, no additional dimension is added to the tensor(s).
- all_reduce(data, group=None, reduce_op='mean')[source]¶
Reduce tensors or collections of tensors from multiple processes.
The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.
- Parameters:
data¶ (
Union
[Tensor
,dict
,list
,tuple
]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof. Tensor will be modified in-place.group¶ (
Optional
[Any
]) – the process group to reduce results across. Defaults to all processes (world).reduce_op¶ (
Union
[ReduceOp
,str
,None
]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp. Some strategies may limit the choices here.
- Return type:
- Returns:
A tensor of the same shape as the input with values reduced pointwise across processes. The same is applied to tensors in a collection if a collection is given as input.
- autocast()[source]¶
A context manager to automatically convert operations for the chosen precision.
Use this only if the forward method of your model does not cover all operations you wish to run with the chosen precision setting.
- Return type:
- backward(tensor, *args, model=None, **kwargs)[source]¶
Replaces
loss.backward()
in your training loop. Handles precision automatically for you.- Parameters:
tensor¶ (
Tensor
) – The tensor (loss) to back-propagate gradients from.*args¶ (
Any
) – Optional positional arguments passed to the underlying backward function.model¶ (
Optional
[_FabricModule
]) – Optional model instance for plugins that require the model for backward(). Required when using DeepSpeed strategy with multiple models.**kwargs¶ (
Any
) – Optional named keyword arguments passed to the underlying backward function.
- Return type:
Note
When using
strategy="deepspeed"
and multiple models were set up, it is required to pass in the model as argument here.Example:
loss = criterion(output, target) fabric.backward(loss) # With DeepSpeed and multiple models fabric.backward(loss, model=model)
- barrier(name=None)[source]¶
Wait for all processes to enter this call.
Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
- Return type:
- broadcast(obj, src=0)[source]¶
Send a tensor from one process to all others.
This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
- Parameters:
- Return type:
TypeVar
(TBroadcast
)- Returns:
The transferred data, the same value on every rank.
- call(hook_name, *args, **kwargs)[source]¶
Trigger the callback methods with the given name and arguments.
Not all objects registered via
Fabric(callbacks=...)
must implement a method with the given name. The ones that have a matching method name will get called.- Parameters:
- Return type:
Example:
class MyCallback: def on_train_epoch_end(self, results): ... fabric = Fabric(callbacks=[MyCallback()]) fabric.call("on_train_epoch_end", results={...})
- clip_gradients(module, optimizer, clip_val=None, max_norm=None, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip the gradients of the model to a given max value or max norm.
- Parameters:
module¶ (
Union
[Module
,_FabricModule
]) – The module whose parameters should be clipped.optimizer¶ (
Union
[Optimizer
,_FabricOptimizer
]) – The optimizer referencing the parameters to be clipped.clip_val¶ (
Union
[float
,int
,None
]) – If passed, gradients will be clipped to this value. Cannot be used together withmax_norm
.max_norm¶ (
Union
[float
,int
,None
]) – If passed, clips the gradients in such a way that the p-norm of the resulting parameters is no larger than the given value. Cannot be used together withclip_val
.norm_type¶ (
Union
[float
,int
]) – The type of norm ifmax_norm
was passed. Can be'inf'
for infinity norm. Defaults to 2-norm.error_if_nonfinite¶ (
bool
) – An error is raised if the total norm of the gradients is NaN or infinite. Only applies whenmax_norm
is used.
- Return type:
- Returns:
The total norm of the gradients (before clipping was applied) as a scalar tensor if
max_norm
was passed, otherwiseNone
.- Raises:
ValueError – If both
clip_val
andmax_norm
are provided, or if neither is provided.
Example:
# Clip by value fabric.clip_gradients(model, optimizer, clip_val=1.0) # Clip by norm total_norm = fabric.clip_gradients(model, optimizer, max_norm=1.0)
- init_module(empty_init=None)[source]¶
Instantiate the model and its parameters under this context manager to reduce peak memory usage.
The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
- init_tensor()[source]¶
Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.
- Return type:
- launch(function=<function _do_nothing>, *args, **kwargs)[source]¶
Launch and initialize all the processes needed for distributed execution.
- Parameters:
function¶ (
Callable
[[Fabric
],Any
]) – Optional function to launch when using a spawn/fork-based strategy, for example, when using the XLA strategy (accelerator="tpu"
). The function must accept at least one argument, to which the Fabric object itself will be passed. If not provided, only process initialization will be performed.*args¶ (
Any
) – Optional positional arguments to be passed to the function.**kwargs¶ (
Any
) – Optional keyword arguments to be passed to the function.
- Return type:
- Returns:
Returns the output of the function that ran in worker process with rank 0.
- Raises:
RuntimeError – If called when script was launched through the CLI.
TypeError – If function is provided but not callable, or if function doesn’t accept required arguments.
Note
The
launch()
method should only be used if you intend to specify accelerator, devices, and so on in the code (programmatically). If you are launching with the Lightning CLI,fabric run ...
, removelaunch()
from your code.The
launch()
is a no-op when called multiple times and no function is passed in.Example:
def train_function(fabric): model, optimizer = fabric.setup(model, optimizer) # ... training code ... fabric = Fabric(accelerator="tpu", devices=8) fabric.launch(train_function)
- load(path, state=None, strict=True)[source]¶
Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
How and which processes load gets determined by the strategy. This method must be called on all processes!
- Parameters:
path¶ (
Union
[str
,Path
]) – A path to where the file is located.state¶ (
Optional
[dict
[str
,Union
[Module
,Optimizer
,Any
]]]) – A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full.strict¶ (
bool
) – Whether to enforce that the keys in state match the keys in the checkpoint.
- Return type:
- Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned.
Example:
# Load full checkpoint checkpoint = fabric.load("checkpoint.pth") # Load into existing objects state = {"model": model, "optimizer": optimizer} remainder = fabric.load("checkpoint.pth", state) epoch = remainder.get("epoch", 0)
- load_raw(path, obj, strict=True)[source]¶
Load the state of a module or optimizer from a single state-dict file.
Use this for loading a raw PyTorch model checkpoint created without Fabric. This is conceptually equivalent to
obj.load_state_dict(torch.load(path))
, but is agnostic to the strategy being used.
- log(name, value, step=None)[source]¶
Log a scalar to all loggers that were added to Fabric.
- Parameters:
- Return type:
- log_dict(metrics, step=None)[source]¶
Log multiple scalars at once to all loggers that were added to Fabric.
- Parameters:
metrics¶ (
Mapping
[str
,Any
]) – A dictionary where the key is the name of the metric and the value the scalar to be logged. Anytorch.Tensor
in the dictionary get detached from the graph automatically.step¶ (
Optional
[int
]) – Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here.
- Return type:
- no_backward_sync(module, enabled=True)[source]¶
Skip gradient synchronization during backward to avoid redundant communication overhead.
Use this context manager when performing gradient accumulation to speed up training with multiple devices. Both the model’s
.forward()
and thefabric.backward()
call need to run under this context.- Parameters:
module¶ (
_FabricModule
) – The module for which to control the gradient synchronization. Must be a module that was set up withsetup()
orsetup_module()
.enabled¶ (
bool
) – Whether the context manager is enabled or not.True
means skip the sync,False
means do not skip.
- Return type:
- Returns:
A context manager that controls gradient synchronization.
- Raises:
TypeError – If the module was not set up with Fabric first.
Note
For strategies that don’t support gradient sync control, a warning is emitted and the context manager becomes a no-op. For single-device strategies, it is always a no-op.
Example:
# Accumulate gradients over 8 batches for batch_idx, batch in enumerate(dataloader): with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)): output = model(batch) loss = criterion(output, target) fabric.backward(loss) if batch_idx % 8 == 0: optimizer.step() optimizer.zero_grad()
- print(*args, **kwargs)[source]¶
Print something only on the first process. If running on multiple machines, it will print from the first process in each machine.
Arguments passed to this method are forwarded to the Python built-in
print()
function.- Return type:
- rank_zero_first(local=False)[source]¶
The code block under this context manager gets executed first on the main process (rank 0) and only when completed, the other processes get to run the code in parallel.
- Parameters:
local¶ (
bool
) – Set this toTrue
if the local rank should be the one going first. Useful if you are downloading data and the filesystem isn’t shared between the nodes.- Return type:
Example:
with fabric.rank_zero_first(): dataset = MNIST("datasets/", download=True)
- run(*args, **kwargs)[source]¶
All the code inside this run method gets accelerated by Fabric.
You can pass arbitrary arguments to this function when overriding it.
- Return type:
- save(path, state, filter=None)[source]¶
Save checkpoint contents to a file.
How and which processes save gets determined by the strategy. For example, the ddp strategy saves checkpoints only on process 0, while the fsdp strategy saves files from every rank. This method must be called on all processes!
- Parameters:
path¶ (
Union
[str
,Path
]) – A path to where the file(s) should be saved.state¶ (
dict
[str
,Union
[Module
,Optimizer
,Any
]]) – A dictionary with contents to be saved. If the dict contains modules or optimizers, their state-dict will be retrieved and converted automatically.filter¶ (
Optional
[dict
[str
,Callable
[[str
,Any
],bool
]]]) – An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (True
) or filtered out (False
). Each filter key should match a state key, where its filter will be applied to thestate_dict
generated.
- Raises:
TypeError – If filter is not a dictionary or contains non-callable values.
ValueError – If filter keys don’t match state keys.
- Return type:
Example:
state = {"model": model, "optimizer": optimizer, "epoch": epoch} fabric.save("checkpoint.pth", state) # With filter def param_filter(name, param): return "bias" not in name # Save only non-bias parameters fabric.save("checkpoint.pth", state, filter={"model": param_filter})
- static seed_everything(seed=None, workers=None, verbose=True)[source]¶
Helper function to seed everything without explicitly importing Lightning.
See
seed_everything()
for more details.- Return type:
- setup(module, *optimizers, scheduler=None, move_to_device=True, _reapply_compile=True)[source]¶
Set up a model and its optimizers for accelerated training.
- Parameters:
module¶ (
Module
) – Atorch.nn.Module
to set up.*optimizers¶ (
Optimizer
) – The optimizer(s) to set up. Can be zero or more optimizers.scheduler¶ (
Optional
[_LRScheduler
]) – An optional learning rate scheduler to set up. Must be provided after optimizers if used.move_to_device¶ (
bool
) – If setTrue
(default), moves the model to the correct device. Set this toFalse
and alternatively useto_device()
manually._reapply_compile¶ (
bool
) – IfTrue
(default), and the model wastorch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False
if compiling DDP/FSDP is causing issues.
- Return type:
- Returns:
If no optimizers are passed, returns the wrapped module. If optimizers are passed, returns a tuple containing the wrapped module and optimizers, and optionally the scheduler if provided, in the same order they were passed in.
Note
For certain strategies like FSDP, you may need to set up the model first using
setup_module()
, then create the optimizer, and finally set up the optimizer usingsetup_optimizers()
.Example:
# Basic usage model, optimizer = fabric.setup(model, optimizer) # With multiple optimizers and scheduler model, opt1, opt2, scheduler = fabric.setup(model, opt1, opt2, scheduler=scheduler) # Model only model = fabric.setup(model)
- setup_dataloaders(*dataloaders, use_distributed_sampler=True, move_to_device=True)[source]¶
Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one.
- Parameters:
*dataloaders¶ (
DataLoader
) – One or more PyTorchDataLoader
instances to set up.use_distributed_sampler¶ (
bool
) – If setTrue
(default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this argument toFalse
.move_to_device¶ (
bool
) – If setTrue
(default), moves the data returned by the dataloader(s) automatically to the correct device. Set this toFalse
and alternatively useto_device()
manually on the returned data.
- Return type:
- Returns:
If a single dataloader is passed, returns the wrapped dataloader. If multiple dataloaders are passed, returns a list of wrapped dataloaders in the same order they were passed in.
Example:
# Single dataloader train_loader = fabric.setup_dataloaders(train_loader) # Multiple dataloaders train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
- setup_module(module, move_to_device=True, _reapply_compile=True)[source]¶
Set up a model for accelerated training or inference.
This is the same as calling
.setup(model)
with no optimizers. It is useful for inference or for certain strategies like FSDP that require setting up the module before the optimizer can be created and set up. See alsosetup_optimizers()
.- Parameters:
module¶ (
Module
) – Atorch.nn.Module
to set up.move_to_device¶ (
bool
) – If setTrue
(default), moves the model to the correct device. Set this toFalse
and alternatively useto_device()
manually._reapply_compile¶ (
bool
) – IfTrue
(default), and the model wastorch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False
if compiling DDP/FSDP is causing issues.
- Return type:
_FabricModule
- Returns:
The wrapped model as a
_FabricModule
.
Example:
# Set up model first (useful for FSDP) model = fabric.setup_module(model) # Then create and set up optimizer optimizer = torch.optim.Adam(model.parameters()) optimizer = fabric.setup_optimizers(optimizer)
- setup_optimizers(*optimizers)[source]¶
Set up one or more optimizers for accelerated training.
Some strategies do not allow setting up model and optimizer independently. For them, you should call
.setup(model, optimizer, ...)
instead to jointly set them up.- Parameters:
*optimizers¶ (
Optimizer
) – One or more optimizers to set up. Must provide at least one optimizer.- Return type:
- Returns:
If a single optimizer is passed, returns the wrapped optimizer. If multiple optimizers are passed, returns a tuple of wrapped optimizers in the same order they were passed in.
- Raises:
RuntimeError – If using DeepSpeed or XLA strategies, which require joint model-optimizer setup.
Note
This method cannot be used with DeepSpeed or XLA strategies. Use
setup()
instead for those strategies.Example:
# Single optimizer optimizer = fabric.setup_optimizers(optimizer) # Multiple optimizers opt1, opt2 = fabric.setup_optimizers(opt1, opt2)
- sharded_model()[source]¶
Instantiate a model under this context manager to prepare it for model-parallel sharding. :rtype:
AbstractContextManager
Deprecated since version This: context manager is deprecated in favor of
init_module()
, use it instead.
- to_device(obj)[source]¶
Move a
torch.nn.Module
or a collection of tensors to the current device, if it is not already on that device.
- property device: device¶
The current device this process runs on.
Use this to create tensors directly on the device if needed.
- property local_rank: int¶
The index of the current process among the processes running on the local node.
- property logger: Logger¶
Returns the first logger in the list passed to Fabric, which is considered the main logger.
- property loggers: list[lightning.fabric.loggers.logger.Logger]¶
Returns all loggers passed to Fabric.