Utilities¶
StateDict¶
- class StateDict(dict=None, /, **kwargs)¶
A dict that implements the Stateful protocol. It is handy for capturing stateful objects that do not already implement the Stateful protocol or can’t implement the protocol (i.e. primitive types).
model = Model() progress = StateDict(current_epoch=0) app_state = {"model": model, "progress": progress} # Load from the last snapshot if available ... while progress["current_epoch"] < NUM_EPOCHS: # Train for an epoch ... progress["current_epoch"] += 1 # progress is captured by the snapshot Snapshot.take("foo/bar", app_state, backend=...)
- load_state_dict(state_dict: Dict[str, Any]) None¶
- state_dict() Dict[str, Any]¶
RNGState¶
- class RNGState¶
When captured in app state, it is guaranteed that rng states will be the same after
Snapshot.takeandSnapshot.restore.app_state = { "rng_state": RNGState(), } snapshot = Snapshot.take("foo/bar", app_state, backend=...) after_take = torch.rand(1) snapshot.restore(app_state) after_restore = torch.rand(1) torch.testing.assert_close(after_take, after_restore)
TODO augment this to capture rng states other than torch.get_rng_state().