Skip to content

Commit

Permalink
Make fbgemm::int_nbit_split_embedding_codegen_lookup_function pt2_com…
Browse files Browse the repository at this point in the history
…pliant (pytorch#2231)

Summary:
Pull Request resolved: pytorch#2231

The previous abstract impl was completely bogus. This diff fixes it.

Reviewed By: williamwen42

Differential Revision: D52265254

fbshipit-source-id: 93d630c57c862030d9afa333dfedd4dcd33013d0
  • Loading branch information
zou3519 authored and facebook-github-bot committed Dec 20, 2023
1 parent 9b208da commit 2f5e16e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 39 deletions.
9 changes: 8 additions & 1 deletion fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif
#include <torch/serialize/input-archive.h>
#include <torch/serialize/output-archive.h>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand Down Expand Up @@ -243,8 +244,14 @@ Tensor pruned_array_lookup_cpu(
Tensor index_remappings_offsets);

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif
m.def(
"int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor");
"int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor",
{PT2_COMPLIANT_TAG});
DISPATCH_TO_CPU(
"int_nbit_split_embedding_codegen_lookup_function",
int_nbit_split_embedding_codegen_lookup_function_cpu);
Expand Down
88 changes: 56 additions & 32 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")

import torch.utils._pytree as pytree
from torch import Tensor


Expand Down Expand Up @@ -239,6 +240,20 @@ def expand_into_jagged_permute_meta(
return output_permute


def check_all_same_device(*tensors: Optional[Tensor]) -> None:
# pyre-ignore[9]
tensors, _ = pytree.tree_flatten(tensors)
if len(tensors) == 0:
return
first_tensor: Optional[Tensor] = None
for tensor in tensors:
if tensor is None:
continue
if first_tensor is None:
first_tensor = tensor
torch._check(tensor.device == first_tensor.device)


@impl_abstract("fbgemm::int_nbit_split_embedding_codegen_lookup_function")
def int_nbit_split_embedding_codegen_lookup_function_meta(
dev_weights: torch.Tensor,
Expand All @@ -257,47 +272,57 @@ def int_nbit_split_embedding_codegen_lookup_function_meta(
offsets: torch.Tensor,
pooling_mode: int,
indice_weights: Optional[torch.Tensor] = None,
output_dtype_int: Optional[int] = None,
output_dtype_int: int = 1,
lxu_cache_weights: Optional[torch.Tensor] = None,
lxu_cache_locations: Optional[torch.Tensor] = None,
row_alignment: Optional[int] = None,
max_float8_D: Optional[int] = None,
fp8_exponent_bits: Optional[int] = None,
fp8_exponent_bias: Optional[int] = None,
) -> Tensor:
T = D_offsets.numel() - 1
B = (offsets.size(0) - 1) // T
output_dtype = torch.float32
torch._check(
output_dtype_int in (0, 1, 5),
lambda: f"expected output_dtype to be fp32, fp16 or bf16, got {indices.dtype}",
check_all_same_device(
dev_weights,
uvm_weights,
weights_placements,
weights_offsets,
weights_tys,
D_offsets,
indices,
offsets,
indice_weights,
)
if output_dtype_int == SparseType.FP32.value:
output_dtype = torch.float32
elif output_dtype_int == SparseType.FP16.value:
output_dtype = torch.float16
elif output_dtype_int == SparseType.BF16.value:
output_dtype = torch.bfloat16
output_dtype = SparseType.from_int(output_dtype_int).as_dtype()
kINT8QparamsBytes = 8

if pooling_mode == PoolingMode.NONE:
# pyre-ignore
offsets_last: int = offsets[-1].item()
total_D_T: int = total_D // T
torch._check_is_size(offsets[-1].item())
torch._check_is_size(total_D_T)
torch._check_is_size(B)
return dev_weights.new_empty(
[offsets_last, total_D_T],
dtype=output_dtype,
device=torch.device("meta"),
D = max(
[
max_int2_D,
max_int4_D,
max_int8_D,
max_float16_D,
max_float32_D,
max_float8_D if max_float8_D is not None else 0,
]
)
torch._check_is_size(B)
torch._check_is_size(total_D)
return dev_weights.new_empty(
(B, total_D),
dtype=output_dtype,
device=torch.device("meta"),
)
total_L = indices.numel()
T = weights_offsets.numel()
torch._check(D > 0)
adjusted_D = D
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
adjusted_D += T * kINT8QparamsBytes
output = dev_weights.new_empty([total_L, adjusted_D], dtype=output_dtype)
return output

T = D_offsets.numel() - 1
torch._check(T > 0)
torch._check(total_D > 0)
B = (offsets.size(0) - 1) // T
total_adjusted_D = total_D
if SparseType.from_int(output_dtype_int) == SparseType.INT8:
total_adjusted_D += T * kINT8QparamsBytes
output = dev_weights.new_empty([B, total_adjusted_D], dtype=output_dtype)
return output


@impl_abstract("fbgemm::block_bucketize_sparse_features")
Expand Down Expand Up @@ -404,8 +429,7 @@ def dense_to_jagged_forward(
if not total_L:
total_L = torch.library.get_ctx().new_dynamic_size()
return dense.new_zeros(
total_L,
dense.size()[-1],
[total_L, dense.size()[-1]],
dtype=dense.dtype,
device=dense.device,
layout=dense.layout,
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/test/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,27 @@
"fbgemm::int_nbit_split_embedding_codegen_lookup_function": {
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_miss_counter": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_fused_pooled_emb_quant": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_uvm_cache_stats": {
"comment": "",
"status": "xfail"
"status": "xsuccess"
}
},
"fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
from hypothesis.strategies import composite
from torch import Tensor

torch.ops.import_module("fbgemm_gpu.sparse_ops")

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

Expand Down

0 comments on commit 2f5e16e

Please sign in to comment.