Shortcuts

Snapshot

Snapshot is the core API of TorchSnapshot. The class represents application state persisted in storage. A user can take a snapshot of an application via Snapshot.take() (a.k.a saving a checkpoint), or restore the state of an application from a snapshot via Snapshot.restore() (a.k.a loading a checkpoint).

Describing the application state

Before using Snapshot to save or restore application state, the user needs to describe the application state. This is done by creating a dictionary that contains all stateful objects that the user wishes to capture as application state:

app_state = {"model": model, "optimizer": optimizer}

Any object that exposes .state_dict() and .load_state_dict() are considered stateful objects. Common PyTorch objects such as Module, Optimizer, and LR Schedulers all qualify as stateful objects and can be captured directly. Objects that don’t meet this requirement can be captured via StateDict:

from torchsnapshot import StateDict

extra_state = StateDict(iterations=0)
app_state = {"model": model, "optimizer": optimizer, "extra_state": extra_state}

Taking a snapshot

Once the application state is described, the user can take a snapshot of the application via Snapshot.take(). Snapshot.take() persists the application state to the user specified path and returns a Snapshot object, which is a reference to the snapshot.

from torchsnapshot import Snapshot

snapshot = Snapshot.take(path="/path/to/my/snapshot", app_state=app_state)

The user specified path can optionally be prepended with a URI prefix. By default, the prefix is fs://, which suggests that the path is a file system location. TorchSnapshot also provides performant and reliable integration with commonly used cloud object storages. A storage backend can be selected by prepending the corresponding URI prefix (e.g. s3:// for S3, gs:// for Google Cloud Storage).

snapshot = Snapshot.take(
    path="s3://bucket/path/to/my/snapshot",
    app_state=app_state
)

Note

Do not move GPU tensors to CPU before saving them with TorchSnapshot. TorchSnapshot implements various optimizations for increasing the throughput and decreasing the host memory usage of GPU-to-storage transfers. Moving GPU tensors to CPU manually will lower the throughput and increase the chance of “out of memory” issues.

Restoring from a snapshot

To restore from a snapshot, the user first need to obtain a reference to the snapshot. As seen previously, in the process where the snapshot is taken, a reference to the snapshot is returned by Snapshot.take(). In another process (which is more common for resumption), a reference can be obtained by creating a Snapshot object with the snapshot path:

from torchsnapshot import Snapshot

snapshot = Snapshot(path="/path/to/my/snapshot")

To restore the application state from the snapshot, invoke Snapshot.restore() with the application state:

snapshot.restore(app_state=app_state)

Note

Snapshot.restore() restores stateful objects in-place to avoid creating unneccessary intermediate copies of the state.

Distributed snapshot

TODO

Elasticity

TODO

Reproducibility

TODO

Taking a snapshot asynchronously

TODO

API Reference

class Snapshot(path: str, pg: Optional[ProcessGroup] = None)

Snapshot represents the persisted program state at one point in time.

Basic usage:

# Define the program state
app_state = {"model": model, "optimizer": optimizer"}

# At an appropriate time, persist the program state as a snapshot
snapshot = Snapshot.take(path=path, app_state=app_state)

# On resuming, restore the program state from a snapshot
snapshot.restore(app_state)

Overview:

At high level, torchsnapshot saves each value in state dicts as a file/object in the corresponding storage system. It also saves a manifest describing the persisted values and the structure of the original state dict.

Comparing with torch.save() and torch.load(), torchsnapshot:

  • Enables efficient random access of persisted model weights.

  • Accelerates persistence by parallelizing writes.

    • For replicated values, persistence is parallelized across ranks.

  • Enables flexible yet robust elasticity (changing world size on restore).

Elasticity:

Elasticity is implemented via correctly making persisted values available to a newly joined rank, and having it correctly restores the corresponding runtime objects from the persisted values.

For the purpose of elasticity, all persisted values fall into one of the categories in [per-rank, replicated, sharded].

per-rank:

By default, all non-sharded values are treated as per-rank.

On save, the value is only saved by the owning rank.

On load, the value is only made available to the same rank.

replicated:

A user can suggest any non-sharded value as replicated via glob patterns.

On save, the value is only saved once (can be by any rank).

On load, the value is made available to all ranks, including newly joined ranks.

sharded:

Specific types are always treated as sharded (e.g. ShardedTensor).

On save, all shard-owning ranks save their shards.

On load, all shards are made available to all ranks, including newly joined rank. All ranks can read from all shards for restoring the runtime object from persisted values. (ShardedTensor resharding is powered by torch.dist.checkpoint).

If all values within a snapshot are either replicated or sharded, the snapshot is automatically reshard-able.

If a snapshot contains per-rank values, it cannot be resharded unless the per-rank values are explicitly coerced to replicated on load.

classmethod async_take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None) PendingSnapshot

Asynchronously take a snapshot from the program state.

This method creates a consistent snapshot of the app state (i.e. changes to the app state after this method returns have no effect on the snapshot). The asynchronicity is a result of performing storage I/O in the background.

Parameters:
  • app_state – The program state to take the snapshot from.

  • path – The location to save the snapshot.

  • pg – The process group for the processes taking the snapshot.

  • unspecified (When) –

    • If distributed is initialized, the global process group will be used.

    • If distributed is not initialized, single process is assumed.

  • replicated – A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored.

Returns:

A handle with which the newly taken snapshot can be obtained via .wait(). Note that waiting on the handle is optional. The snapshot will be committed regardless of whether .wait() is invoked.

get_manifest() Dict[str, Entry]

Returns the snapshot’s manifest.

Returns:

The snapshot’s manifest.

property metadata: SnapshotMetadata
read_object(path: str, obj_out: Optional[T] = None, memory_budget_bytes: Optional[int] = None) T

Read a persisted object from the snapshot’s content.

The persisted object to read is specified by its path in the snapshot metadata. Available paths can be obtained via snapshot.get_manifest().

A path in snapshot metadata follows the following format:

RANK/STATEFUL_NAME/STATE_DICT_KEY[/NESTED_CONTAINER_KEY...]

The rank only matters when the persisted object is “per-rank”. Arbitrary rank can be used when the persisted object is “replicated” or “sharded”.

If the persisted object is a sharded tensor, obj_out must be supplied. read_object will correctly populate obj_out’s local shards according to its sharding spec.

Parameters:
  • path – The path to the persisted object.

  • obj_out – If specified and the object type supports in-place load, read_object will directly read the persisted object into obj_out’s buffer.

  • memory_budget_bytes – When specified, the read operation will keep the temporary memory buffer size below this threshold.

Returns:

The object read from the snapshot’s content.

restore(app_state: Dict[str, T]) None

Restores the program state from the snapshot.

Parameters:

app_state – The program state to restore from the snapshot.

classmethod take(path: str, app_state: Dict[str, T], pg: Optional[ProcessGroup] = None, replicated: Optional[List[str]] = None) Snapshot

Take a snapshot from the program state.

Parameters:
  • app_state – The program state to take the snapshot from.

  • path – The location to save the snapshot.

  • pg – The process group for the processes taking the snapshot.

  • unspecified (When) –

    • If distributed is initialized, the global process group will be used.

    • If distributed is not initialized, single process is assumed.

  • replicated – A list of glob patterns for hinting the matching paths as replicated. Note that patterns not specified by all ranks are ignored.

Returns:

The newly taken snapshot.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources