Skip to content

Commit

Permalink
Make the scratch pad tensor UVA (#2844)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2844

Before this diff, the scratch pad in SSD TBE (see D55998215 for more
detail) was a CPU tensor which was later transferred to GPU to allow
the TBE kernels to access it.  The scratch pad tranfer was highly
inefficient since TBE over provisioned the scratch pad buffer
allocation (as it did not know the exact number of cache missed rows)
causing extra data transfer.  Such the extra data transfer could be
large since the number of cache missed rows was normally much smaller
than value that TBE over provisioned.

There are two ways to avoid the extra data transfer:

(1) Let TBE have the exact number of cache missed rows on host which
requires device-to-host data transfer which will cause a sync point
between host and device (not desirable in most trainings).
However, this will allow TBE to use `cudaMemcpy` which will utilize
the DMA engine and will allow the memory copy to overlap efficiently
with other compute kernels.

(2) Make the scratch pad accessible by both CPU and GPU.  In other
words, make the scratch pad a UVA tensor.  This does not require
device and host synchornization.  However, the memory copy has to be
done through CUDA load/store which requires a kernel to run on SMs.
Thus, the memory copy and compute kernel overlapping will require a
careful SMs management.

Based on the tradeoffs explained above, we chose to implement (2)
to avoid the host and device sync point.

Reviewed By: q10

Differential Revision: D58631974

fbshipit-source-id: f1ed0e4b23447010eb7409a08ca195e419f6089a
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 19, 2024
1 parent 3e845e4 commit c44c2d4
Showing 1 changed file with 66 additions and 48 deletions.
114 changes: 66 additions & 48 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(
self.ssd_event_get_inputs_cpy = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
self.ssd_prefetch_data = []
Expand Down Expand Up @@ -509,6 +509,7 @@ def evict(
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
is_rows_uvm: bool,
name: Optional[str] = "",
) -> None:
"""
Expand All @@ -527,14 +528,17 @@ def evict(
pre_event (Event): The CUDA event that the stream has to wait on
post_event (Event): The CUDA event that the current will record on
when the eviction is done
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
tensor (which is accessible on both host and
device)
Returns:
None
"""
with record_function(f"## ssd_evict_{name} ##"):
with torch.cuda.stream(stream):
stream.wait_event(pre_event)

rows_cpu = self.to_pinned_cpu(rows)
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)

rows.record_stream(stream)

Expand All @@ -554,19 +558,21 @@ def evict(

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows_gpu, post_bwd_evicted_indices_cpu, actions_count_cpu) = (
(inserted_rows, post_bwd_evicted_indices_cpu, actions_count_cpu, do_evict) = (
self.ssd_scratch_pads.pop(0)
)
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
rows=inserted_rows_gpu,
indices_cpu=post_bwd_evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
name="scratch_pad",
)
if do_evict:
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
rows=inserted_rows,
indices_cpu=post_bwd_evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_backward,
post_event=self.ssd_event_evict_sp,
is_rows_uvm=True,
name="scratch_pad",
)

def _compute_cache_ptrs(
self,
Expand All @@ -575,7 +581,7 @@ def _compute_cache_ptrs(
linear_index_inverse_indices: torch.Tensor,
unique_indices_count_cumsum: torch.Tensor,
cache_set_inverse_indices: torch.Tensor,
inserted_rows_gpu: torch.Tensor,
inserted_rows: torch.Tensor,
unique_indices_length: torch.Tensor,
inserted_indices: torch.Tensor,
actions_count_cpu: torch.Tensor,
Expand All @@ -596,7 +602,7 @@ def _compute_cache_ptrs(
unique_indices_count_cumsum,
cache_set_inverse_indices,
self.lxu_cache_weights,
inserted_rows_gpu,
inserted_rows,
unique_indices_length,
inserted_indices,
)
Expand All @@ -616,17 +622,18 @@ def _compute_cache_ptrs(
# Store scratch pad info for post backward eviction
self.ssd_scratch_pads.append(
(
inserted_rows_gpu,
inserted_rows,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
linear_cache_indices.numel() > 0,
)
)

# pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor,
# typing.Any, Tensor]`.
return (
lxu_cache_ptrs,
inserted_rows_gpu,
inserted_rows,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
)
Expand Down Expand Up @@ -659,7 +666,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
1, # for now assume prefetch_dist == 1
self.lru_state,
)

# Transfer evicted indices from GPU to CPU right away to increase a
# chance of overlapping with compute on the default stream
(evicted_indices_cpu,) = (
Expand All @@ -685,49 +691,61 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
evicted_rows = self.lxu_cache_weights[
assigned_cache_slots.clamp(min=0).long(), :
]
inserted_rows = torch.empty(
evicted_rows.shape,
dtype=self.lxu_cache_weights.dtype,
pin_memory=True,
)

if linear_cache_indices.numel() > 0:
inserted_rows = torch.ops.fbgemm.new_managed_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=self.lxu_cache_weights.dtype,
),
evicted_rows.shape,
)
else:
inserted_rows = torch.empty(
evicted_rows.shape,
dtype=self.lxu_cache_weights.dtype,
device=self.current_device,
)

current_stream = torch.cuda.current_stream()

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)
current_stream.wait_event(self.ssd_event_get_inputs_cpy)

self.record_function_via_dummy_profile(
"## ssd_get ##",
self.ssd_db.get_cuda,
inserted_indices_cpu,
inserted_rows,
actions_count_cpu,
)

if linear_cache_indices.numel() > 0:
self.record_function_via_dummy_profile(
"## ssd_get ##",
self.ssd_db.get_cuda,
inserted_indices_cpu,
inserted_rows,
actions_count_cpu,
)
current_stream.record_event(self.ssd_event_get)
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
# Should we allocate on HBM?
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)

torch.ops.fbgemm.masked_index_put(
self.lxu_cache_weights,
assigned_cache_slots,
inserted_rows_gpu,
inserted_rows,
actions_count_gpu,
)

# Evict rows from cache to SSD
self.evict(
rows=evicted_rows,
indices_cpu=evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_get,
post_event=self.ssd_event_evict,
name="cache",
)
if linear_cache_indices.numel() > 0:
# Evict rows from cache to SSD
self.evict(
rows=evicted_rows,
indices_cpu=evicted_indices_cpu,
actions_count_cpu=actions_count_cpu,
stream=self.ssd_eviction_stream,
pre_event=self.ssd_event_get,
post_event=self.ssd_event_evict,
is_rows_uvm=False,
name="cache",
)

# TODO: keep only necessary tensors
self.ssd_prefetch_data.append(
Expand All @@ -737,7 +755,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
linear_index_inverse_indices,
unique_indices_count_cumsum,
cache_set_inverse_indices,
inserted_rows_gpu,
inserted_rows,
unique_indices_length,
inserted_indices,
actions_count_cpu,
Expand All @@ -762,7 +780,7 @@ def forward(
prefetch_data = self.ssd_prefetch_data.pop(0)
(
lxu_cache_ptrs,
inserted_rows_gpu,
inserted_rows,
post_bwd_evicted_indices_cpu,
actions_count_cpu,
) = self._compute_cache_ptrs(*prefetch_data)
Expand Down Expand Up @@ -804,7 +822,7 @@ def forward(
# codegen/genscript/optimizer_args.py
ssd_tensors={
"row_addrs": lxu_cache_ptrs,
"inserted_rows": inserted_rows_gpu,
"inserted_rows": inserted_rows,
"post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
"actions_count": actions_count_cpu,
},
Expand Down

0 comments on commit c44c2d4

Please sign in to comment.