From cb4ef1d1fb805b2d96db09dc814d324a054ab621 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Tue, 1 Oct 2024 12:01:07 -0700 Subject: [PATCH] Add split_embeddings_utils_cpu Summary: Add split_embeddings_utils_cpu to enable adjust_info_num_bits and generate_vbe_metadata on CPU build Reviewed By: q10 Differential Revision: D63711688 --- fbgemm_gpu/FbgemmGpu.cmake | 1 + .../get_infos_metadata.cu | 46 ------- .../split_embeddings_utils.cpp | 46 ------- .../split_embeddings_utils_cpu.cpp | 119 ++++++++++++++++++ 4 files changed, 120 insertions(+), 92 deletions(-) create mode 100644 fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index ccf4805cd..dd23dca85 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -480,6 +480,7 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp + src/split_embeddings_utils/split_embeddings_utils_cpu.cpp codegen/training/index_select/batch_index_select_dim0_ops.cpp codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu index c3eb40819..a4efd4c21 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -13,52 +13,6 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - DLL_PUBLIC std::tuple get_infos_metadata(Tensor unused, int64_t B, int64_t T) { return adjust_info_B_num_bits(B, T); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4ae9ae0f7..8902e1c44 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -33,58 +33,12 @@ generate_vbe_metadata_meta( return {row_output_offsets, b_t_map}; } -std::tuple -generate_vbe_metadata_cpu( - const Tensor& B_offsets, - const Tensor& B_offsets_rank_per_feature, - const Tensor& output_offsets_feature_rank, - const Tensor& D_offsets, - const int64_t D, - const bool nobag, - const c10::SymInt max_B_feature_rank, - const int64_t info_B_num_bits, - const c10::SymInt total_B) { - Tensor row_output_offsets = output_offsets_feature_rank; - Tensor b_t_map = B_offsets_rank_per_feature; - return {row_output_offsets, b_t_map}; -} - } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "transpose_embedding_input(" - " Tensor hash_size_cumsum, " - " int total_hash_size_bits, " - " Tensor indices, " - " Tensor offsets, " - " bool nobag=False, " - " Tensor? vbe_b_t_map=None, " - " int info_B_num_bits=26, " - " int info_B_mask=0x2FFFFFF, " - " int total_unique_indices=-1, " - " bool is_index_select=False, " - " Tensor? total_L_offsets=None, " - " int fixed_L_per_warp=0, " - " int num_warps_per_feature=0" - ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); - m.def( - "generate_vbe_metadata(" - " Tensor B_offsets, " - " Tensor B_offsets_rank_per_feature, " - " Tensor output_offsets_feature_rank, " - " Tensor D_offsets, " - " int D, " - " bool nobag, " - " SymInt max_B_feature_rank, " - " int info_B_num_bits, " - " SymInt total_B" - ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); - DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp new file mode 100644 index 000000000..654a3c3ed --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbgemm_gpu/split_embeddings_utils.h" +#include "fbgemm_gpu/utils/ops_utils.h" + +using Tensor = at::Tensor; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +namespace { + +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + +std::tuple +get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} + +} // namespace + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "transpose_embedding_input(" + " Tensor hash_size_cumsum, " + " int total_hash_size_bits, " + " Tensor indices, " + " Tensor offsets, " + " bool nobag=False, " + " Tensor? vbe_b_t_map=None, " + " int info_B_num_bits=26, " + " int info_B_mask=0x2FFFFFF, " + " int total_unique_indices=-1, " + " bool is_index_select=False, " + " Tensor? total_L_offsets=None, " + " int fixed_L_per_warp=0, " + " int num_warps_per_feature=0" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); + m.def( + "generate_vbe_metadata(" + " Tensor B_offsets, " + " Tensor B_offsets_rank_per_feature, " + " Tensor output_offsets_feature_rank, " + " Tensor D_offsets, " + " int D, " + " bool nobag, " + " SymInt max_B_feature_rank, " + " int info_B_num_bits, " + " SymInt total_B" + ") -> (Tensor, Tensor)"); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); + DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); +}