Skip to content

Commit

Permalink
Add split_embeddings_utils_cpu
Browse files Browse the repository at this point in the history
Summary: Add split_embeddings_utils_cpu to enable adjust_info_num_bits and generate_vbe_metadata on CPU build

Reviewed By: q10

Differential Revision: D63711688
  • Loading branch information
spcyppt authored and facebook-github-bot committed Oct 1, 2024
1 parent a44317d commit cb4ef1d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 92 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ set(fbgemm_gpu_sources_static_cpu
src/split_embeddings_cache/lru_cache_populate_byte.cpp
src/split_embeddings_cache/lxu_cache.cpp
src/split_embeddings_cache/split_embeddings_cache_ops.cpp
src/split_embeddings_utils/split_embeddings_utils_cpu.cpp
codegen/training/index_select/batch_index_select_dim0_ops.cpp
codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)

Expand Down
46 changes: 0 additions & 46 deletions fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,6 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
int32_t B,
int32_t T) {
int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS;
uint32_t info_B_mask = DEFAULT_INFO_B_MASK;
uint32_t max_T = MAX_T;
uint32_t max_B = MAX_B;
bool invalid_T = T > max_T;
bool invalid_B = B > max_B;

TORCH_CHECK(
!(invalid_T && invalid_B),
"Not enough infos bits to accommodate T and B. Default num bits = ",
DEFAULT_INFO_NUM_BITS);

if (invalid_T) {
// Reduce info_B_num_bits
while (invalid_T && !invalid_B && info_B_num_bits > 0) {
info_B_num_bits--;
max_T = ((max_T + 1) << 1) - 1;
max_B = ((max_B + 1) >> 1) - 1;
invalid_T = T > max_T;
invalid_B = B > max_B;
}
} else if (invalid_B) {
// Increase info_B_num_bits
while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) {
info_B_num_bits++;
max_T = ((max_T + 1) >> 1) - 1;
max_B = ((max_B + 1) << 1) - 1;
invalid_T = T > max_T;
invalid_B = B > max_B;
}
}

TORCH_CHECK(
!invalid_T && !invalid_B,
"Not enough infos bits to accommodate T and B. Default num bits = ",
DEFAULT_INFO_NUM_BITS);

// Recompute info_B_mask using new info_B_num_bits
info_B_mask = (1u << info_B_num_bits) - 1;

return {info_B_num_bits, info_B_mask};
}

DLL_PUBLIC std::tuple<int64_t, int64_t>
get_infos_metadata(Tensor unused, int64_t B, int64_t T) {
return adjust_info_B_num_bits(B, T);
Expand Down
46 changes: 0 additions & 46 deletions fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,58 +33,12 @@ generate_vbe_metadata_meta(
return {row_output_offsets, b_t_map};
}

std::tuple<Tensor /*row_output_offsets*/, Tensor /*b_t_map*/>
generate_vbe_metadata_cpu(
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& output_offsets_feature_rank,
const Tensor& D_offsets,
const int64_t D,
const bool nobag,
const c10::SymInt max_B_feature_rank,
const int64_t info_B_num_bits,
const c10::SymInt total_B) {
Tensor row_output_offsets = output_offsets_feature_rank;
Tensor b_t_map = B_offsets_rank_per_feature;
return {row_output_offsets, b_t_map};
}

} // namespace

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"transpose_embedding_input("
" Tensor hash_size_cumsum, "
" int total_hash_size_bits, "
" Tensor indices, "
" Tensor offsets, "
" bool nobag=False, "
" Tensor? vbe_b_t_map=None, "
" int info_B_num_bits=26, "
" int info_B_mask=0x2FFFFFF, "
" int total_unique_indices=-1, "
" bool is_index_select=False, "
" Tensor? total_L_offsets=None, "
" int fixed_L_per_warp=0, "
" int num_warps_per_feature=0"
") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)");
m.def(
"generate_vbe_metadata("
" Tensor B_offsets, "
" Tensor B_offsets_rank_per_feature, "
" Tensor output_offsets_feature_rank, "
" Tensor D_offsets, "
" int D, "
" bool nobag, "
" SymInt max_B_feature_rank, "
" int info_B_num_bits, "
" SymInt total_B"
") -> (Tensor, Tensor)");
DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input);
DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata);
DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata);
DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
Expand Down
119 changes: 119 additions & 0 deletions fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <torch/library.h>
#include "fbgemm_gpu/split_embeddings_utils.h"
#include "fbgemm_gpu/utils/ops_utils.h"

using Tensor = at::Tensor;

DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
int32_t B,
int32_t T) {
int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS;
uint32_t info_B_mask = DEFAULT_INFO_B_MASK;
uint32_t max_T = MAX_T;
uint32_t max_B = MAX_B;
bool invalid_T = T > max_T;
bool invalid_B = B > max_B;

TORCH_CHECK(
!(invalid_T && invalid_B),
"Not enough infos bits to accommodate T and B. Default num bits = ",
DEFAULT_INFO_NUM_BITS);

if (invalid_T) {
// Reduce info_B_num_bits
while (invalid_T && !invalid_B && info_B_num_bits > 0) {
info_B_num_bits--;
max_T = ((max_T + 1) << 1) - 1;
max_B = ((max_B + 1) >> 1) - 1;
invalid_T = T > max_T;
invalid_B = B > max_B;
}
} else if (invalid_B) {
// Increase info_B_num_bits
while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) {
info_B_num_bits++;
max_T = ((max_T + 1) >> 1) - 1;
max_B = ((max_B + 1) << 1) - 1;
invalid_T = T > max_T;
invalid_B = B > max_B;
}
}

TORCH_CHECK(
!invalid_T && !invalid_B,
"Not enough infos bits to accommodate T and B. Default num bits = ",
DEFAULT_INFO_NUM_BITS);

// Recompute info_B_mask using new info_B_num_bits
info_B_mask = (1u << info_B_num_bits) - 1;

return {info_B_num_bits, info_B_mask};
}

namespace {

std::tuple<Tensor /*row_output_offsets*/, Tensor /*b_t_map*/>
generate_vbe_metadata_cpu(
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& output_offsets_feature_rank,
const Tensor& D_offsets,
const int64_t D,
const bool nobag,
const c10::SymInt max_B_feature_rank,
const int64_t info_B_num_bits,
const c10::SymInt total_B) {
Tensor row_output_offsets = output_offsets_feature_rank;
Tensor b_t_map = B_offsets_rank_per_feature;
return {row_output_offsets, b_t_map};
}

std::tuple<int64_t, int64_t>
get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) {
return adjust_info_B_num_bits(B, T);
}

} // namespace

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"transpose_embedding_input("
" Tensor hash_size_cumsum, "
" int total_hash_size_bits, "
" Tensor indices, "
" Tensor offsets, "
" bool nobag=False, "
" Tensor? vbe_b_t_map=None, "
" int info_B_num_bits=26, "
" int info_B_mask=0x2FFFFFF, "
" int total_unique_indices=-1, "
" bool is_index_select=False, "
" Tensor? total_L_offsets=None, "
" int fixed_L_per_warp=0, "
" int num_warps_per_feature=0"
") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)");
m.def(
"generate_vbe_metadata("
" Tensor B_offsets, "
" Tensor B_offsets_rank_per_feature, "
" Tensor output_offsets_feature_rank, "
" Tensor D_offsets, "
" int D, "
" bool nobag, "
" SymInt max_B_feature_rank, "
" int info_B_num_bits, "
" SymInt total_B"
") -> (Tensor, Tensor)");
DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu);
DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu);
}

0 comments on commit cb4ef1d

Please sign in to comment.