Skip to content

Commit

Permalink
Tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
psakievich committed Sep 19, 2023
1 parent 005de2d commit bf884f8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
39 changes: 26 additions & 13 deletions include/SmartField.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class SmartField
{
};

// LEGACY implementation, HOST only, data type has to be a reference b/c the
// LEGACY implementation, HOST only, data type has to be a reference b/c the
// stk::mesh::field ctor is not public
//
// This Type should be used as close to a bucket loop as possible, and not
// This Type should be used as close to a bucket loop as possible, and not
// stored as a class member since sync/modify are marked in the ctor/dtor
template <typename FieldType, typename ACCESS>
class SmartField<
Expand Down Expand Up @@ -108,13 +108,15 @@ class SmartField<

// --- Default Accessors
template <typename A = ACCESS>
inline
typename std::enable_if_t<!std::is_same<A, READ>::value, T>&
get(const stk::mesh::Entity& entity) const
{
return *stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline
typename std::enable_if_t<!std::is_same<A, READ>::value, T>&
operator()(const stk::mesh::Entity& entity) const
{
Expand All @@ -123,13 +125,15 @@ class SmartField<

// --- Const Accessors
template <typename A = ACCESS>
inline
const typename std::enable_if_t<std::is_same<A, READ>::value, T>&
get(const stk::mesh::Entity& entity) const
{
return *stk::mesh::field_data(stkField_, entity);
}

template <typename A = ACCESS>
inline
const typename std::enable_if_t<std::is_same<A, READ>::value, T>&
operator()(const stk::mesh::Entity& entity) const
{
Expand All @@ -153,7 +157,7 @@ class SmartField<
// These should always be used as part of lambda/functor captures
// using copy by value.
//
// SFINAE is used to remove KOKKOS_FUNCTION type decorators for HOST MEMSPACE
// SFINAE is used to remove KOKKOS_INLINE_FUNCTION type decorators for HOST MEMSPACE
template <typename FieldType, typename MEMSPACE, typename ACCESS>
class SmartField<
FieldType,
Expand Down Expand Up @@ -212,16 +216,18 @@ class SmartField<
}

//************************************************************
// Host functions (Remove KOKKOS_FUNCTION decorators)
// Host functions (Remove KOKKOS_INLINE_FUNCTION decorators)
//************************************************************
template <typename M = MEMSPACE>
inline
std::enable_if_t<std::is_same<M, HOST>::value, unsigned> get_ordinal() const
{
return stkField_.get_ordinal();
}

// --- Default Accessors
template <typename A = ACCESS, typename M = MEMSPACE>
inline
std::enable_if_t<
std::is_same<M, HOST>::value && !std::is_same<A, READ>::value,
T>&
Expand All @@ -231,6 +237,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
inline
std::enable_if_t<
std::is_same<M, HOST>::value && !std::is_same<A, READ>::value,
T>&
Expand All @@ -240,6 +247,7 @@ class SmartField<
}

template <typename A = ACCESS, typename M = MEMSPACE>
inline
std::enable_if_t<
std::is_same<M, HOST>::value && !std::is_same<A, READ>::value,
T>&
Expand All @@ -249,6 +257,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
inline
std::enable_if_t<
std::is_same<M, HOST>::value && !std::is_same<A, READ>::value,
T>&
Expand All @@ -259,6 +268,7 @@ class SmartField<

// --- Const Accessors
template <typename A = ACCESS, typename M = MEMSPACE>
inline
const std::enable_if_t<
std::is_same<M, HOST>::value && std::is_same<A, READ>::value,
T>&
Expand All @@ -268,6 +278,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
inline
const std::enable_if_t<
std::is_same<M, HOST>::value && std::is_same<A, READ>::value,
T>&
Expand All @@ -277,6 +288,7 @@ class SmartField<
}

template <typename A = ACCESS, typename M = MEMSPACE>
inline
const std::enable_if_t<
std::is_same<M, HOST>::value && std::is_same<A, READ>::value,
T>&
Expand All @@ -286,6 +298,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
inline
const std::enable_if_t<
std::is_same<M, HOST>::value && std::is_same<A, READ>::value,
T>&
Expand All @@ -298,15 +311,15 @@ class SmartField<
// Device functions
//************************************************************
template <typename M = MEMSPACE>
KOKKOS_FUNCTION std::enable_if_t<std::is_same<M, DEVICE>::value, unsigned>
KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<M, DEVICE>::value, unsigned>
get_ordinal() const
{
return stkField_.get_ordinal();
}

// --- Default Accessors
template <typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION std::enable_if_t<
KOKKOS_INLINE_FUNCTION std::enable_if_t<
std::is_same<M, DEVICE>::value && !std::is_same<A, READ>::value,
T>&
get(stk::mesh::FastMeshIndex& index, int component) const
Expand All @@ -315,7 +328,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION std::enable_if_t<
KOKKOS_INLINE_FUNCTION std::enable_if_t<
std::is_same<M, DEVICE>::value && !std::is_same<A, READ>::value,
T>&
get(MeshIndex index, int component) const
Expand All @@ -324,7 +337,7 @@ class SmartField<
}

template <typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION std::enable_if_t<
KOKKOS_INLINE_FUNCTION std::enable_if_t<
std::is_same<M, DEVICE>::value && !std::is_same<A, READ>::value,
T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
Expand All @@ -333,7 +346,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION std::enable_if_t<
KOKKOS_INLINE_FUNCTION std::enable_if_t<
std::is_same<M, DEVICE>::value && !std::is_same<A, READ>::value,
T>&
operator()(const MeshIndex index, int component) const
Expand All @@ -343,7 +356,7 @@ class SmartField<

// --- Const Accessors
template <typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION const std::enable_if_t<
KOKKOS_INLINE_FUNCTION const std::enable_if_t<
std::is_same<M, DEVICE>::value && std::is_same<A, READ>::value,
T>&
get(stk::mesh::FastMeshIndex& index, int component) const
Expand All @@ -352,7 +365,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION const std::enable_if_t<
KOKKOS_INLINE_FUNCTION const std::enable_if_t<
std::is_same<M, DEVICE>::value && std::is_same<A, READ>::value,
T>&
get(MeshIndex index, int component) const
Expand All @@ -361,7 +374,7 @@ class SmartField<
}

template <typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION const std::enable_if_t<
KOKKOS_INLINE_FUNCTION const std::enable_if_t<
std::is_same<M, DEVICE>::value && std::is_same<A, READ>::value,
T>&
operator()(const stk::mesh::FastMeshIndex& index, int component) const
Expand All @@ -370,7 +383,7 @@ class SmartField<
}

template <typename MeshIndex, typename A = ACCESS, typename M = MEMSPACE>
KOKKOS_FUNCTION const std::enable_if_t<
KOKKOS_INLINE_FUNCTION const std::enable_if_t<
std::is_same<M, DEVICE>::value && std::is_same<A, READ>::value,
T>&
operator()(const MeshIndex index, int component) const
Expand Down
12 changes: 9 additions & 3 deletions src/SmartField.C
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ using namespace tags;
template class SmartField<stk::mesh::Field<T>, HOST, WRITE_ALL>; \
template class SmartField<stk::mesh::Field<T>, HOST, READ_WRITE>

EXPLICIT_TYPE_INSTANTIATOR_NGP(double);
EXPLICIT_TYPE_INSTANTIATOR_NGP(int);
EXPLICIT_TYPE_INSTANTIATOR_NGP(double);
EXPLICIT_TYPE_INSTANTIATOR_NGP(stk::mesh::EntityId);
EXPLICIT_TYPE_INSTANTIATOR_NGP(HypreIntType);

// Hypre Integer types
// What to do about HYPRE int vs long vs long long here?
/* #ifdef NALU_USES_HYPRE */
/* typedef HYPRE_Int HypreIntType; */
/* EXPLICIT_TYPE_INSTANTIATOR_NGP(HypreIntType); */
/* EXPLICIT_TYPE_INSTANTIATOR_LEGACY(HypreIDFieldType); */
/* #endif */

EXPLICIT_TYPE_INSTANTIATOR_LEGACY(ScalarFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(VectorFieldType);
Expand All @@ -39,7 +46,6 @@ EXPLICIT_TYPE_INSTANTIATOR_LEGACY(GenericIntFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(TpetIDFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(LocalIdFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(GlobalIdFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(HypreIDFieldType);
EXPLICIT_TYPE_INSTANTIATOR_LEGACY(ScalarIntFieldType);

} // namespace sierra::nalu
2 changes: 1 addition & 1 deletion unit_tests/UnitTestSmartField.C
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ TEST_F(TestSmartField, device_write_clear_mod_with_lambda)

ASSERT_TRUE(ngpField_->need_sync_to_device());

auto sPtr = MakeSmartField<DEVICE, WRITE>()(*ngpField_);
auto sPtr = MakeSmartField<DEVICE, WRITE_ALL>()(*ngpField_);
lambda_ordinal(sPtr);

EXPECT_FALSE(ngpField_->need_sync_to_device());
Expand Down

0 comments on commit bf884f8

Please sign in to comment.