Skip to content

Commit

Permalink
Fix auto-vec int8 CPU STBE (#2878)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2878

Fix test failure in deeplearning/fbgemm/fbgemm_gpu/test/tbe:nbit_forward - test_nbit_forward_cpu_seq_int8

Reviewed By: q10, spcyppt

Differential Revision: D60126202

fbshipit-source-id: 44123ba0c4f6aed8e49c54ebbc887ff0eaaedc7e
  • Loading branch information
Wei Su authored and facebook-github-bot committed Jul 24, 2024
1 parent c79aa37 commit 92e5c33
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,6 @@ typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
no_bag,
is_bf16_out);
};

} else {
return [=](int64_t output_size,
int64_t index_size,
Expand Down
11 changes: 8 additions & 3 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,19 @@ bool EmbeddingSpMDM8Bit_autovec(
IndexType current = 0;

if (no_bag) {
#pragma unroll 4
// compiler may see this as unused even if it's used in pragma
[[maybe_unused]] constexpr int unroll_factor = 4;
#if defined(__clang__)
#pragma unroll unroll_factor
#elif defined(__GNUC__)
#pragma GCC unroll unroll_factor
#endif
for (int m = 0; m < output_size; ++m) {
const auto idx = indices[m];

if (idx < 0 || idx >= data_size) {
return false;
}

if constexpr (isOutput8bit) {
const uint8_t* input_row_ptr = input + input_stride * idx;
memcpy(out, input_row_ptr, sizeof(uint8_t) * input_stride);
Expand Down Expand Up @@ -140,8 +145,8 @@ bool EmbeddingSpMDM8Bit_autovec(
std::fma(scale, (float)input[input_offset + j], buf[j] + bias);
}
fill_output(out, buf.data(), block_size, is_bf16_out);
out += output_stride;
}
out += output_stride;
} // m
return true;
} // no_bag
Expand Down

0 comments on commit 92e5c33

Please sign in to comment.