Skip to content

Commit

Permalink
Merge Pull Request #12790 from kliegeois/Trilinos/getAutomaticNSubparts
Browse files Browse the repository at this point in the history
Automatically Merged using Trilinos Pull Request AutoTester
PR Title: b'Ifpack2: add getAutomaticNSubparts'
PR Author: kliegeois
  • Loading branch information
trilinos-autotester authored Mar 1, 2024
2 parents 7379acb + 29e57ff commit a469209
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 15 deletions.
2 changes: 1 addition & 1 deletion packages/ifpack2/src/Ifpack2_BlockTriDiContainer_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ namespace Ifpack2 {
const bool useSeqMethod = false;
const bool overlapCommAndComp = false;
initInternal(matrix, importer, overlapCommAndComp, useSeqMethod);
n_subparts_per_part_ = 1;
n_subparts_per_part_ = -1;
IFPACK2_BLOCKHELPER_TIMER_FENCE(typename BlockHelperDetails::ImplType<MatrixType>::execution_space)
}

Expand Down
112 changes: 98 additions & 14 deletions packages/ifpack2/src/Ifpack2_BlockTriDiContainer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,22 +842,83 @@ namespace Ifpack2 {
return Teuchos::null;
}

template<typename local_ordinal_type>
local_ordinal_type costTRSM(const local_ordinal_type block_size) {
return block_size*block_size;
}

template<typename local_ordinal_type>
local_ordinal_type costGEMV(const local_ordinal_type block_size) {
return 2*block_size*block_size;
}

template<typename local_ordinal_type>
local_ordinal_type costTriDiagSolve(const local_ordinal_type subline_length, const local_ordinal_type block_size) {
return 2 * subline_length * costTRSM(block_size) + 2 * (subline_length-1) * costGEMV(block_size);
}

template<typename local_ordinal_type>
local_ordinal_type costSolveSchur(const local_ordinal_type num_parts,
const local_ordinal_type num_teams,
const local_ordinal_type line_length,
const local_ordinal_type block_size,
const local_ordinal_type n_subparts_per_part) {
const local_ordinal_type subline_length = ceil(double(line_length - (n_subparts_per_part-1) * 2) / n_subparts_per_part);
if (subline_length < 1) {
return INT_MAX;
}

const local_ordinal_type p_n_lines = ceil(double(num_parts)/num_teams);
const local_ordinal_type p_n_sublines = ceil(double(n_subparts_per_part)*num_parts/num_teams);
const local_ordinal_type p_n_sublines_2 = ceil(double(n_subparts_per_part-1)*num_parts/num_teams);

const local_ordinal_type p_costApplyE = p_n_sublines_2 * subline_length * 2 * costGEMV(block_size);
const local_ordinal_type p_costApplyS = p_n_lines * costTriDiagSolve((n_subparts_per_part-1)*2,block_size);
const local_ordinal_type p_costApplyAinv = p_n_sublines * costTriDiagSolve(subline_length,block_size);
const local_ordinal_type p_costApplyC = p_n_sublines_2 * 2 * costGEMV(block_size);

if (n_subparts_per_part == 1) {
return p_costApplyAinv;
}
return p_costApplyE + p_costApplyS + p_costApplyAinv + p_costApplyC;
}

template<typename local_ordinal_type>
local_ordinal_type getAutomaticNSubparts(const local_ordinal_type num_parts,
const local_ordinal_type num_teams,
const local_ordinal_type line_length,
const local_ordinal_type block_size) {
local_ordinal_type n_subparts_per_part_0 = 1;
local_ordinal_type flop_0 = costSolveSchur(num_parts, num_teams, line_length, block_size, n_subparts_per_part_0);
local_ordinal_type flop_1 = costSolveSchur(num_parts, num_teams, line_length, block_size, n_subparts_per_part_0+1);
while (flop_0 > flop_1) {
flop_0 = flop_1;
flop_1 = costSolveSchur(num_parts, num_teams, line_length, block_size, (++n_subparts_per_part_0)+1);
}
return n_subparts_per_part_0;
}

template<typename ArgActiveExecutionMemorySpace>
struct SolveTridiagsDefaultModeAndAlgo;

///
/// setup part interface using the container partitions array
///
template<typename MatrixType>
BlockHelperDetails::PartInterface<MatrixType>
createPartInterface(const Teuchos::RCP<const typename BlockHelperDetails::ImplType<MatrixType>::tpetra_block_crs_matrix_type> &A,
const Teuchos::Array<Teuchos::Array<typename BlockHelperDetails::ImplType<MatrixType>::local_ordinal_type> > &partitions,
const typename BlockHelperDetails::ImplType<MatrixType>::local_ordinal_type n_subparts_per_part) {
const typename BlockHelperDetails::ImplType<MatrixType>::local_ordinal_type n_subparts_per_part_in) {
IFPACK2_BLOCKHELPER_TIMER("createPartInterface");
using impl_type = BlockHelperDetails::ImplType<MatrixType>;
using local_ordinal_type = typename impl_type::local_ordinal_type;
using local_ordinal_type_1d_view = typename impl_type::local_ordinal_type_1d_view;
using local_ordinal_type_2d_view = typename impl_type::local_ordinal_type_2d_view;
using size_type = typename impl_type::size_type;

const auto blocksize = A->getBlockSize();
constexpr int vector_length = impl_type::vector_length;
constexpr int internal_vector_length = impl_type::internal_vector_length;

const auto comm = A->getRowMap()->getComm();

Expand All @@ -867,6 +928,40 @@ namespace Ifpack2 {
const local_ordinal_type A_n_lclrows = A->getLocalNumRows();
const local_ordinal_type nparts = jacobi ? A_n_lclrows : partitions.size();

typedef std::pair<local_ordinal_type,local_ordinal_type> size_idx_pair_type;
std::vector<size_idx_pair_type> partsz(nparts);

if (!jacobi) {
for (local_ordinal_type i=0;i<nparts;++i)
partsz[i] = size_idx_pair_type(partitions[i].size(), i);
std::sort(partsz.begin(), partsz.end(),
[] (const size_idx_pair_type& x, const size_idx_pair_type& y) {
return x.first > y.first;
});
}

local_ordinal_type n_subparts_per_part;
if (n_subparts_per_part_in == -1) {
// If the number of subparts is set to -1, the user let the algorithm
// decides the value automatically
using execution_space = typename impl_type::execution_space;

const int line_length = partsz[0].first;

const local_ordinal_type team_size =
SolveTridiagsDefaultModeAndAlgo<typename execution_space::memory_space>::
recommended_team_size(blocksize, vector_length, internal_vector_length);

const local_ordinal_type num_teams = execution_space().concurrency() / (team_size * vector_length);

n_subparts_per_part = getAutomaticNSubparts(nparts, num_teams, line_length, blocksize);

printf("Automatically chosen n_subparts_per_part = %d for nparts = %d, num_teams = %d, team_size = %d, line_length = %d, and blocksize = %d;\n", n_subparts_per_part, nparts, num_teams, team_size, line_length, blocksize);
}
else {
n_subparts_per_part = n_subparts_per_part_in;
}

// Total number of sub lines:
const local_ordinal_type n_sub_parts = nparts * n_subparts_per_part;
// Total number of sub lines + the Schur complement blocks.
Expand Down Expand Up @@ -896,14 +991,6 @@ namespace Ifpack2 {
// reorder parts to maximize simd packing efficiency
p.resize(nparts);

typedef std::pair<local_ordinal_type,local_ordinal_type> size_idx_pair_type;
std::vector<size_idx_pair_type> partsz(nparts);
for (local_ordinal_type i=0;i<nparts;++i)
partsz[i] = size_idx_pair_type(partitions[i].size(), i);
std::sort(partsz.begin(), partsz.end(),
[] (const size_idx_pair_type& x, const size_idx_pair_type& y) {
return x.first > y.first;
});
for (local_ordinal_type i=0;i<nparts;++i)
p[i] = partsz[i].second;

Expand Down Expand Up @@ -2074,9 +2161,6 @@ namespace Ifpack2 {
};
#endif

template<typename ArgActiveExecutionMemorySpace>
struct SolveTridiagsDefaultModeAndAlgo;

template<typename impl_type, typename WWViewType>
KOKKOS_INLINE_FUNCTION
void
Expand Down Expand Up @@ -3251,7 +3335,7 @@ namespace Ifpack2 {

{
#ifdef IFPACK2_BLOCKTRIDICONTAINER_USE_PRINTF
printf("Star ComputeSchurTag\n");
printf("Start ComputeSchurTag\n");
#endif
IFPACK2_BLOCKHELPER_TIMER("BlockTriDi::NumericPhase::ComputeSchurTag");
writeBTDValuesToFile(part2packrowidx0_sub.extent(0), scalar_values_schur, "before_schur.mm");
Expand All @@ -3270,7 +3354,7 @@ namespace Ifpack2 {

{
#ifdef IFPACK2_BLOCKTRIDICONTAINER_USE_PRINTF
printf("Star FactorizeSchurTag\n");
printf("Start FactorizeSchurTag\n");
#endif
IFPACK2_BLOCKHELPER_TIMER("BlockTriDi::NumericPhase::FactorizeSchurTag");
Kokkos::TeamPolicy<execution_space,FactorizeSchurTag>
Expand Down

0 comments on commit a469209

Please sign in to comment.