Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type error suppressions for upcoming upgrade #3109

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,21 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
torch.index_select,
(input, 0, offset_indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand All @@ -306,13 +314,19 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` but got `Tensor`.
cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand Down Expand Up @@ -714,6 +728,8 @@ def batch_group_index_select_bwd(
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -726,12 +742,16 @@ def batch_group_index_select_bwd(
input_rows,
input_columns,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
group_index_select_fwd,
(gis_inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -746,6 +766,8 @@ def batch_group_index_select_bwd(
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -756,6 +778,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_batch,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -766,6 +790,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_group,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N

time_per_iter = benchmark_requests(
requests_uvm,
# pyre-fixme[6]: For 2nd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> Tensor` but got `(indices: Tensor, offsets: Tensor,
# per_sample_weights: Tensor) -> None`.
run_bench,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
Expand Down Expand Up @@ -1922,6 +1925,9 @@ def nbit_uvm(
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb_mixed.forward(
indices,
offsets,
Expand Down Expand Up @@ -2409,6 +2415,9 @@ def nbit_cache( # noqa C901
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb.forward(
indices,
offsets,
Expand Down Expand Up @@ -3049,6 +3058,7 @@ def device_with_spec( # noqa C901
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
)
Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@


def none_throws(
optional: Optional[TypeVar("_T")], message: str = "Unexpected `None`"
# pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
# valid type.
optional: Optional[TypeVar("_T")],
message: str = "Unexpected `None`",
# pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
) -> TypeVar("_T"):
if optional is None:
raise AssertionError(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):

embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
cache_miss_counter: torch.Tensor
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
weights_offsets: torch.Tensor
# pyre-fixme[13]: Attribute `weights_placements` is never initialized.
weights_placements: torch.Tensor

def __init__( # noqa C901
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
uuid: str
# pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
last_uvm_cache_print_state: torch.Tensor
_vbe_B_offsets: Optional[torch.Tensor]
_vbe_max_B: int
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> Tensor:
indices, offsets, per_sample_weights = self.prepare_inputs(
indices, offsets, per_sample_weights
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def dequantize_embs(
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
Expand Down