diff --git a/defs.bzl b/defs.bzl index 43d17b13e..9da6f6920 100644 --- a/defs.bzl +++ b/defs.bzl @@ -105,7 +105,8 @@ def get_fbgemm_inline_avx2_srcs(msvc = False, buck = False): asm_srcs = ["src/FbgemmFP16UKernelsAvx2.cc"] if buck: return select({ - "DEFAULT": asm_srcs if not msvc else intrinsics_srcs, + "DEFAULT": asm_srcs, + "ovr_config//compiler:cl": intrinsics_srcs, "ovr_config//cpu:arm64": intrinsics_srcs, }) return asm_srcs if not msvc else intrinsics_srcs @@ -135,7 +136,8 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False): ] if buck: return select({ - "DEFAULT": asm_srcs if not msvc else intrinsics_srcs, + "DEFAULT": asm_srcs, + "ovr_config//compiler:cl": intrinsics_srcs, "ovr_config//cpu:arm64": intrinsics_srcs, }) return asm_srcs if not msvc else intrinsics_srcs diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 27b4ec884..99d5e5b0c 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS lamb partial_rowwise_adam partial_rowwise_lamb - ensemble_rowwise_adagrad lars_sgd none rowwise_adagrad_with_counter) @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) set(VBE_OPTIMIZERS rowwise_adagrad rowwise_adagrad_with_counter - ensemble_rowwise_adagrad sgd dense) @@ -265,10 +263,10 @@ list(APPEND gen_gpu_host_source_files foreach(optimizer ${ALL_OPTIMIZERS}) list(APPEND gen_cpu_source_files "gen_embedding_backward_split_${optimizer}_cpu.cpp" - "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp") + "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp" + "gen_embedding_split_${optimizer}_pt2_autograd.cpp") list(APPEND gen_gpu_host_source_files "gen_embedding_backward_split_${optimizer}.cpp" - "gen_embedding_split_${optimizer}_pt2_autograd.cpp" "gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp") endforeach() @@ -456,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/forward/embedding_forward_split_cpu.cpp codegen/inference/embedding_forward_quantized_host_cpu.cpp codegen/training/backward/embedding_backward_dense_host_cpu.cpp + codegen/training/pt2/pt2_autograd_utils.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/config/feature_gates.cpp src/memory_utils/memory_utils.cpp @@ -473,6 +472,7 @@ set(fbgemm_gpu_sources_static_cpu src/layout_transform_ops/layout_transform_ops_cpu.cpp src/quantize_ops/quantize_ops_cpu.cpp src/quantize_ops/quantize_ops_meta.cpp + src/sparse_ops/sparse_async_cumsum.cpp src/sparse_ops/sparse_ops_cpu.cpp src/sparse_ops/sparse_ops_meta.cpp src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp @@ -481,6 +481,7 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp + src/split_embeddings_utils/split_embeddings_utils_cpu.cpp codegen/training/index_select/batch_index_select_dim0_ops.cpp codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 6091bcc8e..6d354900a 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -25,11 +25,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def generate_unary_feature( diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index c919199ee..e43106b8c 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -21,11 +21,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def benchmark_hbc_function( diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 46337701e..814f950b0 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -31,10 +31,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) @@ -47,7 +44,6 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 81eb07bea..54755fff6 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -34,11 +34,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 2609f7fbf..bca34b8a9 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -20,11 +20,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index afdcb8b3c..ac37444ff 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -335,7 +335,6 @@ def generate() -> None: lars_sgd(), partial_rowwise_adam(), partial_rowwise_lamb(), - ensemble_rowwise_adagrad(), rowwise_adagrad(), approx_rowwise_adagrad(), rowwise_adagrad_with_weight_decay(), diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 285cf9a55..894ce104c 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None: f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", has_cpu_support=True, is_forward=True, + has_vbe_support=True, ) # Generate PT2 forward wrapper (CUDA) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index acf2af31f..15c100ed5 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]: } -def ensemble_rowwise_adagrad() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; - - at::acc_type multiplier; - at::acc_type coef_ema; - at::acc_type should_ema; - at::acc_type should_swap; - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum2[idx] + g_avg_square; - momentum2[idx] = new_sum_square_grads; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - - coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0; - if (step_mode == 1) { - // row_counter[idx] tracks the number of appearances of this ID - row_counter[idx] += 1.0; - should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema); - should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap); - } else if (step_mode == 2) { - should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema); - should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap); - // row_counter[idx] records the step of last ema - if (should_ema > 0.5) { - coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema); - row_counter[idx] = iter*1.0; - } - // prev_iter[idx] records the step of last swap - if (should_swap > 0.5) { - prev_iter[idx] = iter*1.0; - } - } else { - should_ema = 0.0; - should_swap = 0.0; - } - } - multiplier = SHFL_SYNC(multiplier, 0); - coef_ema = SHFL_SYNC(coef_ema, 0); - should_ema = SHFL_SYNC(should_ema, 0); - should_swap = SHFL_SYNC(should_swap, 0); - """ - - split_weight_update = """ - weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w; - - if (should_ema > 0.5) { // slow table ema - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x; - m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y; - m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z; - m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w; - m_t.store(&momentum1[idx * D + d]); - } - - if (should_swap > 0.5) { // slow-to-fast swap - Vec4T m_t(&momentum1[idx * D + d]); - weight_new.acc.x = m_t.acc.x * 1.0; - weight_new.acc.y = m_t.acc.y * 1.0; - weight_new.acc.z = m_t.acc.z * 1.0; - weight_new.acc.w = m_t.acc.w * 1.0; - } - """ - - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "ensemble_rowwise_adagrad", - "is_prototype_optimizer": True, - "args": OptimizerArgsSet.create( - [ - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum1", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum2", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem(ArgType.TENSOR, "prev_iter"), - OptimItem(ArgType.TENSOR, "row_counter"), - OptimItem(ArgType.FLOAT, "learning_rate"), - OptimItem(ArgType.FLOAT, "eps"), - OptimItem(ArgType.FLOAT, "step_ema"), - OptimItem(ArgType.FLOAT, "step_swap"), - OptimItem(ArgType.FLOAT, "step_start"), - OptimItem(ArgType.FLOAT, "momentum"), - OptimItem(ArgType.INT, "iter"), - OptimItem(ArgType.INT, "step_mode"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": True, - "has_global_weight_decay_support": False, - "has_ssd_support": False, - } - - def partial_rowwise_adam() -> Dict[str, Any]: split_precomputation = """ at::acc_type g_local_sum_square = 0.0; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 2b126c96d..92eff015f 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -41,6 +41,16 @@ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + } // namespace void pruned_hashmap_insert_{{ wdesc }}_cpu( @@ -404,58 +414,72 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] { + using hash_t = + std::conditional_t, uint64_t, uint32_t>; - if (table_start == table_end) { -for (const auto l : c10::irange(L)) { - dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; - } - } else { -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - - // empty slot - if (slot_sparse_idx == -1) { - dense_indices_acc[indices_start + l] = -1; - break; - } - // already exists - if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; - break; + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + + const auto* offsets_acc = offsets.data_ptr(); + const auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + if (table_start == table_end) { + for (const auto l : c10::irange(L)) { + dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; + } + + } else { + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + + while (true) { + const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; + + // empty slot + if (slot_sparse_idx == -1) { + dense_indices_acc[indices_start + l] = -1; + break; + } + // already exists + if (slot_sparse_idx == idx) { + dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; + break; + } + // linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } } - } + }); + return dense_indices; } {% if not weighted %} + Tensor pruned_array_lookup_cpu( Tensor indices, Tensor offsets, @@ -465,37 +489,46 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); - const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); - at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { - for (const auto t : c10::irange(begin, end)) { - int64_t index_remappings_start = index_remappings_offsets_acc[t]; - int64_t index_remappings_end = index_remappings_offsets_acc[t + 1]; - int64_t capacity = index_remappings_end - index_remappings_start; - int32_t indices_start = offsets_acc[t * B]; - int32_t indices_end = offsets_acc[(t + 1) * B]; - if (capacity > 0) { - for (const auto i : c10::irange(indices_start,indices_end)) { - int32_t idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; - } - } else { - std::memcpy( - dense_indices_acc + indices_start, - indices_acc + indices_start, - (indices_end - indices_start) * sizeof(int32_t)); - } - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); + + at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { + for (const auto t : c10::irange(begin, end)) { + const auto index_remappings_start = index_remappings_offsets_acc[t]; + const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; + const auto capacity = index_remappings_end - index_remappings_start; + + const auto indices_start = offsets_acc[t * B]; + const auto indices_end = offsets_acc[(t + 1) * B]; + + if (capacity > 0) { + for (const auto i : c10::irange(indices_start, indices_end)) { + auto idx = indices_acc[i]; + dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + } + } else { + std::memcpy( + dense_indices_acc + indices_start, + indices_acc + indices_start, + (indices_end - indices_start) * sizeof(index_t)); + } + } + }); }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 41fd137dd..b6f55b961 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -21,6 +21,7 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -374,29 +375,37 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { } Tensor lookup(Tensor indices, Tensor offsets) const { + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); + int32_t T = maps_.size(); TORCH_CHECK(T > 0); int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); TORCH_CHECK(maps_.size() == T); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - for (const auto t : c10::irange(T)) { - auto& map = maps_[t]; - for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; - for (const auto l : c10::irange(L)) { - int32_t slot_sparse_index = indices_acc[indices_start + l]; - auto it = map.find(slot_sparse_index); - dense_indices_acc[indices_start + l] = - it != map.end() ? it->second : -1; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "PrunedMapCPU::lookup", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + for (const auto t : c10::irange(T)) { + auto& map = maps_[t]; + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + for (const auto l : c10::irange(L)) { + const auto slot_sparse_index = indices_acc[indices_start + l]; + const auto it = map.find(slot_sparse_index); + dense_indices_acc[indices_start + l] = + it != map.end() ? it->second : -1; + } } } - } + }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..846cd4763 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,19 +14,20 @@ using Tensor = at::Tensor; namespace nbit { +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t table_start = hash_table_offsets[t]; const int64_t table_end = hash_table_offsets[t + 1]; @@ -51,6 +52,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -58,13 +62,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru #else const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif + for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - const int32_t idx = indices[indices_start + l_start + subwarp_id]; - uint32_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + const index_t idx = indices[indices_start + l_start + subwarp_id]; + hash_t slot_start = + pruned_hash_function(static_cast(idx)) % capacity; + while (true) { - const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const hash_t slot = (slot_start + subwarp_tid) % capacity; const int2 val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); const int32_t slot_sparse_idx = val.x; @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } + if (__any_sync(subwarp_mask, found)) { break; } else if (__any_sync(subwarp_mask, empty)) { @@ -89,19 +96,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +117,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -132,6 +140,8 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } // namespace nbit +using namespace nbit; + Tensor pruned_hashmap_lookup_cuda( Tensor indices, Tensor offsets, @@ -139,6 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); CUDA_DEVICE_GUARD(indices); @@ -149,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -178,6 +191,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +218,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; } diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index bc4e7ba74..5dd5c30b1 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -7,7 +7,7 @@ */ // clang-format off -{% set wdesc = "weighted" if weighted else "unweighted" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" @@ -22,7 +22,7 @@ namespace nbit { `Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the same generated source file. */ -{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} +{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const pta::PackedTensorAccessor32 weights_placements, const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 weights_tys, - {% if not nobag %} + {%- if not nobag %} const pta::PackedTensorAccessor32 D_offsets, - {% else %} + {%- else %} const int64_t D, - {% endif %} + {%- endif %} FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread)) const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - {% if not nobag %} + {%- if not nobag %} const int64_t pooling_mode, - {% endif %} + {%- endif %} const int64_t row_alignment, - {% if weighted %} + {%- if weighted %} pta::PackedTensorAccessor32 indice_weights, - {% endif %} - {% if type_map[emb_weight_type].enum_name == "FP8" %} + {%- endif %} + {%- if type_map[emb_weight_type].enum_name == "FP8" %} const int fp8_exponent_bits, const int fp8_exponent_bias, - {% endif %} + {%- endif %} pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ); -{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] +{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] } @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} - -Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( - Tensor dev_weights, - Tensor uvm_weights, - Tensor weights_placements, - Tensor weights_offsets, - Tensor weights_tys, - {% if not nobag %} - Tensor D_offsets, - const int64_t total_D, - {% else %} - const int64_t D, - {% endif %} - const int64_t max_int2_D, - const int64_t max_int4_D, - const int64_t max_int8_D, - const int64_t max_float16_D, - const int64_t max_float32_D, - Tensor indices, - Tensor offsets, - {% if not nobag %} - const int64_t pooling_mode, - {% endif %} - const int64_t row_alignment, - {% if weighted %} - Tensor indice_weights, - {% endif %} - const int64_t output_dtype, - Tensor lxu_cache_weights, - Tensor lxu_cache_locations, - const int64_t max_float8_D, - const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias -) { - TENSOR_ON_CUDA_GPU(dev_weights); - TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); - {% if not nobag %} - TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); - {% endif %} - TENSORS_ON_SAME_DEVICE(indices, dev_weights); - TENSORS_ON_SAME_DEVICE(offsets, dev_weights); - {% if weighted %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); - {% endif %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - - CUDA_DEVICE_GUARD(dev_weights); - +{%- macro construct_and_return_output_tensor() %} // kernels assume indices are contiguous. indices = indices.contiguous(); @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TORCH_CHECK(D > 0); {%- endif %} + // Construct output tensor Tensor output; const int kINT8QparamsBytes = 8; + SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (B == 0 || indices.numel() == 0) { return output; } +{%- endmacro %} - using index_t = int32_t; +template +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - constexpr int32_t kWarpsPerBlock = 4; + CUDA_DEVICE_GUARD(dev_weights); + + {{- construct_and_return_output_tensor() }} + constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; #define Y(...) \ if (device_only) { \ @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + return output; +} + +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + // All argument tensors need to be on the same CUDA device + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); + + // indices and offsets need to have the same scalar type + TENSORS_HAVE_SAME_TYPE(indices, offsets); + // Only int32_t and int64_t indices are supported at the moment + TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int); + + CUDA_DEVICE_GUARD(dev_weights); + + // Create output tensor ref + Tensor output; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] { + output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + {%- if not nobag %} + D_offsets, + total_D, + {%- else %} + D, + {%- endif %} + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + {%- endif %} + row_alignment, + {%- if weighted %} + indice_weights, + {%- endif %} + output_dtype, + lxu_cache_weights, + lxu_cache_locations, + max_float8_D, + fp8_exponent_bits, + fp8_exponent_bias); + }); + return output; } diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b623b92d0..3666de5b9 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -35,8 +35,12 @@ #include "fbgemm_gpu/utils/ops_utils.h" #include #include "fbgemm_gpu/utils/dispatch_macros.h" -#include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/split_embeddings_utils.h" #include "fbgemm_gpu/config/feature_gates.h" +{%- if has_vbe_support %} +#include "fbgemm_gpu/utils/pt2_autograd_utils.h" +{%- endif %} using Tensor = at::Tensor; @@ -236,9 +240,9 @@ enum SSDTensor { const Tensor& /*prev_iter_dev*/, {%- endif %} {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, + const int64_t /*iter*/, {%- endif %} - const double gwd_lower_bound, + const double /*gwd_lower_bound*/, {%- endif %} {# /* if is_gwd */ #} {%- for arg_type in args_pt2.split_function_args %} {{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %} @@ -617,7 +621,6 @@ class {{ autograd_func }} : const c10::SymInt, const int64_t, const c10::SymInt)>(); - auto [ vbe_row_output_offsets, vbe_b_t_map @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} + {%- if vbe %} + if (weights_host.numel() > 1){ + grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); + } + {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward( {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} - )>(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if not ssd %} {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + if (B_offsets.has_value()) { {%- if has_global_weight_decay_support %} // vbe and has gwd support if (apply_global_weight_decay && weight_decay > 0) { diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index c74355207..5b2b066fe 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -30,9 +30,12 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{%- for vbe in ([True, False] if has_vbe_support else [False]) %} +{%- set vdesc = "_vbe" if vbe else "" %} + {%- if is_forward %} {#-/* PT2 wrapper function for backward grad_indice_weights CPU */#} -Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( +Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( const Tensor& indices, const Tensor& offsets, const Tensor& /*lxu_cache_locations*/, - const Tensor& feature_requires_grad) { + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow( @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( {% if is_forward %} {#-/* PT2 wrapper function for forward CPU */#} -Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& host_weights, const Tensor& /*dev_weights*/, const Tensor& /*uvm_weights*/, @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& indice_weights, const Tensor& /*lxu_cache_locations*/, const Tensor& /*uvm_cache_stats*/, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/ + const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/ + const c10::SymInt vbe_output_size, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- endif %} const bool /*is_experimental = false*/, const int64_t output_dtype = static_cast(SparseType::FP32)) { - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") - .typed(); + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + {%- if vbe %} + // TODO: remove this after vbe is implemented for CPU kernel + Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map; + Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets; + const auto output = op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + auto options = at::TensorOptions() + .dtype(output.options().dtype()) + .device(host_weights.options().device()); + const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__); + Tensor output_new = at::empty({vbe_output_size_}, options); + const int32_t T = D_offsets.numel() - 1; + const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1; - return op.call( - host_weights, - weights_offsets, - D_offsets, - total_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - output_dtype); -} + for (int32_t r = 0; r < R; r++){ + auto D_offset = 0; + for (int32_t t = 0; t < T; t++){ + const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item(); + const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item(); + const int32_t D = D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item(); + + TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D)); + auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten(); + output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values); + D_offset += D; + } + } + return output_new; + {%- else %} + return op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + {%- endif %} + } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap const int64_t /*BT_block_size*/, const int64_t /*max_segment_length_per_warp*/, const bool stochastic_rounding, - const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- if vbe %} + const Tensor& B_offsets, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + {%- endif %} const bool /*use_uniq_cache_locations*/, const bool /*use_homogeneous_placements*/, {{ args_pt2.split_function_args | join(", ") }} @@ -194,29 +258,30 @@ namespace { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- if is_forward %} DISPATCH_TO_CPU( - "split_embedding_codegen_grad_indice_weights_pt2_wrapper", - split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper); + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper", + split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper); {%- endif %} {%- for weighted in [True, False] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- if is_forward %} - {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format( - wdesc + {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format( + wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( - optimizer, wdesc + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format( + optimizer, wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper); {%- endif %} {%- endfor %} {#-/*for weighted*/#} } - } // namespace +{%- endfor %} {#-/* for vbe in [True, False] */#} + {% endif %} // if has_cpu_support // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp new file mode 100644 index 000000000..071acf90a --- /dev/null +++ b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets) { + /* FOR CPU VBE to use the same backend */ + const auto T = D_offsets.numel() - 1; + int32_t max_B = 0; + int32_t total_D = 0; + // find max_B, total_D to create output [max_B, total_D] + for (int32_t t = 0; t < T; t++) { + auto b = B_offsets[t + 1].item() - B_offsets[t].item(); + max_B = std::max(max_B, b); + total_D += D_offsets[t + 1].item() - D_offsets[t].item(); + } + auto grad_output_ = at::empty({max_B, total_D}, grad_output.options()); + // for each feature + auto offset = 0; + + const int32_t R = B_offsets_rank_per_feature.size(1) - 1; + for (int32_t r = 0; r < R; r++) { + auto D_offset = 0; + for (int32_t t = 0; t < T; t++) { + const int32_t b_begin = B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = + B_offsets_rank_per_feature[t][r + 1].item(); + const int32_t D = + D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b = b_end - b_begin; + const int32_t num_elm = b * D; + auto values = grad_output.slice(0, offset, offset + num_elm); + values = values.reshape({b, D}); + grad_output_.index_put_( + {at::indexing::Slice(b_begin, b_end), + at::indexing::Slice(D_offset, D_offset + D)}, + values); + D_offset += D; + offset += num_elm; + } + } + return grad_output_; +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index 54fa11177..357aad622 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -60,10 +60,6 @@ class OptimizerArgs(NamedTuple): eps: float beta1: float beta2: float - step_ema: float - step_swap: float - step_start: float - step_mode: int weight_decay: float weight_decay_mode: int eta: float diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e03a879cb..2f14b27de 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -145,18 +145,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} @@ -327,18 +315,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template index b9be5cd4c..6c2380e7c 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template @@ -90,18 +90,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2: float = 0.999, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema: float = 10000, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap: float = 10000, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start: float = 0, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode: int = 2, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay: float = 0.0, {%- endif %} @@ -130,18 +118,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2=beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=weight_decay, {%- endif %} @@ -186,7 +162,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): rowwise = False {% endif %} {% elif state_tensor == "momentum2" %} - {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %} + {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %} rowwise = True {% else %} rowwise = False @@ -236,18 +212,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} self.beta2 = beta2 {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - self.step_ema = step_ema - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - self.step_swap = step_swap - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - self.step_start = step_start - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - self.step_mode = step_mode - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} self.weight_decay = weight_decay {%- endif %} diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu index 08e22baa9..8d8ee6ab5 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu @@ -233,22 +233,24 @@ void bounds_check_indices_cuda( constexpr size_t kNumThreads = 256; const auto max_B_ = vbe ? max_B : B; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { - const auto bounds_check_kernel = - (vbe ? bounds_check_indices_kernel - : bounds_check_indices_kernel); - TORCH_DSA_KERNEL_LAUNCH( - bounds_check_kernel, - div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), - dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), - 0, - at::cuda::getCurrentCUDAStream(), - rows_per_table.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - vbe ? B_offsets.value().data_ptr() : nullptr, - bounds_check_mode_, - warning.packed_accessor32(), - FixedDivisor(max_B_)); - }); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "bounds_check_indices_cuda", [&] { + const auto bounds_check_kernel = + (vbe ? bounds_check_indices_kernel + : bounds_check_indices_kernel); + TORCH_DSA_KERNEL_LAUNCH( + bounds_check_kernel, + div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), + dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), + 0, + at::cuda::getCurrentCUDAStream(), + rows_per_table + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + vbe ? B_offsets.value().data_ptr() : nullptr, + bounds_check_mode_, + warning.packed_accessor32(), + FixedDivisor(max_B_)); + }); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 1098378d0..1d0cd1348 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -70,7 +70,7 @@ void bounds_check_indices_cpu( const auto rows_per_table_acc = rows_per_table.accessor(); auto warning_acc = warning.data_ptr(); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { auto offsets_acc = offsets.accessor(); auto indices_acc = indices.accessor(); auto num_indices = indices.numel(); diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst index 92e8f1148..a85168bfc 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst @@ -3,14 +3,22 @@ Jagged Tensor Operators .. automodule:: fbgemm_gpu +.. _jagged-tensor-ops-stable-api: + +Stable API +---------- + +.. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense + +Other API +--------- + .. autofunction:: torch.ops.fbgemm.jagged_2d_to_dense .. autofunction:: torch.ops.fbgemm.jagged_1d_to_dense .. autofunction:: torch.ops.fbgemm.dense_to_jagged -.. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense - .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst index 654373f40..7970ce6f9 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst @@ -3,5 +3,13 @@ Pooled Embedding Modules .. automodule:: fbgemm_gpu +.. _pooled-embedding-modules-stable-api: + +Stable API +---------- + .. autoclass:: fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings :members: __call__ + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst index 52e2fd47d..9e9d545d7 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst @@ -3,6 +3,14 @@ Pooled Embedding Operators .. automodule:: fbgemm_gpu +.. _pooled-embedding-operators-stable-api: + +Stable API +---------- + .. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings .. autofunction:: torch.ops.fbgemm.permute_pooled_embs + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst new file mode 100644 index 000000000..3b47f8bcd --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst @@ -0,0 +1,14 @@ +Quantization Operators +====================== + +.. automodule:: fbgemm_gpu + +.. _quantize-ops-stable-api: + +Stable API +---------- + +.. autofunction:: torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst new file mode 100644 index 000000000..e5a4213f7 --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -0,0 +1,27 @@ +Sparse Operators +================ + +.. automodule:: fbgemm_gpu + +.. _sparse-ops-stable-api: + +Stable API +---------- + +.. autofunction:: torch.ops.fbgemm.permute_2D_sparse_data + +.. autofunction:: torch.ops.fbgemm.permute_1D_sparse_data + +.. autofunction:: torch.ops.fbgemm.expand_into_jagged_permute + +.. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum + +.. autofunction:: torch.ops.fbgemm.offsets_range + +.. autofunction:: torch.ops.fbgemm.segment_sum_csr + +.. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 + +.. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features + +Other API diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst index bbd39d873..9b5453786 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst @@ -1,6 +1,11 @@ Table Batched Embedding (TBE) Training Module ============================================= +.. _table-batched-embedding-ops-stable-api: + +Stable API +---------- + .. autoclass:: fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen :members: forward, split_embedding_weights, @@ -8,3 +13,6 @@ Table Batched Embedding (TBE) Training Module set_learning_rate, update_hyper_parameters, set_optimizer_step + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst new file mode 100644 index 000000000..54b4a6baa --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst @@ -0,0 +1,37 @@ +FBGEMM_GPU Stable Python API +============================ + +We provide the stable API support starting from FBGEMM_GPU v1.0. The following +outlines our supports: + +- API backward compatibility guarantees via thorough testing. We guarantee that + our stable APIs will be backward compatible within a major version, meaning + that the stable APIs for v1.0.0 will be compatible with every future release + unless explicitly announced in advance + +- Enhanced documentation, ensuring that every stable API has comprehensive and + up-to-date documentation. + +- Functionality guarantees are only provided through unit testing framework. + We do NOT guarantee any functionalities that are NOT explicitly tested and + documented in our unit tests. + +- No performance guarantees. However, we are committed to providing support on + a best-effort basis. + +Stable APIs +----------- + +Our stable APIs can be found via the links below: + +- :ref:`Table batched embedding (TBE) modules` + +- :ref:`Pooled embedding operators` + +- :ref:`Pooled embedding modules` + +- :ref:`Sparse operators` + +- :ref:`Jagged tensor operators` + +- :ref:`Quantization operators` diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index ba0d8ba6b..2b92f0d3d 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -56,6 +56,14 @@ Table of Contents fbgemm_gpu-overview/jagged-tensor-ops/JaggedTensorOps.rst +.. _fbgemm.toc.api.stable: + +.. toctree:: + :maxdepth: 1 + :caption: FBGEMM Stable API + + fbgemm_gpu-stable-api/python_api.rst + .. _fbgemm.toc.api.cpp: .. toctree:: @@ -89,8 +97,10 @@ Table of Contents :maxdepth: 1 :caption: FBGEMM_GPU Python Operators API - fbgemm_gpu-python-api/jagged_tensor_ops.rst + fbgemm_gpu-python-api/sparse_ops.rst fbgemm_gpu-python-api/pooled_embedding_ops.rst + fbgemm_gpu-python-api/quantize_ops.rst + fbgemm_gpu-python-api/jagged_tensor_ops.rst .. _fbgemm-gpu.toc.api.python.modules: diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 07765fa21..d90f12e0a 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -6,7 +6,8 @@ # pyre-unsafe import logging -from typing import List, Optional, Tuple +import sys +from typing import List, Optional, Tuple, Union import torch import triton # @manual @@ -43,7 +44,7 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 -def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: +def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper: """ Converts tensor to triton fp8 type. @@ -57,6 +58,28 @@ def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: return tl_reinterpret(tensor, dtype=dtype) +def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: + """ + Maps torch dtype to triton dtype. + + Args: + dtype (torch.dtype): input dtype. + + Returns: + tl.dtype: triton dtype. + """ + if dtype == torch.float16: + return tl.float16 + elif dtype == torch.bfloat16: + return tl.bfloat16 + elif dtype == torch.float32: + return tl.float32 + elif dtype == torch.int32: + return tl.int32 + else: + raise ValueError(f"Unsupported dtype {dtype}") + + def init_to_zero(name): return lambda nargs: nargs[name].zero_() @@ -213,11 +236,6 @@ def get_configs_io_bound() -> List[Config]: "k_key", ], ) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) @triton.jit def _kernel_matmul_fp8_row( A_ptr, @@ -246,7 +264,6 @@ def _kernel_matmul_fp8_row( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, USE_BIAS: tl.constexpr, AB_DTYPE: tl.constexpr, NUM_SMS: tl.constexpr, @@ -751,6 +768,7 @@ def _kernel_matmul_fp8_row_tma_persistent( stride_cn, dot_out_dtype: tl.constexpr, c_dtype: tl.constexpr, + bias_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, @@ -818,7 +836,6 @@ def _kernel_matmul_fp8_row_tma_persistent( dtype_fp8 = tl.float8e4nv scale_dtype = tl.float32 - bias_dtype = tl.float32 for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -880,10 +897,14 @@ def _kernel_matmul_fp8_row_tma_persistent( if HAS_TMA_DESC: print( - "TMA benchmarks will be running with experimental grid constant TMA descriptor." + "TMA benchmarks will be running with experimental grid constant TMA descriptor.", + file=sys.stderr, ) else: - print("TMA benchmarks will be running without grid constant TMA descriptor.") + print( + "TMA benchmarks will be running without grid constant TMA descriptor.", + file=sys.stderr, + ) class TmaAutoTuneHelper: @@ -964,7 +985,7 @@ def get_tma_descriptor_kernel_param(self, name): return self.cuda_descriptors[name] -@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=()) +@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=()) def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, @@ -976,6 +997,7 @@ def matmul_fp8_row( fp8_fast_accum: bool = True, imprecise_acc: bool = False, tma_persistent: bool = True, + no_use_persistent: bool = False, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. @@ -995,15 +1017,15 @@ def matmul_fp8_row( torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper torch fp8 dtype. + assert a.dtype == pt_fp8_dtype + assert b.dtype == pt_fp8_dtype M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = ( - prep_matmul(a_tl, b_tl, dot_out_dtype) + prep_matmul(a, b, dot_out_dtype) ) output_shape = a_shape[:-1] + (N,) @@ -1035,7 +1057,38 @@ def persistent_grid(META): ), ) - if tma_persistent: + if no_use_persistent: + logger.info("Using non-persistent kernel") + if bias is not None: + raise AssertionError("bias is not supported in non-persistent kernel") + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid]( + a, + b, + c, + M, + N, + K, + m_key, + n_key, + k_key, + a_scale, + b_scale, + # bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + # GROUP_M=8, + # USE_BIAS=bias is not None, + AB_DTYPE=False, + ) + elif tma_persistent: # used by TMA persistent kernel desc_helper = TmaAutoTuneHelper() desc_helper.init_tma_descriptor("a") @@ -1049,22 +1102,22 @@ def persistent_grid_tma(META): nonlocal desc_helper desc_helper.fill_2d_tma_descriptor( "a", - a_tl.data_ptr(), + a.data_ptr(), M, K, META["BLOCK_M"], META["BLOCK_K"], - a_tl.element_size(), + a.element_size(), ) desc_helper.fill_2d_tma_descriptor( "b", - b_tl.data_ptr(), + b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], - b_tl.element_size(), + b.element_size(), ) desc_helper.fill_2d_tma_descriptor( "c", @@ -1111,8 +1164,14 @@ def persistent_grid_tma(META): desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") - # pyre-ignore[28]: - _kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma]( + bias_dtype_triton = None + if bias is not None: + bias_dtype_triton = map_dtype_to_triton(bias.dtype) + + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[ + persistent_grid_tma + ]( desc_a, desc_b, desc_c, @@ -1133,6 +1192,7 @@ def persistent_grid_tma(META): c.stride(1), dot_out_dtype=dot_out_dtype_triton, c_dtype=c_dtype_triton, + bias_dtype=bias_dtype_triton, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, GROUP_M=8, @@ -1141,9 +1201,9 @@ def persistent_grid_tma(META): USE_BIAS=bias is not None, ) elif imprecise_acc: - _kernel_matmul_fp8_row_imprecise_acc[grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid]( + a, + b, c, M, N, @@ -1168,9 +1228,9 @@ def persistent_grid_tma(META): AB_DTYPE=False, ) elif fp8_fast_accum: - _kernel_matmul_fp8_row[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid]( + a, + b, c, M, N, @@ -1196,9 +1256,11 @@ def persistent_grid_tma(META): NUM_SMS=NUM_SMS, ) else: - _kernel_matmul_fp8_row_no_fast_acc[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[ + persistent_grid + ]( + a, + b, c, M, N, @@ -1659,13 +1721,13 @@ def matmul_fp8_block( Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + _, tl_fp8_dtype, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper triton fp8 dtype. + a_tl = reinterpret_fp8_type(a, tl_fp8_dtype) + b_tl = reinterpret_fp8_type(b, tl_fp8_dtype) M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul( a_tl, b_tl, dot_out_dtype @@ -1794,14 +1856,18 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]: def prep_matmul( - a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype] -) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]: + a: Union[TensorWrapper, torch.Tensor], + b: Union[TensorWrapper, torch.Tensor], + dot_out_dtype: Optional[torch.dtype], +) -> Tuple[ + int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device +]: """ Shared bookkeeping for a @ b.T matmul. Args: - a (TensorWrapper): [M, K] input tensor. - b (TensorWrapper): [N, K] input tensor. + a (torch.Tensor): [M, K] input tensor. + b (torch.Tensor): [N, K] input tensor. dot_out_dtype (tl.dtype): Output type of tensor core. Returns: @@ -1812,7 +1878,8 @@ def prep_matmul( n_key (int): Autotuning key for N dim. k_key (int): Autotuning key for K dim. c (Tensor): [M, N] output tensor. - dot_out_dtype (torch.dtype): Output type of tensor core. + c_dtype_triton (tl.dtype): Type of output tensor. + dot_out_dtype (tl.dtype): Output type of tensor core. device (torch.device): Device of output tensor. """ device = a.device @@ -1827,11 +1894,20 @@ def prep_matmul( # allocates output assert a.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5, tl.float8e4b8, - ] and b.dtype in [ + ] + assert b.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5, @@ -1847,12 +1923,7 @@ def prep_matmul( assert isinstance( dot_out_dtype, torch.dtype ), f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype" - if dot_out_dtype == torch.bfloat16: - dot_out_dtype_triton = tl.bfloat16 - elif dot_out_dtype == torch.float32: - dot_out_dtype_triton = tl.float32 - else: - dot_out_dtype_triton = tl.int32 + dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype) return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device @@ -2383,3 +2454,213 @@ def quantize_fp8_block( x_scale = x_scale.to(output_device) # pyre-ignore del x, x_padded return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def prune_configs(configs, named_args, **kwargs): + + SIZE_M = named_args["A"].shape[0] + SIZE_N = named_args["B"].shape[1] + SIZE_K = named_args["C"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_SIZE_M, BLOCK_SIZE_N, _ = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + ) + SPLIT_K = kw["SPLIT_K"] + if SIZE_M <= 32 and BLOCK_SIZE_M != 32: + continue + if SIZE_N <= 32 and BLOCK_SIZE_N != 32: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): + continue + pruned_configs.append(config) + logging.info(f"pruned_configs: config len{len(pruned_configs)}") + return pruned_configs + + +def get_full_non_persistent_tuning_space(use_split_k): + if torch.version.hip is None: + logger.warning("Using HIP configs on CUDA device, this may be slow.") + configs = [] + block_mn_range = [32, 64, 128, 256] + block_k_range = [32, 64, 128] + split_k_range = [1] + num_warps_range = [1, 2, 4, 8, 16] + group_m_range = [1, 4, 8] + num_stage_range = [0] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for split_k in split_k_range: + for num_stages in num_stage_range: + configs.append( + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + return configs + + +MATMUL_CONFIGS: List[Config] = get_full_non_persistent_tuning_space(True) + + +@triton.autotune( + configs=MATMUL_CONFIGS, + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": prune_configs, + "perf_model": None, + "top_k": None, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def _kernel_matmul_fp8_row_non_persistent( + A, + B, + C, + M, + N, + K, + m_key, + n_key, + k_key, + A_scale, + B_scale, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, +) -> None: + """Matmul kernel of [M, K] @ [N, K] with row-wise scales + + performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. + + Args: + A (TensorWrapper): [M, K] input tensor. + B (TensorWrapper): [N, K] input tensor. + C (TensorWrapper): [M, N] output tensor. + M (int): M dimension of input tensor. + N (int): N dimension of input tensor. + K (int): K dimension of input tensor. + m_key (int): Autotuning key for M dimension of input tensor. + n_key (int): Autotuning key for N dimension of input tensor. + k_key (int): Autotuning key for K dimension of input tensor. + A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A + B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B + stride_am (int): Stride of M dimension of A. + stride_ak (int): Stride of K dimension of A. + stride_bn (int): Stride of N dimension of B. + stride_bk (int): Stride of K dimension of B. + stride_cm (int): Stride of M dimension of C. + stride_cn (int): Stride of N dimension of C. + dot_out_dtype (torch.dtype): Output type of tensor core. + allow_tf32 (bool): Whether to use TF32 for tensor core. + fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + BLOCK_M (int): Block size for M dimension. + BLOCK_N (int): Block size for N dimension. + BLOCK_K (int): Block size for K dimension. + GROUP_M (int): Number of groups for M dimension swizzle. + SPLIT_K (int): Number of SM's to launch per row. + EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. + AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. + """ + # Matrix multiplication. + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # Re-order program ID for better L2 performance (swizzle). + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # Do matrix multiplication. + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # Pointers. + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Invert scaling. + a_scale = tl.load(A_scale + rm, mask=rm < M) + b_scale = tl.load(B_scale + rn, mask=rn < N) + # Invert vector, then multiply on matrix for speed. + # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. + scale = a_scale[:, None] * b_scale[None, :] + acc *= scale + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # Handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index a2916c9e5..dd6f165ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -43,11 +43,13 @@ else() src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu src/quantize/cutlass_extensions/f8f8bf16_cublas.cu src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu + src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu src/quantize/cutlass_extensions/i8i8bf16.cu src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu + src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu src/quantize/quantize.cu src/quantize/quantize.cpp) endif() diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 787c0547c..00974a9fe 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -1437,6 +1437,7 @@ __global__ void dequantize_fp8_cache_kernel( auto MAX_T = cache_K.size(1); auto D_H = cache_K_dq.size(3); auto D_H_q = cache_K.size(3); + // TODO: support D_H < 128 for small model used in testing. CUDA_KERNEL_ASSERT(D_H == 128); auto b = blockIdx.x; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip new file mode 100644 index 000000000..25f532ad3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if defined(USE_ROCM) + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/flush_icache.hpp" + +namespace fbgemm_gpu { + +void flush_icache_ck() +{ + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + auto stream = at::cuda::getCurrentHIPStream().stream(); + + ck::flush_icache<<>>(); + hip_check_error(hipGetLastError()); +} + +} // namespace fbgemm_gpu + +#endif // defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu new file mode 100644 index 000000000..871543a2f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + typename WEIGHT_SCALE_DTYPE> +at::Tensor bf16i4bf16_rowwise_batched_impl( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + int K = X.size(2); + + int num_groups = w_scale.size(0) / B; + + TORCH_CHECK(X.is_cuda() && X.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); + TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); + TORCH_CHECK(K >= num_groups && K % num_groups == 0); + + int group_size = K / num_groups; + + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::bfloat16_t; + using LayoutInputA = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::int4b_t; + using LayoutInputB = cutlass::layout::RowMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementScale = WEIGHT_SCALE_DTYPE; + using ElementZeroPoint = WEIGHT_SCALE_DTYPE; + using ElementComputeEpilogue = float; + using ElementAccumulator = float; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using PongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + cute::tuple, + LayoutInputB, + AlignmentInputB, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + using StrideS = typename CollectiveMainloop::StrideScale; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, B)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(N, num_groups, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, B}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(X.data_ptr()), + stride_a, + reinterpret_cast(w_scale.data_ptr()), + stride_S, + group_size, + reinterpret_cast(w_zp.data_ptr())}, + {{1.0, 0.0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + KernelMode kernel = get_batched_kernel_mode(X, WQ); + if (kernel == KernelMode::Small) { + return bf16i4bf16_rowwise_batched_impl< + 64, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else if (kernel == KernelMode::Large) { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + false, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } +} + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // Check datatypes. + TORCH_CHECK( + (w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || + (w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || + (w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), + "Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); + + if (w_scale.dtype() == at::kFloat) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kHalf) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kBFloat16) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else { + throw std::runtime_error( + "Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched"); + } +} + +#else + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu new file mode 100644 index 000000000..a34c694e0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -0,0 +1,515 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +// Cutlass rowwise batched kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +at::Tensor f8f8bf16_rowwise_batched_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = WQ.size(2); + TORCH_CHECK(XQ.size(-1) == K); + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + // Make sure the provided output has the proper shape and dtype. + TORCH_CHECK(Y.sizes().vec() == out_sizes); + TORCH_CHECK(Y.dtype() == at::kBFloat16); + } else { + Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, int32_t>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, int32_t>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, int32_t>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, B}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr()), + ElementBias(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // bias + // compute_1 + { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +// FP8 Rowwise batched Cutlass kernel dispatch. +template +at::Tensor dispatch_fp8_rowwise_batched_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + KernelMode kernel = get_batched_kernel_mode(XQ, WQ); + TORCH_CHECK( + (XQ.dim() == 3 && WQ.dim() == 3), + "FP8 rowwise batched GEMM only supports 3D inputs"); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Medium) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 1, + 2, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } +} + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } +} + +#else + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 94a68096d..9a267193a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -12,7 +12,7 @@ namespace fbgemm_gpu { -enum class KernelMode { Small, Large, Default }; +enum class KernelMode { Small, Medium, Large, Default }; inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto M = XQ.size(0); @@ -31,4 +31,25 @@ inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { } } +inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto B = XQ.size(0); + auto M = XQ.size(1); + auto K = XQ.size(2); + auto N = WQ.size(1); + auto BM = B * M; + // Heuristic to determine kernel mode + bool use_medium_kernel = + ((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192)))); + bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024))); + if (BM <= 128 || N <= 128) { + return KernelMode::Small; + } else if (use_medium_kernel) { + return KernelMode::Medium; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 39084712c..1abf8fb40 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -27,6 +27,11 @@ namespace fbgemm_gpu { +#ifdef USE_ROCM +// flush icache +void flush_icache_ck(); +#endif + // SmoothQuant kernels at::Tensor i8i8bf16(at::Tensor XQ, at::Tensor WQ, double scale, int64_t split_k); @@ -57,6 +62,14 @@ at::Tensor f8f8bf16_rowwise( std::optional bias = c10::nullopt, bool use_fast_accum = true, std::optional output = c10::nullopt); +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -83,6 +96,11 @@ at::Tensor bf16i4bf16_rowwise( at::Tensor WQ, at::Tensor w_scale, at::Tensor w_zp); +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, + at::Tensor WQ, + at::Tensor w_scale, + at::Tensor w_zp); at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); @@ -132,11 +150,17 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); + m.def( + "f8f8bf16_rowwise_batched(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); + m.def( + "bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); + m.def( "i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor"); m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); @@ -175,6 +199,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.impl( "quantize_fp8_per_tensor_fixed_scale", quantize_fp8_per_tensor_fixed_scale); + +#ifdef USE_ROCM + m.def("flush_icache_hip() -> ()"); + m.impl("flush_icache_hip", flush_icache_ck); +#endif } TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -188,6 +217,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); #endif } @@ -216,6 +246,21 @@ at::Tensor f8f8bf16_rowwise_meta( return Y; } +at::Tensor f8f8bf16_rowwise_batched_meta( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor /* x_scale */, + at::Tensor /* w_scale */, + std::optional /* bias = c10::nullopt */, + bool /* use_fast_accum = true */, + std::optional /* output = c10::nullopt */) { + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, XQ.options().dtype(at::kBFloat16)); + return Y; +} + at::Tensor f8f8bf16_blockwise_meta( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -290,14 +335,28 @@ at::Tensor f8i4bf16_rowwise_meta( at::Tensor bf16i4bf16_rowwise_meta( at::Tensor X, // BF16 at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { int M = X.size(0); int N = WQ.size(0); auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); return Y; } +at::Tensor bf16i4bf16_rowwise_batched_meta( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + std::vector quantize_fp8_per_row_meta( at::Tensor input, std::optional bs, @@ -331,8 +390,10 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py index ab728db6f..4a7211c68 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py @@ -38,8 +38,12 @@ # pyre-fixme[21]: Could not find name `pow` in `triton.language.math`. from triton.language.math import pow except ImportError: - # @manual=//triton:triton - from triton.language.extra.cuda.libdevice import pow + try: + # @manual=//triton:triton + from triton.language.extra.libdevice import pow + except ImportError: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import pow _INTERNAL_DTYPE_MAP: Dict[str, int] = {"": 0, "f32": 1, "f64": 2} diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index fdd038e2c..38a09f360 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -614,12 +614,16 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: BMM ops are not yet suported." + ) @settings(deadline=None) @given( B=st.sampled_from([1, 4]), M=st.sampled_from([2048, 4096]), N=st.sampled_from([128, 256]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_fp8_batched_gemm( self, @@ -627,6 +631,7 @@ def test_fp8_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 @@ -655,7 +660,10 @@ def fp8_loopover_bmm( return y y_ref = torch.bmm(x, w.transpose(1, 2)) - y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + if use_loopover: + y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + else: + y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale) torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) @unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.") @@ -665,6 +673,7 @@ def fp8_loopover_bmm( M=st.sampled_from([2048, 4096]), N=st.sampled_from([256, 512]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_int4_batched_gemm( self, @@ -672,6 +681,7 @@ def test_int4_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: if not MARLIN_ENABLED: return @@ -681,28 +691,48 @@ def test_int4_batched_gemm( wq = [] w_scale = [] group_size = 128 - for i in range(B): - _, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size) - wq.append(wq_) - w_scale.append(w_scale_) - wq = torch.stack(wq) - w_scale = torch.stack(w_scale) - - def int4_loopover_bmm( - x: torch.Tensor, - wq: torch.Tensor, - w_scale: torch.Tensor, - ) -> torch.Tensor: - B = x.shape[0] - M = x.shape[1] - N = w_scale.shape[2] - y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + + if use_loopover: for i in range(B): - y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) - return y + _, wq_, w_scale_ = marlin_quantize( + w[i].cuda().t().contiguous(), group_size + ) + wq.append(wq_) + w_scale.append(w_scale_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale) + + def int4_loopover_bmm( + x: torch.Tensor, + wq: torch.Tensor, + w_scale: torch.Tensor, + ) -> torch.Tensor: + B = x.shape[0] + M = x.shape[1] + N = w_scale.shape[2] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + for i in range(B): + y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) + return y + + y_int4 = int4_loopover_bmm(x, wq, w_scale) + else: + w_zp = [] + for i in range(B): + wq_, w_scale_, w_zp_ = int4_row_quantize(w[i], group_size) + + wq_ = pack_int4(wq_).contiguous().to(device="cuda") + w_scale_ = w_scale_.contiguous().to(device="cuda") + w_zp_ = w_zp_.contiguous().to(device="cuda") + wq.append(wq_) + w_scale.append(w_scale_) + w_zp.append(w_zp_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale).view(-1, N) + w_zp = torch.stack(w_zp).view(-1, N) + y_int4 = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, wq, w_scale, w_zp) y_ref = torch.bmm(x, w.transpose(1, 2)) - y_int4 = int4_loopover_bmm(x, wq, w_scale) torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2) diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index 5ef0f1c32..db2260df4 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -13,16 +13,13 @@ import torch +from fbgemm_gpu.utils.loader import load_torch_module + try: # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index 4b621cbe3..8d696532a 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -11,6 +11,8 @@ jagged_tensor_ops, merge_pooled_embedding_ops, permute_pooled_embedding_ops, + quantize_ops, + sparse_ops, ) except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py new file mode 100644 index 000000000..3662b12c7 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf, + """ +FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate) -> Tensor + +Convert FP32/16 to INT8/4/2 using rowwise quantization. + +Args: + input (Tensor): An input tensor. Must be either FP32 (`torch.float`) + or FP16 (`torch.half`) and must be 2 dimensions. + + bit_rate (int): Quantized bit rate (2 for INT2, 4 for INT4, or 8 for + INT8) + +Returns: + Quantized output (Tensor). Data type is `torch.uint8` (byte type) + +**Example:** + + >>> # Randomize input + >>> input = torch.randn(2, 4, dtype=torch.float32, device="cuda") + >>> print(input) + tensor([[ 0.8247, 0.0031, -1.0068, -1.2081], + [ 0.5427, 1.5772, 1.0291, -0.7626]], device='cuda:0') + >>> # Quantize + >>> output = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate=4) + >>> print(output) + tensor([[159, 1, 86, 48, 213, 188], + [248, 11, 254, 48, 26, 186]], device='cuda:0', dtype=torch.uint8) + """, +) diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py new file mode 100644 index 000000000..333d5e5da --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -0,0 +1,468 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.permute_2D_sparse_data, + """ +permute_2D_sparse_data(permute, lengths, values, weights=None, permuted_lengths_sum=None) -> Tuple[Tensor, Tensor, Optional[Tensor]] + +Permute 2D sparse data along the first dimension (dim 0). Note that 2D +refers to the number of dense dimensions. The input data is actually 3D +where the first two dimensions are dense and the last dimension is +jagged (sparse). The data to permute over can be less or more and with or +without repetitions. + +Args: + permute (Tensor): A 1D-tensor that describes how data is permuted along dim + 0. `permute[i]` indicates that data at position `permute[i]` is moved + to position `i`. The length of this tensor is the total amount of data + in dim 0 to be permuted. The values in `permute` must be >= 0 and < + `lengths.shape[0]` + + lengths (Tensor): A 2D-tensor that contains jagged shapes corresponding to + the other two dense dimensions. For example, in the case of the + embedding input, the 3D shape is (num features, batch size, bag size). + `lengths[t][b]` represents the bag size of feature `t` and sample `b`. + + values (Tensor): A 1D-input-tensor to be permuted. The length of this + tensor must be equal to `lengths.sum()`. This tensor can be of any data + type. + + weights (Optional[Tensor] = None): An optional 1D-float-tensor. It must + have the same length as `values`. It will be permuted the same way as + values + + permuted_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the permuted data (output + shape). If not provided, the operator will compute this data which may + cause a device-host synchronization (if using GPU). Thus, it is + recommended to supply this value to avoid such the synchronization. + +Returns: + A tuple of permuted lengths, permuted indices and permuted weights + +**Example:** + + >>> permute = torch.tensor([1, 0, 2], dtype=torch.int32, device="cuda") + >>> lengths = torch.tensor([[2, 3, 4, 5], [1, 2, 4, 8], [0, 3, 2, 3]], dtype=torch.int64, device="cuda") + >>> values = torch.randint(low=0, high=100, size=(lengths.sum().item(),), dtype=torch.int64, device="cuda") + >>> print(values) + tensor([29, 12, 61, 98, 56, 94, 5, 89, 65, 48, 71, 54, 40, 33, 78, 68, 42, 21, + 60, 51, 15, 47, 48, 68, 52, 19, 38, 30, 38, 97, 97, 98, 18, 40, 42, 89, + 66], device='cuda:0') + >>> torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) + (tensor([[1, 2, 4, 8], + [2, 3, 4, 5], + [0, 3, 2, 3]], device='cuda:0'), + tensor([78, 68, 42, 21, 60, 51, 15, 47, 48, 68, 52, 19, 38, 30, 38, 29, 12, 61, + 98, 56, 94, 5, 89, 65, 48, 71, 54, 40, 33, 97, 97, 98, 18, 40, 42, 89, + 66], device='cuda:0'), + None) + """, +) + +add_docs( + torch.ops.fbgemm.permute_1D_sparse_data, + """ +permute_1D_sparse_data(permute, lengths, values, weights=None, permuted_lengths_sum=None) -> Tuple[Tensor, Tensor, Optional[Tensor]] + +Permute 1D sparse data. Note that 1D referrs to the number of dense dimensions. +The input data is actually 2D where the first dimension is dense and the second +dimension is jagged (sparse). The data to permute over can be less or more and +withh or without repetitions. + +Args: + permute (Tensor): A 1D-tensor that describes how data is permuted along dim + 0. `permute[i]` indicates that data at position `permute[i]` is moved + to position `i`. The length of this tensor is the total amount of data + in dim 0 to be permuted. The values in `permute` must be >= 0 and < + `lengths.numel()` + + lengths (Tensor): A 1D-tensor that contains jagged shapes corresponding to + the other dense dimension. `lengths[i]` represents the jagged shape of + data at position `i` in dim 0 + + values (Tensor): A 1D-input-tensor to be permuted. The length of this + tensor must be equal to `lengths.sum()`. This tensor can be of any data + type. + + weights (Optional[Tensor] = None): An optional 1D-float-tensor. It must + have the same length as `values`. It will be permuted the same way as + values + + permuted_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the permuted data (output + shape). If not provided, the operator will compute this data which may + cause a device-host synchronization (if using GPU). Thus, it is + recommended to supply this value to avoid such the synchronization. + +Returns: + A tuple of permuted lengths, permuted indices and permuted weights + +**Example:** + >>> permute = torch.tensor([1, 0, 3, 0], dtype=torch.int32, device="cuda") + >>> lengths = torch.tensor([2, 3, 4, 5], dtype=torch.int64, device="cuda") + >>> values = torch.randint(low=0, high=100, size=(lengths.sum().item(),), dtype=torch.int64, device="cuda") + >>> print(values) + tensor([ 1, 76, 24, 84, 94, 25, 15, 23, 31, 46, 9, 23, 34, 3], + device='cuda:0') + >>> torch.ops.fbgemm.permute_1D_sparse_data(permute, lengths, values) + (tensor([3, 2, 5, 2], device='cuda:0'), + tensor([24, 84, 94, 1, 76, 46, 9, 23, 34, 3, 1, 76], device='cuda:0'), + None) + """, +) + +add_docs( + torch.ops.fbgemm.expand_into_jagged_permute, + """ +expand_into_jagged_permute(permute, input_offset, output_offset, output_size) -> Tensor + +Expand the sparse data permute index from feature dimension to batch dimension, +for cases where the sparse features has different batch sizes across ranks. + +The op expands the permute from feature level to batch level by contiguously +mapping each bag of its corresponding features to the position the batch sits +on after feature permute. The op will automatically derive offset array of +feature and batch to compute the output permute. + +Args: + permute (Tensor): The feature level permute index. + + input_offset (Tensor): The exclusive offsets of feature-level length. + + output_offsets (Tensor): The exclusive offsets of feature-level permuted + length. + + output_size (int): The number of elements in the output tensor + +Returns: + The output follows the following formula + + >>> output_permute[feature_offset[permute[feature]] + batch] <- bag_offset[batch] + """, +) + +add_docs( + torch.ops.fbgemm.asynchronous_complete_cumsum, + """ +asynchronous_complete_cumsum(t_in) -> Tensor + +Compute complete cumulative sum. For the GPU operator, the operator is +nonblocking asynchronous. For the CPU operator, it is a blocking operator. + +Args: + t_in (Tensor): An input tensor + +Returns: + The complete cumulative sum of `t_in`. Shape is `t_in.numel() + 1` + +**Example:** + + >>> t_in = torch.tensor([7, 8, 2, 1, 0, 9, 4], dtype=torch.int64, device="cuda") + >>> torch.ops.fbgemm.asynchronous_complete_cumsum(t_in) + tensor([ 0, 7, 15, 17, 18, 18, 27, 31], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.offsets_range, + """ +offsets_range(offsets, range_size) -> Tensor + +Generate an integer sequence from 0 to `(offsets[i+1] - offsets[i])` for every +`i`, where `0 <= i < offsets.numel()` + +Args: + offsets (Tensor): The offsets (complete cumulative sum values) + + range_size (int): The output size (the total sum) + +Returns: + A tensor that contains offsets range + +**Example:** + >>> # Generate example inputs + >>> lengths = torch.tensor([3, 4, 1, 9, 3, 7], dtype=torch.int64, device="cuda") + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> range_size = offsets[-1].item() + >>> print(range_size) + 27 + >>> offsets = offsets[:-1] + >>> print(offsets) + tensor([ 0, 3, 7, 8, 17, 20], device='cuda:0') + >>> # Invoke + >>> torch.ops.fbgemm.offsets_range(offsets, range_size) + tensor([0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, + 4, 5, 6], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.segment_sum_csr, + """ +segment_sum_csr(batch_size, csr_seg, values) -> Tensor + +Sum values within each segment on the given CSR data where each row has the +same number of non-zero elements. + +Args: + batch_size (int): The row stride (number of non-zero elements in each row) + + csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment + length is the number of rows within each segment. The shape of the + `csr_seg` tensor is `num_segments + 1` where `num_segments` is the + number of segments. + + values (Tensor): The values tensor to be segment summed. The number of + elements in the tensor must be multiple of `batch_size` + +Returns: + A tensor containing the segment sum results. Shape is the number of + segments. + +**Example:** + + >>> batch_size = 2 + >>> # Randomize inputs + >>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda") + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32) + >>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda") + >>> print(values) + tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02, + -1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03, + 2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00, + -3.5531e-01], device='cuda:0') + >>> # Invoke + >>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values) + tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.keyed_jagged_index_select_dim1, + """ +keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor] + +Perform an index select operation on the batch dimension (dim 1) of the given +keyed jagged tensor (KJT) input. The same samples in the batch of every key +will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`, +jagged dim), where `num_keys` is the number of keys, and `batch_size` is the +batch size. This operator is similar to a permute operator. + +Args: + values (Tensor): The KJT values tensor which contains concatenated data of + every key + + lengths (Tensor): The KJT lengths tensor which contains the jagged shapes + of every key (dim 0) and sample (dim 1). Shape is `num_keys * + batch_size` + + offsets (Tensor): The KJT offsets tensor which is the complete cumulative + sum of `lengths`. Shape is `num_keys * batch_size + 1` + + indices (Tensor): The indices to select, i.e., samples in the batch to + select. The values of `indices` must be >= 0 and < `batch_size` + + batch_size (int): The batch size (dim 1 of KJT) + + weights (Optional[Tensor] = None): An optional float tensor which will be + selected the same way as `values`. Thus, it must have the same shape as + `values` + + selected_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the index select data + (output shape). If not provided, the operator will compute this data + which may cause a device-host synchronization (if using GPU). Thus, it + is recommended to supply this value to avoid such the synchronization. + +Returns: + The index-select KJT tensor (as a list of values, lengths, and weights if + `weights` is not None) + +**Example:** + + >>> num_keys = 2 + >>> batch_size = 4 + >>> output_size = 3 + >>> # Randomize inputs + >>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda") + >>> print(lengths) + tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0') + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0') + >>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda") + >>> print(indices) + tensor([3, 3, 1], device='cuda:0') + >>> # Use torch.arange instead of torch.randn to simplify the example + >>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda") + >>> print(values) + tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., + 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., + 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.], + device='cuda:0') + >>> # Invoke. Output = (output, lengths) + >>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size) + [tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37., + 38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'), + tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] + """, +) + +add_docs( + torch.ops.fbgemm.block_bucketize_sparse_features, + """ +block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]] + +Preprocess sparse features by partitioning sparse features into multiple +buckets. Every feature is split into the same number of buckets, but the bucket +sizes (widths) for the different features can be different. Moreover, the +bucket sizes within each feature can be different. + +Args: + lengths (Tensor): The lengths of the sparse features. The tensor contains + the lengths of each sample in a batch and each feature. Shape is `B * + T` where `B` is the batch size and `T` is the number of features + + indices (Tensor): The sparse data. Only support integer types. Shape is the + sum of `lengths` + + bucketize_pos (bool): If True, return the original relative indices within + a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths = + [3, 4]`. The original relative indices within a sample for the indices + are `[0, 1, 2, 0, 1, 2, 3]` + + sequence (bool): If True, return the new indices positions in the original + indices positions (the tensor is called `unbucketize_permute_data`). + + block_sizes (Tensor): This tensor is used for the case where the bucket + size within a feature is uniform (i.e., when + `block_bucketize_pos=None`). The tensor contains bucket sizes (i.e., + bucket widths) for each feature. `block_sizes[t]` represents the + bucket size of feature `t`. Shape is the number of features. + + my_size (int): The number of buckets for each feature. Note that every + feature has the same number of buckets. + + weights (Optional[Tensor] = None): An optional float tensor that will be + bucketized the same way as `indices`. This tensor must have the same + shape as `indices` + + batch_size_per_feature (Optional[Tensor] = None): An optional tensor that + contains batch sizes for different features. If not None, batch sizes + are not uniform among features. Otherwise, the operator will assume + that the batch size is uniform and infer it from the `lengths` and + `block_sizes` tensors + + max_B (int = -1): The max batch size. Must be set if + `batch_size_per_feature` is not None + + block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for + non-uniform bucket sizes within a feature. `block_bucketize_pos` is a + list of tensors. Each tensor contains the range offsets of buckets for + each feature. These range offsets are equivalent to the complete + cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents + two buckets. The first bucket size is `(4 - 0) = 4`, and the second + bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos` + must be equal to the number of features. + + keep_orig_idx (bool = False): If True, return original indices instead of + the relative indices within each bucket + +Return: + A tuple of tensors containing + + (1) Bucketized lengths. Shape is `lengths.num() * my_size`. + + (2) Bucketized indices. Same shape as `indices`. + + (3) Bucketized weights or None if `weights` is None. Same shape as + `indices`. + + (4) Bucketized positions or None if `bucketize_pos=False`. Same shape as + `indices`. + + (5) `unbucketize_permute` or None if `sequence=False`. Same shape as + `indices` + +**Example**: + + >>> # Generate input example. Batch size = 2. Number of features = 4 + >>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda") + >>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda") + >>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda") + >>> my_size = 2 # Number of buckets + >>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and + >>> # sequence=False + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=False) + >>> # The first 8 values in the returned lengths are the lengths for bucket + >>> # 0 and the rests are the legths for bucket 1 + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 0, 13, 14, 0, 1, 2, 3, 2, 0, 0], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + >>> # Invoke with keep_orig_idx=True, bucketize_pos=True, and + >>> # sequence=True + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=True, + >>> sequence=True, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=True) + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 10, 11, 12, 13, 22, 20, 20], + device='cuda:0', dtype=torch.int32), + None, + tensor([0, 1, 0, 0, 0, 0, 1, 2, 1, 0, 1, 2, 1, 2, 0], device='cuda:0', + dtype=torch.int32), + tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14], + device='cuda:0', dtype=torch.int32)) + >>> # Invoke with block_bucketize_pos + >>> block_bucketize_pos = [ + >>> torch.tensor([0, 2, 8], dtype=torch.int), + >>> torch.tensor([0, 5, 10], dtype=torch.int), + >>> torch.tensor([0, 7, 12], dtype=torch.int), + >>> torch.tensor([0, 2, 16], dtype=torch.int), + >>> ] + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> block_bucketize_pos=block_bucketize_pos, + >>> keep_orig_idx=False) + (tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0', + dtype=torch.int32), + tensor([14, 1, 6, 11, 10, 10, 1, 2, 7, 5, 14, 3, 4, 6, 9], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + """, +) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index de9a21ef9..71e0e2ccc 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -20,20 +20,18 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" ) else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 554bd0b00..1730dbedc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -145,6 +145,15 @@ class GlobalWeightDecayDefinition: lower_bound: float = 0.0 +@dataclass(frozen=True) +class EnsembleModeDefinition: + step_ema: float = 10000 + step_swap: float = 10000 + step_start: float = 0 + step_ema_coef: float = 0.6 + step_mode: StepMode = StepMode.USE_ITER + + # Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh class UVMCacheStatsIndex(enum.IntEnum): num_calls = 0 @@ -449,8 +458,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): Adam. Note that default is different from torch.nn.optim.Adagrad default of 1e-10 - momentum (float = 0.9): Momentum used by LARS-SGD and - ENSEMBLE_ROWWISE_ADAGRAD + momentum (float = 0.9): Momentum used by LARS-SGD weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM, and rowwise-Adagrad. @@ -473,14 +481,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): beta2 (float = 0.999): The beta2 value used by LAMB and ADAM - step_ema (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_swap (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_start (float = 0.0): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_mode: (StepMode = StepMode.USE_ITER): Used by - ENSEMBLE_ROWWISE_ADAGRAD + ensemble_mode (Optional[EnsembleModeDefinition] = None): + Used by Ensemble Rowwise Adagrad counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None): Used by Rowwise Adagrad @@ -598,10 +600,7 @@ def __init__( # noqa C901 eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, - step_ema: float = 10000, - step_swap: float = 10000, - step_start: float = 0, - step_mode: StepMode = StepMode.USE_ITER, + ensemble_mode: Optional[EnsembleModeDefinition] = None, counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, @@ -623,7 +622,10 @@ def __init__( # noqa C901 self.uuid = str(uuid.uuid4()) self.logging_table_name: str = self.get_table_name_for_logging(table_names) self.pooling_mode = pooling_mode - self.bounds_check_mode_int: int = bounds_check_mode.value + # If environment variable is set, it overwrites the default bounds check mode. + self.bounds_check_mode_int: int = int( + os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) + ) self.weights_precision = weights_precision self.output_dtype: int = output_dtype.as_int() assert ( @@ -920,6 +922,12 @@ def __init__( # noqa C901 self.gwd_start_iter: int = global_weight_decay.start_iter self.gwd_lower_bound: float = global_weight_decay.lower_bound + if ensemble_mode is None: + ensemble_mode = EnsembleModeDefinition() + self._ensemble_mode: Dict[str, float] = { + key: float(fval) for key, fval in ensemble_mode.__dict__.items() + } + if counter_based_regularization is None: counter_based_regularization = CounterBasedRegularizationDefinition() if cowclip_regularization is None: @@ -957,10 +965,6 @@ def __init__( # noqa C901 eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=opt_arg_weight_decay_mode.value, eta=eta, @@ -995,11 +999,13 @@ def __init__( # noqa C901 if ( optimizer_state_dtypes is None or "momentum1" not in optimizer_state_dtypes + or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ) else optimizer_state_dtypes["momentum1"].as_dtype() ) rowwise = optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ] self._apply_split( construct_split_state( @@ -1029,7 +1035,6 @@ def __init__( # noqa C901 rowwise = optimizer in ( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, - OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ) momentum2_dtype = ( torch.float32 @@ -1059,9 +1064,7 @@ def __init__( # noqa C901 else: # NOTE: make TorchScript work! self._register_nonpersistent_buffers("momentum2") - if self._used_rowwise_adagrad_with_counter or ( - optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: self._apply_split( construct_split_state( embedding_specs, @@ -1865,18 +1868,15 @@ def forward( # noqa: C901 assert self._feature_is_enabled( FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD ), "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!" + with torch.no_grad(): + if self.training: + self.ensemble_and_swap(self._ensemble_mode) return self._report_io_size_count( "fwd_output", - invokers.lookup_ensemble_rowwise_adagrad.invoke( + invokers.lookup_rowwise_adagrad.invoke( common_args, self.optimizer_args, momentum1, - momentum2, - prev_iter, - row_counter, - iter=int(self.iter.item()), - apply_global_weight_decay=False, - gwd_lower_bound=0.0, ), ) @@ -1908,8 +1908,9 @@ def forward( # noqa: C901 ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: + iter_ = int(self.iter.item()) apply_global_weight_decay = ( - self.step >= self.gwd_start_iter and self.training + iter_ >= self.gwd_start_iter and self.training ) return self._report_io_size_count( "fwd_output", @@ -1917,7 +1918,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, - iter=int(self.iter.item()), + iter=iter_, apply_global_weight_decay=apply_global_weight_decay, prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, @@ -1935,6 +1936,33 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") + def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: + should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 + should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 + if should_ema or should_swap: + weights = self.split_embedding_weights() + states = self.split_optimizer_states() + for i in range(len(self.embedding_specs)): + if should_ema: + step_start = int(ensemble_mode["step_start"]) + if int(ensemble_mode["step_mode"]) == 1: + should_ema_reset = self.iter.item() % step_start == 0 + elif int(ensemble_mode["step_mode"]) == 2: + should_ema_reset = self.iter.item() <= step_start + else: + should_ema_reset = (self.iter.item() <= step_start) or ( + self.iter.item() % step_start == 0 + ) + coef_ema = ( + 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] + ) + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + if should_swap: + weights[i].copy_(states[i][1], non_blocking=True) + def reset_uvm_cache_stats(self) -> None: assert ( self.gather_uvm_cache_stats @@ -2343,9 +2371,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: list_of_state_dict = [ { "sum": states[0], - "exp_avg": states[1], - "prev_iter": states[2], - "row_counter": states[3], + "sparse_ema": states[1], } for states in split_optimizer_states ] @@ -2390,8 +2416,7 @@ def split_optimizer_states( (8) `PARTIAL_ROWWISE_LAMB`: `momentum1`, `momentum2` (rowwise) - (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum2` (rowwise), `momentum1`, - `prev_iter` (rowwise), `row_counter` (rowwise) + (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `momentum2` (10) `NONE`: no states (throwing an error) @@ -2428,19 +2453,6 @@ def get_optimizer_states( return splits states: List[List[torch.Tensor]] = [] - # For ensemble_rowwise_adagrad, momentum2 ("sum") should go first, - # as it is the default optimizer state for embedding pruning later. - if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - states.append( - get_optimizer_states( - self.momentum2_dev, - self.momentum2_host, - self.momentum2_uvm, - self.momentum2_physical_offsets, - self.momentum2_physical_placements, - rowwise=True, - ) - ) if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( @@ -2452,6 +2464,7 @@ def get_optimizer_states( rowwise=self.optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ], ) ) @@ -2460,6 +2473,7 @@ def get_optimizer_states( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.LAMB, OptimType.PARTIAL_ROWWISE_LAMB, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ): states.append( get_optimizer_states( @@ -2475,7 +2489,6 @@ def get_optimizer_states( if ( self._used_rowwise_adagrad_with_counter or self._used_rowwise_adagrad_with_global_weight_decay - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ): states.append( get_optimizer_states( @@ -2487,10 +2500,7 @@ def get_optimizer_states( rowwise=True, ) ) - if ( - self._used_rowwise_adagrad_with_counter - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: states.append( get_optimizer_states( self.row_counter_dev, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 2c02db4b9..c1f13bd74 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -37,7 +37,6 @@ apply_split_helper, CounterBasedRegularizationDefinition, CowClipDefinition, - StepMode, UVMCacheStatsIndex, WeightDecayMode, ) @@ -116,10 +115,6 @@ def __init__( eta: float = 0.001, # used by LARS-SGD, beta1: float = 0.9, # used by LAMB and ADAM beta2: float = 0.999, # used by LAMB and ADAM - step_ema: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_swap: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_start: float = 0, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_mode: StepMode = StepMode.USE_ITER, # used by ENSEMBLE_ROWWISE_ADAGRAD counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, # used by Rowwise Adagrad @@ -499,8 +494,6 @@ def __init__( # pyre-fixme[4]: Attribute must be annotated. self.ssd_prefetch_data = [] - # Scratch pad value queue - self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] # Scratch pad eviction data queue self.ssd_scratch_pad_eviction_data: List[ Tuple[Tensor, Tensor, Tensor, bool] @@ -508,6 +501,9 @@ def __init__( self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = [] if self.prefetch_pipeline: + # Scratch pad value queue + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # pyre-ignore[4] # Scratch pad index queue self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue( @@ -535,10 +531,6 @@ def __init__( eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=weight_decay_mode.value, eta=eta, @@ -1407,15 +1399,16 @@ def prefetch( # noqa C901 if t.is_cuda: t.record_stream(forward_stream) - # Store scratch pad info for the lookup in the next iteration - # prefetch - self.ssd_scratch_pads.append( - ( - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, + if self.prefetch_pipeline: + # Store scratch pad info for the lookup in the next iteration + # prefetch + self.ssd_scratch_pads.append( + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + ) ) - ) # Store scratch pad info for post backward eviction self.ssd_scratch_pad_eviction_data.append( diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index 551e83d7c..f81276a9d 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -11,7 +11,7 @@ #include #ifdef FBGEMM_FBCODE -#include "fbgemm_gpu/config/feature_gates_fb.h" +#include "deeplearning/fbgemm/fbgemm_gpu/fb/include/fbgemm_gpu/config/feature_gates_fb.h" #endif /// @defgroup fbgemm-gpu-config FBGEMM_GPU Configuration diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 97353e03c..2164afd3e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D( __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +__device__ inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + // ---------------------- START cp.async helpers, copied from CUTLASS /// CUTLASS helper to get SMEM pointer diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 9bea430ef..41ba190fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,8 +9,11 @@ #pragma once #include +#include #include #include +#include + #include namespace fbgemm_gpu { @@ -924,6 +927,44 @@ at::Tensor index_add_with_unique_indices_cuda( const int consecutive_range_start, const int consecutive_range_length); +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref); + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size); + +class GroupIndexSelectDim0Op + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + at::TensorList all_indices_input, + const int64_t group_size); + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group); +}; + ///@ingroup sparse-data-cuda void group_index_select_or_add_cuda( const int64_t* input_ptrs, diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index 42fe5eb4c..8351e046c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -12,14 +12,7 @@ #include #include #include "fbgemm_gpu/embedding_common.h" - -// These values are adjusted in backward based on B and T -constexpr int DEFAULT_INFO_NUM_BITS = 32; -constexpr int DEFAULT_INFO_B_NUM_BITS = 26; -constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; -constexpr uint32_t MAX_T = - (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; -constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +#include "fbgemm_gpu/split_embeddings_utils.h" /** * "Transpose" embedding inputs by sorting indices by their values. @@ -50,11 +43,6 @@ transpose_embedding_input( const int64_t fixed_L_per_warp = 0, const int64_t num_warps_per_feature = 0); -std::tuple -get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); - -std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); - // Use these functions instead of directly calling cub functions // to reduce code size and compilation time. // Arguments are the same as cub::DeviceRadixSort::SortPairs @@ -77,15 +65,3 @@ DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); #undef DECL_RADIX_SORT_PAIRS_FN - -std::tuple -generate_vbe_metadata( - const at::Tensor& B_offsets, - const at::Tensor& B_offsets_rank_per_feature, - const at::Tensor& output_offsets_feature_rank, - const at::Tensor& D_offsets, - const int64_t D, - const bool nobag, - const int64_t max_B_feature_rank, - const int64_t info_B_num_bits, - const int64_t total_B); diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h new file mode 100644 index 000000000..b41681012 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +// These values are adjusted in backward based on B and T +constexpr int DEFAULT_INFO_NUM_BITS = 32; +constexpr int DEFAULT_INFO_B_NUM_BITS = 26; +constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +constexpr uint32_t MAX_T = + (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; +constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; + +std::tuple +get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); + +std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); + +std::tuple +generate_vbe_metadata( + const at::Tensor& B_offsets, + const at::Tensor& B_offsets_rank_per_feature, + const at::Tensor& output_offsets_feature_rank, + const at::Tensor& D_offsets, + const int64_t D, + const bool nobag, + const int64_t max_B_feature_rank, + const int64_t info_B_num_bits, + const int64_t total_B); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h new file mode 100644 index 000000000..3aff58c9a --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index b1ab0306c..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -299,3 +299,77 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( } return aligned_grad_output; } + +template +std::string tensor_scalar_type_is_one_of( + const at::Tensor& ten, + const ScalarTypes&... ttypes) { + auto has_match = false; + + ( + [&](const auto& ttype) { + if (ten.scalar_type() == ttype) { + has_match = true; + } + }(ttypes), + ...); + + if (has_match) { + return ""; + } + + std::string msg = "Tensor's scalar type ("; + msg.append(toString(ten.scalar_type())); + msg.append(") did not match any one of the following types: ["); + ( + [&](const auto& ttype) { + msg.append(toString(ttype)); + msg.append(", "); + }(ttypes), + ...); + + msg.append("]"); + return msg; +} + +#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \ + do { \ + const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ + TORCH_CHECK(has_match.empty(), has_match); \ + } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false) diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu index e3c88b101..6870cdfb9 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu @@ -9,19 +9,19 @@ // clang-format off #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" #include -#include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" // clang-format on #include +#include #include #include #include #include #include -#include "ATen/Parallel.h" #include "fbgemm_gpu/layout_transform_ops.cuh" #include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/utils/tensor_utils.h" diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp index e9159f37a..a01de3a67 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp @@ -7,9 +7,9 @@ */ #include +#include #include #include -#include "ATen/Parallel.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" diff --git a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h index 0498bda96..9e091f1c1 100644 --- a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h @@ -9,8 +9,8 @@ #pragma once #include "../ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h" -#include -#include +#include +#include #include "mvai_infra/experimental/ps_training/tps_client/TrainingParameterServiceClient.h" namespace ps { diff --git a/fbgemm_gpu/src/sparse_ops/common.h b/fbgemm_gpu/src/sparse_ops/common.h new file mode 100644 index 000000000..1cdd8ce9e --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/common.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +namespace { +inline Tensor native_empty_like(const Tensor& self) { + return at::native::empty_like( + self, + c10::optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().layout_opt(), + self.options().device_opt(), + self.options().pinned_memory_opt(), + c10::nullopt); +} + +} // namespace + +}; // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp new file mode 100644 index 000000000..e3f04b58e --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "common.h" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" +#include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3] +// Used as a helper to several functions below. +template +U exclusive_scan_ptrs_cpu( + const int64_t N, + const T* const input, + U* const output) { + U cumsum = 0; + for (const auto i : c10::irange(N)) { + output[i] = cumsum; + cumsum += input[i]; + } + return cumsum; +} + +void asynchronous_exclusive_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + TENSOR_ON_CPU(t_out); + + const auto t_in_contig = t_in.expect_contiguous(); + at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt); + + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_exclusive_cumsum_cpu_kernel", + [&] { + exclusive_scan_ptrs_cpu( + t_in_contig->numel(), + t_in_contig->data_ptr(), + t_out.data_ptr()); + }); +} + +Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + + const auto t_in_contig = t_in.expect_contiguous(); + auto output = native_empty_like(*t_in_contig); + asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig); + return output; +} + +Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + + const auto t_in_contig = t_in.expect_contiguous(); + auto output = native_empty_like(*t_in_contig); + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_inclusive_cumsum_cpu_kernel", + [&] { + scalar_t cumsum = 0; + const auto* input_ptr = t_in_contig->data_ptr(); + const auto N = t_in_contig->numel(); + auto* output_ptr = output.data_ptr(); + + for (const auto i : c10::irange(N)) { + cumsum += input_ptr[i]; + output_ptr[i] = cumsum; + } + }); + return output; +} + +Tensor asynchronous_complete_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + TENSOR_ON_CPU(t_out); + const auto num_dims = t_in.dim(); + TORCH_CHECK(num_dims == 1 || num_dims == 2); + const auto t_in_contig = t_in.expect_contiguous(); + const auto t_out_contig = t_out.expect_contiguous(); + + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_complete_cumsum_cpu_kernel", + [&] { + if (num_dims == 1) { + const auto N = t_in_contig->numel(); + t_out.data_ptr()[N] = exclusive_scan_ptrs_cpu( + N, t_in_contig->data_ptr(), t_out.data_ptr()); + } else { + const auto num_vecs = t_in_contig->size(0); + const auto N = t_in_contig->size(1); + at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + scalar_t* out_ptr = t_out.data_ptr() + i * (N + 1); + out_ptr[N] = exclusive_scan_ptrs_cpu( + N, t_in_contig->data_ptr() + i * N, out_ptr); + } + }); + } + }); + return t_out; +} + +Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) { + const auto num_dims = t_in.dim(); + TORCH_CHECK(num_dims == 1 || num_dims == 2); + auto output = num_dims == 1 + ? at::empty({t_in.numel() + 1}, t_in.options()) + : at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options()); + + return asynchronous_complete_cumsum_cpu_out(output, t_in); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); +} + +TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + DISPATCH_TO_CPU( + "asynchronous_exclusive_cumsum", + fbgemm_gpu::asynchronous_exclusive_cumsum_cpu); + DISPATCH_TO_CPU( + "asynchronous_inclusive_cumsum", + fbgemm_gpu::asynchronous_inclusive_cumsum_cpu); + DISPATCH_TO_CPU( + "asynchronous_complete_cumsum", + fbgemm_gpu::asynchronous_complete_cumsum_cpu); +} diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index a80eea05e..88d9ef2e6 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -20,6 +20,7 @@ #include #include +#include "common.h" #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" @@ -128,16 +129,6 @@ Tensor pack_segments_autograd( return PackSegments::apply(t_in, lengths, max_length)[0]; } -Tensor native_empty_like(const Tensor& self) { - return at::native::empty_like( - self, - c10::optTypeMetaToScalarType(self.options().dtype_opt()), - self.options().layout_opt(), - self.options().device_opt(), - self.options().pinned_memory_opt(), - c10::nullopt); -} - template void prefix_sum(const int length, const T* const array, T* const presum) { presum[0] = 0; @@ -1317,115 +1308,6 @@ bucketize_sparse_features_cpu( return {new_lengths, new_indices, new_weights, new_pos}; } -// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3] -// Used as a helper to several functions below. -template -U exclusive_scan_ptrs_cpu( - const int64_t N, - const T* const input, - U* const output) { - U cumsum = 0; - for (const auto i : c10::irange(N)) { - output[i] = cumsum; - cumsum += input[i]; - } - return cumsum; -} - -void asynchronous_exclusive_cumsum_cpu_out( - at::Tensor& t_out, - const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - TENSOR_ON_CPU(t_out); - - const auto t_in_contig = t_in.expect_contiguous(); - at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt); - - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_exclusive_cumsum_cpu_kernel", - [&] { - exclusive_scan_ptrs_cpu( - t_in_contig->numel(), - t_in_contig->data_ptr(), - t_out.data_ptr()); - }); -} - -Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - - const auto t_in_contig = t_in.expect_contiguous(); - auto output = native_empty_like(*t_in_contig); - asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig); - return output; -} - -Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - - const auto t_in_contig = t_in.expect_contiguous(); - auto output = native_empty_like(*t_in_contig); - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_inclusive_cumsum_cpu_kernel", - [&] { - scalar_t cumsum = 0; - const auto* input_ptr = t_in_contig->data_ptr(); - const auto N = t_in_contig->numel(); - auto* output_ptr = output.data_ptr(); - - for (const auto i : c10::irange(N)) { - cumsum += input_ptr[i]; - output_ptr[i] = cumsum; - } - }); - return output; -} - -at::Tensor asynchronous_complete_cumsum_cpu_out( - at::Tensor& t_out, - const at::Tensor& t_in) { - TENSOR_ON_CPU(t_in); - TENSOR_ON_CPU(t_out); - const auto num_dims = t_in.dim(); - TORCH_CHECK(num_dims == 1 || num_dims == 2); - const auto t_in_contig = t_in.expect_contiguous(); - const auto t_out_contig = t_out.expect_contiguous(); - - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_complete_cumsum_cpu_kernel", - [&] { - if (num_dims == 1) { - const auto N = t_in_contig->numel(); - t_out.data_ptr()[N] = exclusive_scan_ptrs_cpu( - N, t_in_contig->data_ptr(), t_out.data_ptr()); - } else { - const auto num_vecs = t_in_contig->size(0); - const auto N = t_in_contig->size(1); - at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) { - for (const auto i : c10::irange(start, end)) { - scalar_t* out_ptr = t_out.data_ptr() + i * (N + 1); - out_ptr[N] = exclusive_scan_ptrs_cpu( - N, t_in_contig->data_ptr() + i * N, out_ptr); - } - }); - } - }); - return t_out; -} - -Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) { - const auto num_dims = t_in.dim(); - TORCH_CHECK(num_dims == 1 || num_dims == 2); - auto output = num_dims == 1 - ? at::empty({t_in.numel() + 1}, t_in.options()) - : at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options()); - - return asynchronous_complete_cumsum_cpu_out(output, t_in); -} - template void reorder_batched_ad_lengths_( const Tensor& cat_ad_lengths, @@ -2969,19 +2851,84 @@ Tensor pack_segments_cpu( const int64_t max_length) { return pack_segments_forward_cpu(t_in, lengths, max_length); } -namespace { -Tensor index_select_dim0( - const Tensor& input, - const Tensor& indices, - std::optional /*consecutive_range_start*/, - std::optional /*consecutive_range_length*/, - std::optional /*skip_indices_sorting_fwd*/) { - return at::index_select(input, 0, indices); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + return GroupIndexSelectDim0Op::apply(all_indices_input, group_size); +} + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size) { + std::vector indices_group; + std::vector input_group; + + indices_group.reserve(group_size); + input_group.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); + } + + TORCH_CHECK(group_size == static_cast(indices_group.size())); + + return std::make_pair(input_group, indices_group); } torch::autograd::variable_list group_index_select_dim0( at::TensorList input_group, at::TensorList indices_group) { + const auto group_size = indices_group.size(); + std::vector output_group; + + if (group_size == 0) { + return std::vector(); + } + + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); + } + + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto res = forward_op.call(all_indices_input_tensor, group_size); + TORCH_CHECK(res.size() == group_size + 2); + // only return the outputs (the first group_size elements) + res.resize(group_size); + return res; +} + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size) { + throw std::runtime_error( + "group_index_select_dim0_forward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + throw std::runtime_error( + "group_index_select_dim0_backward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group) { int num_groups = input_group.size(); TORCH_CHECK(num_groups == (int)indices_group.size()) std::vector output_group; @@ -2992,18 +2939,83 @@ torch::autograd::variable_list group_index_select_dim0( return output_group; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu( +torch::autograd::variable_list GroupIndexSelectDim0Op::forward( + torch::autograd::AutogradContext* ctx, at::TensorList all_indices_input, const int64_t group_size) { - throw std::runtime_error( - "group_index_select_dim0_gpu_impl is not implemented for CPU"); + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto result = forward_op.call(all_indices_input, group_size); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); + + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + const auto input_dim = input_group[0].dim(); + std::vector input_shape_group; + input_shape_group.reserve(group_size * input_dim); + + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + // Copy input shape + auto input_shape = input.sym_sizes().vec(); + input_shape_group.insert( + input_shape_group.end(), input_shape.begin(), input_shape.end()); + } + + // save indices, args_tensor, saved_data + auto saved_tensors = std::vector(indices_group); + saved_tensors.insert( + saved_tensors.end(), result.cbegin() + group_size, result.cend()); + saved_tensors.push_back(input_group[0]); + ctx->save_for_backward(saved_tensors); + ctx->saved_data["input_shape_group"] = input_shape_group; + + return result; } -torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - throw std::runtime_error( - "group_index_select_dim0_gpu_backward is not implemented for CPU"); +torch::autograd::variable_list GroupIndexSelectDim0Op::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { + // empty outputs + return torch::autograd::variable_list(1); + } + // remove redundant grads + auto group_size = grad_output_group.size() - 2; + grad_output_group.resize(group_size); + + auto saved_tensors = ctx->get_saved_variables(); + TORCH_CHECK(saved_tensors.size() == group_size + 3); + auto output_shape_group = + ctx->saved_data["input_shape_group"].toSymIntVector(); + grad_output_group.insert( + grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); + static auto backward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_backward", "") + .typed(); + auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill(res.begin(), res.begin() + group_size, torch::autograd::Variable()); + // 3) Add 1 Variable() for group_size + res.push_back({}); + return res; +} + +namespace { +Tensor index_select_dim0( + const Tensor& input, + const Tensor& indices, + std::optional /*consecutive_range_start*/, + std::optional /*consecutive_range_length*/, + std::optional /*skip_indices_sorting_fwd*/) { + return at::index_select(input, 0, indices); } Tensor bottom_k_per_row( @@ -3100,15 +3112,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)"); m.def( "bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)"); - m.def( - "asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); - m.def( - "asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); - m.def( - "asynchronous_complete_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); m.def( "reorder_batched_sequence_embeddings(Tensor cat_sequence_embeddings_offsets, Tensor cat_sequence_embeddings, Tensor reordered_cat_sequence_embeddings_offsets, Tensor batch_offsets, SymInt num_items_in_batch) -> Tensor"); m.def( @@ -3214,15 +3217,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { fbgemm_gpu::block_bucketize_sparse_features_inference_cpu); DISPATCH_TO_CPU( "bucketize_sparse_features", fbgemm_gpu::bucketize_sparse_features_cpu); - DISPATCH_TO_CPU( - "asynchronous_exclusive_cumsum", - fbgemm_gpu::asynchronous_exclusive_cumsum_cpu); - DISPATCH_TO_CPU( - "asynchronous_inclusive_cumsum", - fbgemm_gpu::asynchronous_inclusive_cumsum_cpu); - DISPATCH_TO_CPU( - "asynchronous_complete_cumsum", - fbgemm_gpu::asynchronous_complete_cumsum_cpu); DISPATCH_TO_CPU( "reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu); DISPATCH_TO_CPU( @@ -3268,13 +3262,14 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu); DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); + "group_index_select_dim0", + fbgemm_gpu::group_index_select_dim0_decomposed); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu); + fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu); + fbgemm_gpu::group_index_select_dim0_backward_impl_cpu); DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row); } @@ -3283,11 +3278,14 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { } TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); + m.impl( + "group_index_select_dim0", + &fbgemm_gpu::group_index_select_dim0_decomposed); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { // CPU group_index_select_dim0 is decomposable m.impl( - "group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0)); + "group_index_select_dim0", + TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 6325017e8..0c3966fc3 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -193,442 +193,346 @@ class IndexSelectDim0GPUOp } }; -std::pair, std::vector> -group_index_select_dim0_unpack( +// need to combine input_group and indices_group into one tensor list +// to get this working with autograd. +static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( at::TensorList all_indices_input, const int64_t group_size) { - std::vector indices_group; - std::vector input_group; + // Unpack from TensorList + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; + } + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; + + // Allocate memory for GroupIndexSelectArgs + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + auto& first_input = input_group[0]; + auto& first_indices = indices_group[0]; + + const int input_dim = first_input.dim(); + const int num_output_rows = first_indices.size(0); + const int num_input_rows = first_input.size(0); + Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); + const int num_cols = input_reshaped.size(1); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t warp_offset = 0; + bool use_var_cols = false; + + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size + 2); - indices_group.reserve(group_size); - input_group.reserve(group_size); + // We need to store contiguous inputs and indices outside the for-loop to + // guarantee that the contiguous tensors will outlive the kernel + // computation + std::vector> input_contigs; + std::vector> index_contigs; + input_contigs.reserve(group_size); + index_contigs.reserve(group_size); + // For each group, copy input to output for (const auto i : c10::irange(group_size)) { - indices_group.push_back(all_indices_input[i]); - input_group.push_back(all_indices_input[group_size + i]); - } + const auto& input = input_group[i]; + const auto& indices = indices_group[i]; - TORCH_CHECK(group_size == static_cast(indices_group.size())); + // Verify that all input tensors have the same number of dimensions + TORCH_CHECK( + input_dim == input.dim(), + "All inputs in group_index_select must have the same number of dimensions"); - return std::make_pair(input_group, indices_group); -} + // Verify that all tensors are on the same GPU + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); -class GroupIndexSelectDim0GPUOp - : public torch::autograd::Function { - public: - // need to combine input_group and indices_group into one tensor list - // to get this working with autograd. - static torch::autograd::variable_list forward_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - // Unpack from TensorList - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - - // args_tensor stores kernel arguments: - // input_ptrs (group_size int64_t elements) - // output_ptrs (group_size int64_t elements) - // indices_ptrs (group_size int64_t elements) - // warp_offsets_group (group_size + 1 int64_t elements) - // num_cols_group (group_size int32_t elements) - int64_t args_ptrs_offsets[NUM_ARGS + 1]; - - const int64_t numels_num_cols_group_64 = - compute_num_int64s(group_size); - - // Initialize offsets - args_ptrs_offsets[P_input_ptrs] = group_size; - args_ptrs_offsets[P_output_ptrs] = group_size; - args_ptrs_offsets[P_indices_ptrs] = group_size; - args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; - args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; - - // Compute offsets - int64_t offset = 0; - auto next = args_ptrs_offsets[0]; - for (const auto i : c10::irange(NUM_ARGS)) { - args_ptrs_offsets[i] = offset; - offset += next; - next = args_ptrs_offsets[i + 1]; + auto num_output_rows_ = indices.size(0); + + // Verify that all input tensors have the same shape[0] + TORCH_CHECK( + num_output_rows == num_output_rows_, + "The number of indices to be selected must be the same for the entire group"); + const auto input_reshaped_ = input.reshape({input.size(0), -1}); + + // Number of columns can be different + auto num_cols_ = input_reshaped_.size(1); + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + if (num_cols != num_cols_) { + use_var_cols = true; } - // Total number of int64_t elements required - args_ptrs_offsets[NUM_ARGS] = offset; - - // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( - {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, - at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; - - // Offset host pointers - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - auto& first_input = input_group[0]; - auto& first_indices = indices_group[0]; - - const int input_dim = first_input.dim(); - const int num_output_rows = first_indices.size(0); - const int num_input_rows = first_input.size(0); - Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); - const int num_cols = input_reshaped.size(1); - const int cols_per_warp = get_group_index_select_cols_per_warp(); - int64_t warp_offset = 0; - bool use_var_cols = false; - - // Allocate memory for output_group - std::vector output_group; - output_group.reserve(group_size + 2); - - // We need to store contiguous inputs and indices outside the for-loop to - // guarantee that the contiguous tensors will outlive the kernel + + // Create output pointers + auto input_shape = input.sizes().vec(); + input_shape[0] = num_output_rows_; + Tensor output = at::empty(input_shape, input.options()); + // Ensure that the allocated output is contiguous + TORCH_CHECK(output.is_contiguous()) + output_group.push_back(output); + + // Store input and indices contigs to keep them alive during the kernel // computation - std::vector> input_contigs; - std::vector> index_contigs; - input_contigs.reserve(group_size); - index_contigs.reserve(group_size); - - // For each group, copy input to output - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - const auto& indices = indices_group[i]; - - // Verify that all input tensors have the same number of dimensions - TORCH_CHECK( - input_dim == input.dim(), - "All inputs in group_index_select must have the same number of dimensions"); - - // Verify that all tensors are on the same GPU - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); - - auto num_output_rows_ = indices.size(0); - - // Verify that all input tensors have the same shape[0] - TORCH_CHECK( - num_output_rows == num_output_rows_, - "The number of indices to be selected must be the same for the entire group"); - const auto input_reshaped_ = input.reshape({input.size(0), -1}); - - // Number of columns can be different - auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; - - if (num_cols != num_cols_) { - use_var_cols = true; - } - - // Create output pointers - auto input_shape = input.sizes().vec(); - input_shape[0] = num_output_rows_; - Tensor output = at::empty(input_shape, input.options()); - // Ensure that the allocated output is contiguous - TORCH_CHECK(output.is_contiguous()) - output_group.push_back(output); - - // Store input and indices contigs to keep them alive during the kernel - // computation - input_contigs.push_back(input.expect_contiguous()); - index_contigs.push_back(indices.expect_contiguous()); - - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - - warp_offset += warps_per_row * num_output_rows; - } + input_contigs.push_back(input.expect_contiguous()); + index_contigs.push_back(indices.expect_contiguous()); - // Store the last offset - warp_offsets_group[group_size] = warp_offset; - - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); - - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - int64_t saved_data[] = { - static_cast(group_size), - use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, - }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, - at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); - - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); - - // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) - return output_group; + // Store args + input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; } - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - at::TensorList all_indices_input, - const int64_t group_size) { - at::AutoDispatchBelowADInplaceOrView guard; - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); - - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - const auto input_dim = input_group[0].dim(); - std::vector input_shape_group; - input_shape_group.reserve(group_size * input_dim); - - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - // Copy input shape - auto input_shape = input.sym_sizes().vec(); - input_shape_group.insert( - input_shape_group.end(), input_shape.begin(), input_shape.end()); - } + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t.is_contiguous()); + memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; +} - // save indices, args_tensor, saved_data - auto saved_tensors = std::vector(indices_group); - saved_tensors.insert( - saved_tensors.end(), result.cbegin() + group_size, result.cend()); - saved_tensors.push_back(input_group[0]); - ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; +static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + TORCH_CHECK(all_inputs.size() > 2); + + const int64_t group_size = (all_inputs.size() - 3) / 2; + + Tensor fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + Tensor saved_data = all_inputs[2 * group_size + 1]; + Tensor args_tensor_old = all_inputs[2 * group_size]; + Tensor first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } - return result; + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data.device() == at::kCPU); + TORCH_CHECK(saved_data.is_contiguous()); + int64_t* saved_data_ptr = saved_data.data_ptr(); + // Check that the size is the same + TORCH_CHECK(saved_data_ptr[0] == group_size); + const bool use_var_cols = saved_data_ptr[1]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; + + // We checked in forward that all output rows are the same for all member + // in the group + const int num_input_rows = grad_output_group[0].size(0); + + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. + for (auto i = 0; i < group_size; i++) { + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } - static torch::autograd::variable_list backward_impl( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); - - const int64_t group_size = (all_inputs.size() - 3) / 2; - - Tensor fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - Tensor saved_data = all_inputs[2 * group_size + 1]; - Tensor args_tensor_old = all_inputs[2 * group_size]; - Tensor first_indices = all_inputs[group_size]; - - auto grad_output_group = std::vector( - all_inputs.cbegin(), all_inputs.cbegin() + group_size); - std::vector output_shape_group; - output_shape_group.reserve(output_shape_group_ref.size()); - for (const auto& i : output_shape_group_ref) { - output_shape_group.push_back(i.as_int_unchecked()); - } + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + int64_t* grad_output_ptrs = args_tensor.data_ptr(); + int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; + int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + int64_t group_grad_input_numel = 0; + std::vector grad_input_numels; + grad_input_numels.reserve(group_size); + + // We need to store contiguous gradients outside the for-loop to guarantee + // that the contiguous tensors will outlive the kernel computation + std::vector> grad_output_contigs; + grad_output_contigs.reserve(group_size); - auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - - // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; - - // We checked in forward that all output rows are the same for all member - // in the group - const int num_input_rows = grad_output_group[0].size(0); - - std::vector outputs; - // Returning 3 outputs: - // 1) group_size Variable()'s for indices - // 2) group_size gradients for inputs - // 3) 1 Variable() for group_size - outputs.reserve(group_size * 2 + 1); - - // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. - // Add empty tensor with zero size here to make __torch_dispatch__ work for - // the backward op. Those empty tensors will be replaced with - // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); - } + for (const auto i : c10::irange(group_size)) { + const auto& grad = grad_output_group[i]; + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; - - int64_t group_grad_input_numel = 0; - std::vector grad_input_numels; - grad_input_numels.reserve(group_size); - - // We need to store contiguous gradients outside the for-loop to guarantee - // that the contiguous tensors will outlive the kernel computation - std::vector> grad_output_contigs; - grad_output_contigs.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - const auto& grad = grad_output_group[i]; - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - - // Store grad contigs to keep them alive during the kernel computation - grad_output_contigs.push_back(grad.expect_contiguous()); - - // Compute the total number of elements for all grad_inputs - int64_t grad_input_numel = output_shape_group[i * output_dim]; - for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { - grad_input_numel *= output_shape_group[j]; - } - grad_input_numels.push_back(grad_input_numel); - group_grad_input_numel += grad_input_numel; - - // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); - } + // Store grad contigs to keep them alive during the kernel computation + grad_output_contigs.push_back(grad.expect_contiguous()); - // Allocate a big tensor to avoid calling many small elementwise kernels - const auto group_grad_input = - at::zeros({group_grad_input_numel}, fwd_input.options()); - TORCH_CHECK(group_grad_input.is_contiguous()); + // Compute the total number of elements for all grad_inputs + int64_t grad_input_numel = output_shape_group[i * output_dim]; + for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { + grad_input_numel *= output_shape_group[j]; + } + grad_input_numels.push_back(grad_input_numel); + group_grad_input_numel += grad_input_numel; - // Split to output_group - auto output_group = group_grad_input.split(grad_input_numels, 0); + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + } - TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Allocate a big tensor to avoid calling many small elementwise kernels + const auto group_grad_input = + at::zeros({group_grad_input_numel}, fwd_input.options()); + TORCH_CHECK(group_grad_input.is_contiguous()); - // Reshape grad inputs and obtain their pointers - for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); - TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = - reinterpret_cast(output_group[i].data_ptr()); + // Split to output_group + auto output_group = group_grad_input.split(grad_input_numels, 0); - // 2) Add group_size gradients for inputs - outputs.push_back(output_group[i]); - } + TORCH_CHECK(output_group.size() == static_cast(group_size)); - // Calculate indices_ptrs - std::vector> index_contigs; - index_contigs.reserve(group_size); - for (const auto i : c10::irange(group_size)) { - const auto& indices = indices_group[i]; - index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - } + // Reshape grad inputs and obtain their pointers + for (int i = 0; i < group_size; i++) { + const auto grad_input_shape = std::vector( + output_shape_group.begin() + i * output_dim, + output_shape_group.begin() + (i + 1) * output_dim); + output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK(output_group[i].is_contiguous()); + grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); - // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, - fwd_input.scalar_type(), - first_indices.scalar_type(), - fwd_input.device().index(), - num_input_rows, - total_num_warps, - group_size, - /*use_index_select=*/false, - use_var_cols); - - return outputs; + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); } - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { - // empty outputs - return torch::autograd::variable_list(1); - } - // remove redundant grads - auto group_size = grad_output_group.size() - 2; - grad_output_group.resize(group_size); - - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); - grad_output_group.insert( - grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); - static auto backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_gpu_backward", "") - .typed(); - auto res = backward_op.call(grad_output_group, output_shape_group); - // 1) Add group_size Variable()'s for indices - // Replace all empty tensors with Variable(). This must be done after the - // op.call to make __torch_dispatch__ work for the backward op. - std::fill( - res.begin(), res.begin() + group_size, torch::autograd::Variable()); - // 3) Add 1 Variable() for group_size - res.push_back({}); - return res; + // Calculate indices_ptrs + std::vector> index_contigs; + index_contigs.reserve(group_size); + for (const auto i : c10::irange(group_size)) { + const auto& indices = indices_group[i]; + index_contigs.push_back(indices.expect_contiguous()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); } -}; + + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.data_ptr(), + args_tensor.data_ptr() + group_size, + args_tensor.data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); + + return outputs; +} Tensor pack_segments_cuda( const Tensor& t_in, @@ -654,45 +558,6 @@ Tensor index_select_dim0_gpu( user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0]; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size); -} - -torch::autograd::variable_list group_index_select_dim0_gpu( - at::TensorList input_group, - at::TensorList indices_group) { - const auto group_size = indices_group.size(); - std::vector output_group; - - if (group_size == 0) { - return std::vector(); - } - - // Pack input_group and indices_group into TensorList - std::vector all_indices_input_vec; - all_indices_input_vec.reserve(group_size * 2); - - for (const Tensor& index : indices_group) { - all_indices_input_vec.push_back(index); - } - for (const Tensor& input : input_group) { - all_indices_input_vec.push_back(input); - } - - at::TensorList all_indices_input_tensor = all_indices_input_vec; - - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); - // only return the outputs (the first group_size elements) - res.resize(group_size); - return res; -} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -721,17 +586,17 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::forward_impl); + fbgemm_gpu::group_index_select_dim0_forward_impl_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::backward_impl); + fbgemm_gpu::group_index_select_dim0_backward_impl_gpu); DISPATCH_TO_CUDA( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); } TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); m.impl( "group_index_select_dim0_gpu_impl", - &fbgemm_gpu::group_index_select_dim0_gpu_impl); + &fbgemm_gpu::group_index_select_dim0_autograd_impl); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu index 9037b7c09..c899bbf9b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu @@ -62,18 +62,21 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( CUDA_DEVICE_GUARD(data); + const auto data_contig = data.expect_contiguous(); + Tensor unpacked_tensor; // The output tensor AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "unpack_segments_cuda", [&] { const auto* const lengths_data = lengths.data_ptr(); // Create output tensor of appropriate dimensions - auto shape = data.sizes().vec(); + auto shape = data_contig->sizes().vec(); shape.erase(shape.begin()); shape[0] = total_length; - unpacked_tensor = at::empty(shape, data.options()); + unpacked_tensor = at::empty(shape, data_contig->options()); - if (!(data.size(0) && data.size(1))) { // TODO: What does this mean? + if (!(data_contig->size(0) && + data_contig->size(1))) { // TODO: What does this mean? return; } @@ -82,10 +85,11 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( auto lps_data = lengths_prefix_sum.data_ptr(); FBGEMM_DISPATCH_ALL_TYPES( - data.scalar_type(), "unpack_segments_cuda-unpacking", [&] { + data_contig->scalar_type(), "unpack_segments_cuda-unpacking", [&] { const auto num_seq = lengths.size(0); - const auto cell_size = data.numel() / (data.size(0) * data.size(1)); - const auto* const data_ptr = data.data_ptr(); + const auto cell_size = data_contig->numel() / + (data_contig->size(0) * data_contig->size(1)); + const auto* const data_ptr = data_contig->data_ptr(); auto* const out_data = unpacked_tensor.data_ptr(); unpack_segments_cuda_kernel diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu index c3eb40819..a4efd4c21 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -13,52 +13,6 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - DLL_PUBLIC std::tuple get_infos_metadata(Tensor unused, int64_t B, int64_t T) { return adjust_info_B_num_bits(B, T); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4e8407fb1..8902e1c44 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -36,35 +36,6 @@ generate_vbe_metadata_meta( } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "transpose_embedding_input(" - " Tensor hash_size_cumsum, " - " int total_hash_size_bits, " - " Tensor indices, " - " Tensor offsets, " - " bool nobag=False, " - " Tensor? vbe_b_t_map=None, " - " int info_B_num_bits=26, " - " int info_B_mask=0x2FFFFFF, " - " int total_unique_indices=-1, " - " bool is_index_select=False, " - " Tensor? total_L_offsets=None, " - " int fixed_L_per_warp=0, " - " int num_warps_per_feature=0" - ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); - m.def( - "generate_vbe_metadata(" - " Tensor B_offsets, " - " Tensor B_offsets_rank_per_feature, " - " Tensor output_offsets_feature_rank, " - " Tensor D_offsets, " - " int D, " - " bool nobag, " - " SymInt max_B_feature_rank, " - " int info_B_num_bits, " - " SymInt total_B" - ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp new file mode 100644 index 000000000..654a3c3ed --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbgemm_gpu/split_embeddings_utils.h" +#include "fbgemm_gpu/utils/ops_utils.h" + +using Tensor = at::Tensor; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +namespace { + +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + +std::tuple +get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} + +} // namespace + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "transpose_embedding_input(" + " Tensor hash_size_cumsum, " + " int total_hash_size_bits, " + " Tensor indices, " + " Tensor offsets, " + " bool nobag=False, " + " Tensor? vbe_b_t_map=None, " + " int info_B_num_bits=26, " + " int info_B_mask=0x2FFFFFF, " + " int total_unique_indices=-1, " + " bool is_index_select=False, " + " Tensor? total_L_offsets=None, " + " int fixed_L_per_warp=0, " + " int num_warps_per_feature=0" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); + m.def( + "generate_vbe_metadata(" + " Tensor B_offsets, " + " Tensor B_offsets_rank_per_feature, " + " Tensor output_offsets_feature_rank, " + " Tensor D_offsets, " + " int D, " + " bool nobag, " + " SymInt max_B_feature_rank, " + " int info_B_num_bits, " + " SymInt total_B" + ") -> (Tensor, Tensor)"); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); + DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); +} diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index fdc91b0e9..ca5d3b6cb 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -7,8 +7,8 @@ */ #include "kv_db_table_batched_embeddings.h" +#include #include -#include #include #include "common/time/Time.h" #include "kv_db_cuda_utils.h" diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 93495f2da..d6e0c9180 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -37,7 +37,7 @@ #include #include -#include +#include #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" #include "fbgemm_gpu/utils/dispatch_macros.h" diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index d371c2845..340d616e6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -343,7 +343,7 @@ ssd_cache_populate_actions_cuda( bool gather_cache_stats, std::optional ssd_cache_stats, const bool lock_cache_line, - const c10::optional& lxu_cache_locking_counter) { + const std::optional& lxu_cache_locking_counter) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index b4ed00d31..73e6973c6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -30,7 +30,7 @@ ssd_cache_populate_actions_cuda( bool gather_cache_stats, std::optional ssd_cache_stats, const bool lock_cache_line, - const c10::optional& lxu_cache_locking_counter); + const std::optional& lxu_cache_locking_counter); /// @ingroup embedding-ssd /// @@ -396,24 +396,26 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { public: explicit KVTensorWrapper( c10::intrusive_ptr db, - c10::intrusive_ptr snapshot_handle, std::vector shape, int64_t dtype, - int64_t row_offset) - : db_(db->impl_), - snapshot_handle_(std::move(snapshot_handle)), - shape_(std::move(shape)), - row_offset_(row_offset) { + int64_t row_offset, + std::optional> + snapshot_handle) + : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; options_ = at::TensorOptions() .dtype(static_cast(dtype)) .device(at::kCPU) .layout(at::kStrided); + if (snapshot_handle.has_value()) { + snapshot_handle_ = std::move(snapshot_handle.value()); + } } at::Tensor narrow(int64_t dim, int64_t start, int64_t length) { CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported"; CHECK_EQ(db_->get_max_D(), shape_[1]); + CHECK_TRUE(snapshot_handle_ != nullptr); auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_); db_->get_range_from_snapshot( t, start + row_offset_, length, snapshot_handle_->handle); @@ -422,6 +424,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { return t.narrow(1, 0, shape_[1]); } + void set_range( + int64_t dim, + const int64_t start, + const int64_t length, + const at::Tensor& weights) { + CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported"; + CHECK_EQ(db_->get_max_D(), shape_[1]); + db_->set_range(weights, start + row_offset_, length); + } + c10::IntArrayRef size() { return shape_; } @@ -537,21 +549,25 @@ static auto kv_tensor_wrapper = .def( torch::init< c10::intrusive_ptr, - c10::intrusive_ptr, std::vector, int64_t, - int64_t>(), + int64_t, + std::optional< + c10::intrusive_ptr>>(), "", {torch::arg("db"), - torch::arg("snapshot_handle"), torch::arg("shape"), torch::arg("dtype"), - torch::arg("row_offset")}) + torch::arg("row_offset"), + // snapshot must be provided for reading + // not needed for writing + torch::arg("snapshot_handle") = std::nullopt}) .def( "narrow", &KVTensorWrapper::narrow, "", {torch::arg("dim"), torch::arg("start"), torch::arg("length")}) + .def("set_range", &KVTensorWrapper::set_range) .def_property("dtype_str", &KVTensorWrapper::dtype_str) .def_property( "shape", diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 23b6f1e89..f14897854 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -11,10 +11,10 @@ #include #include +#include +#include +#include #include -#include -#include -#include #include #ifdef FBGEMM_FBCODE #include "common/strings/UUID.h" @@ -528,13 +528,21 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const SnapshotHandle* snapshot_handle) { const auto seq_indices = at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); - int64_t* count_ = new int64_t[1]; - count_[0] = length; - const auto count = at::from_blob(count_, {1}, at::kLong); + const auto count = at::tensor({length}, at::ScalarType::Long); folly::coro::blockingWait( get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); } + void set_range( + const at::Tensor& weights, + const int64_t start, + const int64_t length) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + const auto count = at::tensor({length}, at::ScalarType::Long); + folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count)); + } + int64_t get_max_D() { return max_D_; } @@ -668,9 +676,11 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { snapshot_ptr_t snapshot = snapshot_handle == nullptr ? nullptr : snapshot_handle->get_snapshot_for_shard(shard); + auto local_ro = ro_; + local_ro.snapshot = snapshot; tasks.emplace_back( folly::coro::co_invoke( - [this, &indices, &weights, count_, shard, snapshot]() mutable + [this, &indices, &weights, count_, shard, local_ro]() mutable -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { @@ -734,10 +744,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { values.resize(keys.size()); statuses.resize(keys.size()); - // Set a snapshot if it is available - ro_.snapshot = snapshot; dbs_[shard]->MultiGet( - ro_, + local_ro, keys.size(), cfs.data(), keys.data(), diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 3f1124eff..37d0aaf7b 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -26,14 +26,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + # Relative tolerances # pyre-fixme[5]: Global expression must be annotated. diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 3bdcacf98..491176f76 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -24,13 +24,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 658d773f3..57a7b0263 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -22,14 +22,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + MAX_EXAMPLES = 20 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 5333cc893..6a720a174 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -23,12 +23,7 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py index 005cf38b5..4218ebc95 100644 --- a/fbgemm_gpu/test/release/utils.py +++ b/fbgemm_gpu/test/release/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import inspect import typing from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 69b6e3477..8abddca75 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -23,12 +23,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") suppressed_list: List[HealthCheck] = ( diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 40cfacc06..fb2fcf85e 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -2,6 +2,16 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "fb::pack_segments": { + "PackedSegmentsTest.test_aot_dispatch_dynamic__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + }, + "PackedSegmentsTest.test_faketensor__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index dd5319277..095ea4377 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -23,9 +23,9 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available + from test_utils import gpu_available, gpu_unavailable else: - from fbgemm_gpu.test.test_utils import gpu_available + from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: @@ -47,6 +47,15 @@ def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: # pyre-fixme[2] # pyre-fixme[24] def torch_compiled(model: Callable, **kwargs) -> Callable: + """A helper function to apply torch.compile if python < 3.12. + + Args: + model: The model to be compiled. + kwargs: The arguments to be passed to torch.compile. + + Returns: + The model. + """ if sys.version_info < (3, 12, 0): return torch.compile(model, **kwargs) else: @@ -60,6 +69,17 @@ def _pack_segments_ref( tensor: torch.Tensor, max_length: Optional[int] = None, ) -> npt.NDArray: + """ + This function is a reference implementation of pack_segments. + + Args: + lengths (Tensor): The lengths of tensor. + tensor (Tensor): The tensor to be packed. + max_length (Optional[int]): The maximum length of the packed tensor. + + Returns: + The packed tensor. + """ lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length @@ -106,6 +126,22 @@ def test_pack_segments( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops compared to the reference implementation. + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True) lengths = torch.tensor( @@ -209,6 +245,23 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops with set max_length + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + max_length - The maximum length of the packed tensor + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype) lengths = torch.tensor( @@ -264,6 +317,20 @@ def test_pack_segments_meta_backend( divisions: int, dtype: torch.dtype, ) -> None: + """ + This function tests pack_segments ops with meta backend. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor( input_raw, dtype=torch.float32, requires_grad=True @@ -281,6 +348,109 @@ def test_pack_segments_meta_backend( # verify forward assert packed_tensor.size() == torch.Tensor(packed_ref).size() + @unittest.skipIf(*gpu_unavailable) + @given( + n=st.integers(2, 10), + k=st.integers(2, 10), + batch_size=st.integers(1, 30), + divisions=st.integers(1, 10), + dtype=st.sampled_from( + [ + torch.float, + torch.half, + ] + ), + torch_compile=st.booleans(), + use_cpu=st.booleans(), + ) + @settings(deadline=None) + def test_pack_segments_noncontig( + self, + n: int, + k: int, + batch_size: int, + divisions: int, + dtype: torch.dtype, + torch_compile: bool, + use_cpu: bool, + ) -> None: + """ + This function tests pack_segments ops when input gradients to backward are non-contiguous. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + use_cpu - Whether to use CPU or GPU + + Returns: + None + """ + + input_raw = np.random.rand(batch_size, n, k) + # create input + input_data_ref = torch.tensor(input_raw, dtype=dtype, requires_grad=True) + input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True).cuda() + # retain grad to compare gradients of the inputs later + input_data.retain_grad() + input_data_ref.retain_grad() + + # set lengths + lengths = torch.tensor( + get_n_rand_num_summing_to_k(divisions, batch_size), + dtype=torch.int, + ) + max_length = lengths.max().item() + + packed_ref = torch.ops.fbgemm.pack_segments( + t_in=input_data_ref, lengths=lengths, max_length=max_length + ) + packed_ref.retain_grad() + + # pack segments using fbgemm and fb + packed_tensor = torch.ops.fbgemm.pack_segments( + t_in=input_data, lengths=lengths.cuda(), max_length=max_length + ) + packed_tensor.retain_grad() + + # verify forward + self.assertTrue(torch.equal(packed_tensor.cpu(), packed_ref)) + + # create non-contiguous grad + shape = tuple(x * 2 for x in packed_ref.shape) + grads = torch.tensor( + np.random.uniform(low=0.01, high=0.5, size=shape).astype(np.float32) + ).to(dtype) + grad_noncontig_cpu = grads.as_strided(packed_ref.shape, grads.stride()) + grad_noncontig_cuda = grads.cuda().as_strided(packed_ref.shape, grads.stride()) + + self.assertTrue( + not ( + grad_noncontig_cpu.is_contiguous() + and grad_noncontig_cuda.is_contiguous() + ), + msg="Expected grads to be non-contiguous but they are contiguous", + ) + + # verify backward + packed_ref.backward(grad_noncontig_cpu) + packed_tensor.backward(grad_noncontig_cuda) + self.assertTrue( + torch.equal(packed_tensor.cpu(), packed_ref), + msg="Expected packed tensors to be equal but they are not", + ) + + # verify backward input gradients + self.assertTrue( + # pyre-fixme[16]: Optional type has no attribute `cpu`. + # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Optional[Tensor]`. + torch.equal(input_data.grad.cpu(), input_data_ref.grad.cpu()), + msg="Expected input gradients to be equal but they are not", + ) + extend_test_class(PackedSegmentsTest) diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py index 1530f1e75..f1277ccac 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -76,7 +76,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: # create a view tensor wrapper snapshot = ssd_db.create_snapshot() tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - ssd_db, snapshot, [E, D], weights.dtype, 0 + ssd_db, [E, D], weights.dtype, 0, snapshot ) self.assertEqual(tensor_wrapper.shape, [E, D]) @@ -100,3 +100,72 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: del tensor_wrapper del snapshot self.assertEqual(ssd_db.get_snapshot_count(), 0) + + def test_write_tensor_to_db(self) -> None: + E = int(1e4) # num total rows + D = 128 # emb dimension + N = 1000 # window size + weights_precision = SparseType.FP32 + weights_dtype = weights_precision.as_dtype() + + with tempfile.TemporaryDirectory() as ssd_directory: + # pyre-fixme[16]: Module `classes` has no attribute `fbgemm`. + ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( + ssd_directory, + 8, # num_shards + 8, # num_threads + 0, # ssd_memtable_flush_period, + 0, # ssd_memtable_flush_offset, + 4, # ssd_l0_files_per_compact, + D, # embedding_dim + 0, # ssd_rate_limit_mbps, + 1, # ssd_size_ratio, + 8, # ssd_compaction_trigger, + 536870912, # 512MB ssd_write_buffer_size, + 8, # ssd_max_write_buffer_num, + -0.01, # ssd_uniform_init_lower + 0.01, # ssd_uniform_init_upper + 32, # row_storage_bitwidth + 10 * (2**20), # block cache size + ) + + weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) + output_weights = torch.empty_like(weights) + + # no snapshot needed for writing to rocksdb + tensor_wrapper0 = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0 + ) + step = N + for i in range(0, E, step): + tensor_wrapper0.set_range(0, i, step, weights) + + # force waiting for set to complete + indices = torch.arange(step) + for i in range(0, E, step): + ssd_db.get(i + indices, output_weights, torch.tensor(indices.shape[0])) + + # create a view tensor wrapper + snapshot = ssd_db.create_snapshot() + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0, snapshot + ) + self.assertEqual(tensor_wrapper.shape, [E, D]) + + # table has a total of E rows + # load 1000 rows at a time + step = 1000 + for i in range(0, E, step): + narrowed = tensor_wrapper.narrow(0, i, step) + self.assertTrue( + torch.equal(narrowed, weights), + msg=( + f"Tensor value mismatch :\n" + f"actual\n{narrowed}\n\nexpected\n{weights}" + ), + ) + + del tensor_wrapper0 + del tensor_wrapper + del snapshot + self.assertEqual(ssd_db.get_snapshot_count(), 0) diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index adb96daaa..2db48594d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -26,6 +26,7 @@ CounterBasedRegularizationDefinition, CounterWeightDecayMode, CowClipDefinition, + EnsembleModeDefinition, GradSumDecay, LearningRateMode, SplitTableBatchedEmbeddingBagsCodegen, @@ -33,6 +34,7 @@ TailIdThreshold, WeightDecayMode, ) + from fbgemm_gpu.tbe.utils import ( b_indices, get_table_batched_offsets_from_dense, @@ -307,21 +309,22 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["eta"] = eta if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - (eps, step_ema, step_swap, step_start, step_mode, momentum) = ( + (eps, step_ema, step_swap, step_start, step_mode) = ( 1e-4, 1.0, 1.0, - -1.0, + 0.0, StepMode.USE_ITER, - 0.8, ) optimizer_kwargs["eps"] = eps - optimizer_kwargs["step_ema"] = step_ema - optimizer_kwargs["step_swap"] = step_swap - optimizer_kwargs["step_start"] = step_start - optimizer_kwargs["step_mode"] = step_mode - optimizer_kwargs["momentum"] = momentum optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes + optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition( + step_ema=step_ema, + step_swap=step_swap, + step_start=step_start, + step_ema_coef=momentum, + step_mode=step_mode, + ) cc = emb_op( embedding_specs=[ @@ -555,14 +558,14 @@ def execute_backward_optimizers_( # noqa C901 if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: for t in range(T): iter_ = cc.iter.item() - (m2, m1, prev_iter, row_counter) = split_optimizer_states[t] + (m1, m2) = split_optimizer_states[t] if (m1.dtype == torch.float) and (m2.dtype == torch.float): tol = 1.0e-4 else: tol = 1.0e-2 # Some optimizers have non-float momentums - dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() - m2_ref = dense_cpu_grad.pow(2).mean(dim=1) + m2_ref = torch.mul(bs[t].weight.cpu(), 1.0 - momentum) + weights_ref = m2_ref.mul(1.0) torch.testing.assert_close( m2.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m2_ref.float() @@ -571,15 +574,8 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - v_hat_t = m2_ref.view(m2_ref.numel(), 1) - weights_new = split_weights[t] - weights_ref = torch.addcdiv( - bs[t].weight.cpu(), - value=-lr, - tensor1=dense_cpu_grad, - tensor2=v_hat_t.sqrt_().add_(eps), - ) - m1_ref = torch.mul(weights_ref, 1.0 - momentum) + dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() + m1_ref = dense_cpu_grad.pow(2).mean(dim=1) torch.testing.assert_close( m1.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m1_ref.float() @@ -588,7 +584,14 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - weights_ref = m1_ref.div(1.0) + v_hat_t = m1_ref.view(m1_ref.numel(), 1) + weights_new = split_weights[t] + weights_ref = torch.addcdiv( + weights_ref, + value=-lr, + tensor1=dense_cpu_grad, + tensor2=v_hat_t.sqrt_().add_(eps), + ) torch.testing.assert_close( weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(), weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()), @@ -599,9 +602,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer_states_dict = get_optimizer_states[t] assert set(optimizer_states_dict.keys()) == { "sum", - "exp_avg", - "prev_iter", - "row_counter", + "sparse_ema", } if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):