diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 4e86ebacd..e91e4e36e 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -756,9 +756,10 @@ def forward( self.index_remappings_array, self.index_remappings_array_offsets, ) - if self.timestep_prefetch_size.get() <= 0: - self.prefetch(indices, offsets) - self.timestep_prefetch_size.decrement() + if self.lxu_cache_weights.numel() > 0: + if self.timestep_prefetch_size.get() <= 0: + self.prefetch(indices, offsets) + self.timestep_prefetch_size.decrement() lxu_cache_locations = self.lxu_cache_locations_list.pop()