Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Saving and loading memmap tensordicts / tensorclasses #176

Open
2 tasks
vmoens opened this issue Jan 23, 2023 · 1 comment
Open
2 tasks
Assignees
Labels
enhancement New feature or request

Comments

@vmoens
Copy link
Contributor

vmoens commented Jan 23, 2023

Motivation

We can easily save tensordicts using torchsnapshot. However, as MemmapTensors store a file on disk, it would be fairly easy to save a tensordict that contains only MemmapTensors (is_memmap() returns True).
Ideally, the saved tensordict structure would follow the one of the original tensordict (+ metadata)

tensordict = TensorDict({"a": torch.randn(3, 4), {"b": torch.randn(3, 4)}}, [3, 4])
tensordict.memmap_()
save_memmap_tensordict(tensordict, "/path/to/save")

would result in

/path/to/save/metadata.pt
/path/to/save/a.memmap
/path/to/save/b/metadata.pt
/path/to/save/b/c.memmap

(we'd need metadata for each subtensordict since they may have a different device / batch_size).

Loading from such file would also be easy (and would not create a copy):

tensordict_loaded = load_from_memmap("/path/to/save/", mode="r+")
assert tensordict_loaded.is_memmap()
assert (tensordict_loaded == tensordict).all()

This should work for tensordict and tensorclass, but a little extra work may be needed for the latter.

  • TensorDict
  • tensorclass

@sreevasthav

@vmoens vmoens added the enhancement New feature or request label Jan 23, 2023
@vmoens vmoens self-assigned this Jan 23, 2023
@maximilianigl
Copy link

Is this related to the warning I'm getting:
"The method <bound method TensorDictBase.load_memmap_ of [...] wasn't explicitly implemented for tensorclass. This fallback will be deprecated in future releases because it is inefficient and non-compilable. Please raise an issue in tensordict repo to support this method!"?

I'm getting this when saving a replay_buffer which has (custom) Tensorclasses in it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants