Skip to content

Commit

Permalink
feat(fuzzer): Update "tryFlipJoinSides" functions to handle multi-joi…
Browse files Browse the repository at this point in the history
…ns (facebookincubator#11938)

Summary:

The function should traverse the plan tree and recursively the sides of all join nodes that are eligible to be flipped.

Differential Revision: D67606686
  • Loading branch information
Daniel Hunte authored and facebook-github-bot committed Dec 23, 2024
1 parent 1bdcb0b commit d2f8649
Showing 1 changed file with 51 additions and 14 deletions.
65 changes: 51 additions & 14 deletions velox/exec/fuzzer/JoinFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,66 +606,103 @@ std::optional<core::JoinType> tryFlipJoinType(core::JoinType joinType) {
}

// Returns a plan with flipped join sides of the input hash join node. If the
// join type doesn't allow flipping, returns a nullptr.
// inputs of the join node are other hash join nodes, recursively flip the join
// sides of those join nodes as well. If the join type doesn't allow flipping,
// returns a nullptr.
core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode) {
// Null-aware right semi project join doesn't support filter.
if (joinNode.filter() &&
joinNode.joinType() == core::JoinType::kLeftSemiProject &&
joinNode.isNullAware()) {
return nullptr;
}
core::PlanNodePtr left = joinNode.sources()[0];
core::PlanNodePtr right = joinNode.sources()[1];
if (auto leftJoinInput = std::dynamic_pointer_cast<const core::HashJoinNode>(
joinNode.sources()[0])) {
left = tryFlipJoinSides(*leftJoinInput);
}
if (auto rightJoinInput = std::dynamic_pointer_cast<const core::HashJoinNode>(
joinNode.sources()[1])) {
right = tryFlipJoinSides(*rightJoinInput);
}
auto flippedJoinType = tryFlipJoinType(joinNode.joinType());
if (!flippedJoinType.has_value()) {
if (!flippedJoinType) {
return nullptr;
}

return std::make_shared<core::HashJoinNode>(
joinNode.id(),
flippedJoinType.value(),
*flippedJoinType,
joinNode.isNullAware(),
joinNode.rightKeys(),
joinNode.leftKeys(),
joinNode.filter(),
joinNode.sources()[1],
joinNode.sources()[0],
right,
left,
joinNode.outputType());
}

// Returns a plan with flipped join sides of the input merge join node. If the
// inputs of the join node are other merge join nodes, recursively flip the join
// sides of those join nodes as well. If the
// join type doesn't allow flipping, returns a nullptr.
core::PlanNodePtr tryFlipJoinSides(const core::MergeJoinNode& joinNode) {
// Merge join only supports inner and left join, so only inner join can be
// flipped.
if (joinNode.joinType() != core::JoinType::kInner) {
return nullptr;
}
auto flippedJoinType = core::JoinType::kInner;
core::PlanNodePtr left = joinNode.sources()[0];
core::PlanNodePtr right = joinNode.sources()[1];
if (auto leftJoinInput = std::dynamic_pointer_cast<const core::MergeJoinNode>(
joinNode.sources()[0])) {
left = tryFlipJoinSides(*leftJoinInput);
}
if (auto rightJoinInput =
std::dynamic_pointer_cast<const core::MergeJoinNode>(
joinNode.sources()[1])) {
right = tryFlipJoinSides(*rightJoinInput);
}

return std::make_shared<core::MergeJoinNode>(
joinNode.id(),
flippedJoinType,
core::JoinType::kInner,
joinNode.rightKeys(),
joinNode.leftKeys(),
joinNode.filter(),
joinNode.sources()[1],
joinNode.sources()[0],
right,
left,
joinNode.outputType());
}

// Returns a plan with flipped join sides of the input nested loop join node. If
// the join type doesn't allow flipping, returns a nullptr.
// the inputs of the join node are other nested loop join nodes, recursively
// flip the join sides of those join nodes as well. If the join type doesn't
// allow flipping, returns a nullptr.
core::PlanNodePtr tryFlipJoinSides(const core::NestedLoopJoinNode& joinNode) {
auto flippedJoinType = tryFlipJoinType(joinNode.joinType());
if (!flippedJoinType.has_value()) {
if (!flippedJoinType) {
return nullptr;
}

core::PlanNodePtr left = joinNode.sources()[0];
core::PlanNodePtr right = joinNode.sources()[1];
if (auto leftJoinInput =
std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(
joinNode.sources()[0])) {
left = tryFlipJoinSides(*leftJoinInput);
}
if (auto rightJoinInput =
std::dynamic_pointer_cast<const core::NestedLoopJoinNode>(
joinNode.sources()[1])) {
right = tryFlipJoinSides(*rightJoinInput);
}
return std::make_shared<core::NestedLoopJoinNode>(
joinNode.id(),
flippedJoinType.value(),
joinNode.joinCondition(),
joinNode.sources()[1],
joinNode.sources()[0],
right,
left,
joinNode.outputType());
}

Expand Down

0 comments on commit d2f8649

Please sign in to comment.