-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NOMERG, WIP, POC] Auto-nested TensorDict #201
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Tom Begley <[email protected]> Co-authored-by: Ruggero Vasile <[email protected]>
Co-authored-by: Tom Begley <[email protected]> Co-authored-by: Ruggero Vasile <[email protected]>
#209) Co-authored-by: Tom Begley <[email protected]> Co-authored-by: Ruggero Vasile <[email protected]>
441b521
to
34a0275
Compare
test_chunk: fails for auto-nested due to usage of cat function. Tentatively disable it for that specific case |
Co-authored-by: Tom Begley <[email protected]> Co-authored-by: Ruggero Vasile <[email protected]>
…ested-tensordicts
…219) Co-authored-by: Tom Begley <[email protected]> Co-authored-by: Ruggero Vasile <[email protected]>
We're putting this PR on hold for now. |
For the benefit of anyone who picks this up in future, copying my comment from #220 about the reasons for some outstanding test failures: The following tests fail because
The following tests fail because we can't instantiate a TensorDict from a Python
Finally |
Description
This PR adds support for auto-nesting inside
TensorDict
. This is a proof-of-concept with missing features. Supporting auto-nested values is challenging because of the large number of methods in theTensorDict
class and its children which employ recursion. Checking for cycles during iteration also inevitably introduces some overhead. These trade-offs still need to be varefully benchmarked and evaluated.Here's a summary of the state of this branch and outstanding issues. We have implemented the following:
_TensorDictKeysView
that can detect a cycle ad raise an error or continue (internal usage only)_apply_safe
which can safely map any function onto all entries of the TensorDict, preserving auto-nesting if detected.The updated keys view is useful for iterating over all values in the TensorDict and applying some in-place operation, or aggregating some computed quantities. For example, zeroing all values in the TensorDict
or alternatively, in the implementation of
any
On the other hand,
_apply_safe
can be used to reimplement any function which returns a TensorDict of the same structure as the input. For example, implementingto_tensordict
is as simple asFixed so far
apply_
: implemented with_apply_safe
expand
: implemented with_apply_safe
__eq__
: implemented with_apply_safe
__ne__
: implemented with_apply_safe
to_tensordict
: implemented with_apply_safe
zero_
: implemented with_TensorDictKeysView
clone
: implemented with_apply_safe
__repr__
: fixed manually, without either paradigmall
: implemented with_TensorDictKeysView
whendim
is not specified, and_apply_safe
when it isany
: implemented with_TensorDictKeysView
whendim
is not specified, and_apply_safe
when it islock
: implemented with_TensorDictKeysView
unlock
: implemented with_TensorDictKeysView
_index_tensordict
: implemented with_apply_safe
masked_fill_
: implemented with_TensorDictKeysView
Outstanding bugs
split
: doesn't neatly fit paradigm of_apply_safe
since we return list of TensorDicts. Could possibly use_apply_safe
inside a list comprehension, but we would not be able to usetorch.split
, we'd have to manually compute indexes and slice which risks both being slow and also deviating fromtorch.split
behaviour if not carefully tested.select
:KeyError
usesset(self.keys(include_nested=True))
in the error message which fails on auto-nestedto_dict
: could be implemented with_TensorDictKeysView
, but we need a convenience function for setting nested entries of a Pythondict
.zero_
: fails for TensorDict variants with lazily computed values, as the in-place update applies to a lazily computed value and doesn't persist. Suggest replacingvalue.zero_
we have currently withself.set_(key, 0, no_check=True)
apply
: needs to be updated for auto-nested case. Unclear if we can use_apply_safe
._TensorDictKeysView._items
: need to handle theSubTensorDict
case._TensorDictKeysView
: fails when instantiated with lazy variants of TensorDictassert_allclose_td
: fails in the auto-nested case. Shouldn't be hard to fix the recursion error, but ideally if there is auto-nesting we would check that the auto-nesting structure exists in both tensordicts.masked_select
: needs to be updated, could probably be done with_apply_safe
_index_tensordict
: preserve_is_memmap
etc. in nested values (really an_apply_safe
bug)unbind
: may work automatically onceapply
is fixedpad
: recursive implementation. Could potentially use_apply_safe
.is_contiguous
: recursive implementation needs to be updated__setitem__
: gets stuck in a recursive loop in the auto-nested case.stack
: fails in auto-nested case.cat
: fails in auto-nested case.flatten_keys
: doesn't make sense in auto-nested case, raise informative error and test for itmemmap_
: fails in recursive case. Probably needs some care...TensorDict
from adict
with auto-nested values causes a recursion error.Tests to refactor
test_lock_write
: replace calls toitems
withinclude_nested=True
with_TensorDictKeysView
.test_apply
: replace calls tokeys
with_TensorDictKeysView
.test_apply_other
: replace calls tokeys
with_TensorDictKeysView
.test_masking_set
: helper functionzeros_like
is recursive and fails on auto-nested tensordicts. Potential to use_apply_safe
.test_entry_type
: replace call tokeys
with_TensorDictKeysView
test_update
: key comparison is currently callingset(td.keys(True))
. Need to find alternative way to check that key structure is preserved.Original issue is fixed,test_getitem_range
: failing on all test cases, seems that_index_tensordict
is doing something incorrect when we passrange
.assert_allclose_td
is causing failure nowtest_to_dict_nested
: has a recursive checker algorithm that fails on auto-nested casetest_unflatten_keys
: doesn't really make sense in the auto-nested case.test_batchsize_reset
: some issue with_index_tensordict
it seemstest_shared_inheritance
:unbind
doesn't preserveis_shared
.Open questions
select
: what should happen when we select keys from an auto-nested tensordict.detect_loop
: we don't actually use this in any of our implementations, should we have such a public method?