Skip to content

Commit

Permalink
[SYCL][ESIMD] Fix load_2d inconsistency when reading non native types…
Browse files Browse the repository at this point in the history
… with VNNI transforms (#15584)
  • Loading branch information
fineg74 authored Oct 2, 2024
1 parent 1f12cae commit e076c04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
11 changes: 5 additions & 6 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4070,7 +4070,7 @@ __ESIMD_API simd<T, N> load_2d_impl(const T *Ptr, unsigned SurfaceWidth,
uintptr_t Addr = reinterpret_cast<uintptr_t>(Ptr);
constexpr lsc_data_order Transpose =
Transposed ? lsc_data_order::transpose : lsc_data_order::nontranspose;
simd<RawT, ActualN> Raw =
simd<T, ActualN> Raw =
__esimd_lsc_load2d_stateless<RawT, L1H, L2H, DS, Transpose, NBlocks,
BlockWidth, BlockHeight, Transformed,
ActualN>(Mask.data(), Addr, SurfaceWidth,
Expand All @@ -4096,17 +4096,16 @@ __ESIMD_API simd<T, N> load_2d_impl(const T *Ptr, unsigned SurfaceWidth,
// +----+----+----+----+----+----+-----+-----+
// * signifies the padded element.

simd<RawT, DstElements> Dst;
simd<T, DstElements> Dst;

for (auto i = 0; i < NBlocks; i++) {
auto DstBlock =
Dst.template select<DstBlockElements, 1>(i * DstBlockElements);

auto RawBlock = Raw.template select<GRFBlockSize, 1>(i * GRFBlockPitch);
DstBlock =
RawBlock.template bit_cast_view<RawT, GRFColSize, GRFRowPitch>()
.template select<GRFColSize, 1, GRFRowSize, 1>(0, 0)
.template bit_cast_view<RawT>();
DstBlock = RawBlock.template bit_cast_view<T, GRFColSize, GRFRowPitch>()
.template select<GRFColSize, 1, GRFRowSize, 1>(0, 0)
.template bit_cast_view<T>();
}

return Dst;
Expand Down
3 changes: 3 additions & 0 deletions sycl/test-e2e/ESIMD/lsc/lsc_load_2d_compare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,15 @@ int main() {
result |= test<uint16_t>();
result |= test<uint8_t>();
result |= test<sycl::half>();
result |= test<bf16>();

result |= test<float, true>();
result |= test<uint32_t, true>();

result |= test<uint16_t, false, true>();
result |= test<uint8_t, false, true>();
result |= test<sycl::half, false, true>();
result |= test<bf16, false, true>();

std::cout << (result ? "FAILED" : "passed") << std::endl;
return 0;
Expand Down

0 comments on commit e076c04

Please sign in to comment.