Shortcuts

Utilities

StateDict

class StateDict(dict=None, /, **kwargs)

A dict that implements the Stateful protocol. It is handy for capturing stateful objects that do not already implement the Stateful protocol or can’t implement the protocol (i.e. primitive types).

model = Model()
progress = StateDict(current_epoch=0)
app_state = {"model": model, "progress": progress}

# Load from the last snapshot if available
...

while progress["current_epoch"] < NUM_EPOCHS:
    # Train for an epoch
    ...
    progress["current_epoch"] += 1

    # progress is captured by the snapshot
    Snapshot.take("foo/bar", app_state, backend=...)
load_state_dict(state_dict: Dict[str, Any]) None
state_dict() Dict[str, Any]

RNGState

class RNGState

When captured in app state, it is guaranteed that rng states will be the same after Snapshot.take and Snapshot.restore.

app_state = {
    "rng_state": RNGState(),
}
snapshot = Snapshot.take("foo/bar", app_state, backend=...)
after_take = torch.rand(1)

snapshot.restore(app_state)
after_restore = torch.rand(1)

torch.testing.assert_close(after_take, after_restore)

TODO augment this to capture rng states other than torch.get_rng_state().

load_state_dict(state_dict: Dict[str, Tensor]) None
state_dict() Dict[str, Tensor]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources