You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The functions: TensorDict and tensordict.nn.make_tensordict expects a dictionary to be passed.
a dictionary with non-string keys gives an error: IndexError: tuple index out of range
Same is true about tensordict.TensorDict function.
Traceback (most recent call last):
File "<stdin>", line 1, in<module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, insetreturn self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
Traceback (most recent call last):
File "<stdin>", line 1, in<module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, insetreturn self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
>>> from tensordict.nn import make_tensordict
>>> d = make_tensordict(d)
Traceback (most recent call last):
File "<stdin>", line 1, in<module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/functional.py", line 379, in make_tensordict
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1332, in from_dict
out = cls(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, insetreturn self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
Expected behavior
when the dictionary has string keys, a python dictionary is converted to TensorDict ,
eg. d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)
This is correct code as expected but, when keys are non-string like
d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)
it gives an error.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
Describe how the library was installed (pip, source, ...): python -m pip install tensordict==0.3.2
Python version: Python 3.8.13
Versions of any other relevant libraries: pytorch:2.2.2+cu121
TensorDict required keys to be strings, tuples of strings or tuples of tuples of strings etc. but no other key type is allowed.
The main reason is that tensordicts can also be indexed along the "shape" dimension, and allowing other key-types (e.g. ints) would lead to undefined behaviours.
Example
data=TensorDict({"a": torch.arange(3)}, batch_size=[3])
data[1] # returns 1data=TensorDict({1: torch.arange(3)}, batch_size=[3])
data[1] # should this take the second element along shape dimension, or the '1' key?
That being said we should probably capture this error to make things clearer for our users!
Because earlier the values of the keys were anything different than tensordict, dictionary, scalars and tensors, it explicitly gave the error that data type of value is out of this set.
So I think something similar for keys be beneficial.
Should I go ahead and add the type checking for this, if you confirm that the keys would be just string, tuple of string, so on.
Thanks
Describe the bug
The functions:
TensorDict
andtensordict.nn.make_tensordict
expects a dictionary to be passed.a dictionary with non-string keys gives an error: IndexError: tuple index out of range
Same is true about
tensordict.TensorDict
function.To Reproduce
Expected behavior
when the dictionary has string keys, a python dictionary is converted to TensorDict ,
eg.
d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)
This is correct code as expected but, when keys are non-string like
d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)
it gives an error.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
python -m pip install tensordict==0.3.2
Python 3.8.13
pytorch:2.2.2+cu121
Additional context
Reason and Possible fixes
I think the code at an abstract level works in 2 steps:
Thus, the culprit might
which calls
So whats happening is search for string keys, where keys might not be string
Checklist
The text was updated successfully, but these errors were encountered: