Skip to content

Commit

Permalink
Flag to disable uvm caching for pt2 export (pytorch#2308)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2308

Avoid conditional of module output for pt2 tracing, when uvm caching is disabled no need to prefetch

Reviewed By: suo, sryap

Differential Revision: D53361712

fbshipit-source-id: e1d1efb07d73f0d40dce856a5b2d99775a788b08
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Feb 5, 2024
1 parent dad9720 commit 7889f64
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,10 @@ def forward(
self.index_remappings_array,
self.index_remappings_array_offsets,
)
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()
if self.lxu_cache_weights.numel() > 0:
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
self.timestep_prefetch_size.decrement()

lxu_cache_locations = self.lxu_cache_locations_list.pop()

Expand Down

0 comments on commit 7889f64

Please sign in to comment.