Skip to content

Commit

Permalink
[Performance] Make _to_consolidated compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: ddd498c547f0f1ff8aee01d0990061bfff5502eb
Pull Request resolved: #1041
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent fe6db77 commit de7014b
Showing 1 changed file with 124 additions and 2 deletions.
126 changes: 124 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):

flat_size = []
start = 0
sorting_index = 0

def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
nonlocal start
nonlocal start, sorting_index
n = value.element_size() * value.numel()
if need_padding:
pad = n % 8
Expand All @@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
start,
stop,
pad,
flat_size[-1],
sorting_index,
)
sorting_index = sorting_index + 1
start = stop

def assign(
Expand Down Expand Up @@ -10390,7 +10394,7 @@ def to(self, *args, **kwargs) -> T:
return result

if self.is_consolidated() and dtype is None:
return self._to_consolidated(
return self._to_consolidated_compile(
device=device,
pin_memory=non_blocking_pin,
num_threads=num_threads,
Expand Down Expand Up @@ -10542,6 +10546,124 @@ def copy_dict(d):

return result

def _to_consolidated_compile(self, *, device, pin_memory, num_threads, non_blocking):

def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()):
root = False
if lengths is None:
lengths = []
pos = []
keys = []
root = True
for k, v in metadata["leaves"].items():
lengths.append(v[-2])
pos.append(v[-1])
keys.append(prefix + (k,))
for k, d in metadata.items():
if "leaves" in d:
get_l(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,))
if root:
# l = torch.empty(len(lengths), dtype=torch.long)
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
out0 = [None, ] * len(pos)
out1 = [None, ] * len(pos)
for p, l, k in zip(pos, lengths, keys):
out0[p] = k
out1[p] = l
return out0, out1

def split_storage(consolidated):
keys, splits = get_l(consolidated["metadata"])
return dict(zip(keys, consolidated["storage"].split(splits)))

if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0
storage = self._consolidated["storage"]
if pin_memory:
storage = storage.pin_memory()
storage_cast = storage.to(device, non_blocking=True)

_consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
# faster than deepcopy
def copy_dict(d):
return {
k: v if not isinstance(v, dict) else copy_dict(v)
for k, v in d.items()
}

_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])

slice_map = split_storage(_consolidated)

def set_(name, x):
if not isinstance(name, tuple):
name = (name,)
if x.is_nested:
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.nested._internal.nested_tensor import (
_tensor_symint_registry,
NestedTensor,
)
from torch.nested._internal.ops import extract_kwargs

if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
kwargs = extract_kwargs(x)
values = x._values
lengths = x._lengths
offsets = x._offsets
kwargs["offsets"] = slice_map[(*name[:-1], "<NJT_OFFSETS>"+name[-1],)].view(offsets.dtype).view(offsets.shape)
if lengths is not None:
kwargs["lengths"] = slice_map[(*name[:-1], "<NJT_LENGTHS>"+name[-1],)].view(lengths.dtype).view(lengths.shape)
ragged_source = lengths
else:
ragged_source = offsets
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
from torch._subclasses.functional_tensor import (
mb_unwrap_functional_tensor,
)

# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_thing)
src = mb_unwrap_functional_tensor(ragged_source)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
ragged_source
]

return NestedTensor(
slice_map[(*name[:-1], "<NJT_VALUES>"+name[-1],)].view(values.dtype).view(values.shape),
**kwargs,
)
return slice_map[name].view(x.dtype).view(x.shape)

result = self._fast_apply(
set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True,
)
result._consolidated = _consolidated

if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _sync_all(self):
if _has_cuda:
# TODO: dynamo doesn't like torch.cuda.is_initialized
Expand Down

0 comments on commit de7014b

Please sign in to comment.