diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index d23b7a1cd72d..5d3f424f9e5c 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -606,7 +606,9 @@ std::optional tryFlipJoinType(core::JoinType joinType) { } // Returns a plan with flipped join sides of the input hash join node. If the -// join type doesn't allow flipping, returns a nullptr. +// inputs of the join node are other hash join nodes, recursively flip the join +// sides of those join nodes as well. If the join type doesn't allow flipping, +// returns a nullptr. core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode) { // Null-aware right semi project join doesn't support filter. if (joinNode.filter() && @@ -614,24 +616,36 @@ core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode) { joinNode.isNullAware()) { return nullptr; } + core::PlanNodePtr left = joinNode.sources()[0]; + core::PlanNodePtr right = joinNode.sources()[1]; + if (auto leftJoinInput = std::dynamic_pointer_cast( + joinNode.sources()[0])) { + left = tryFlipJoinSides(*leftJoinInput); + } + if (auto rightJoinInput = std::dynamic_pointer_cast( + joinNode.sources()[1])) { + right = tryFlipJoinSides(*rightJoinInput); + } auto flippedJoinType = tryFlipJoinType(joinNode.joinType()); - if (!flippedJoinType.has_value()) { + if (!flippedJoinType) { return nullptr; } return std::make_shared( joinNode.id(), - flippedJoinType.value(), + *flippedJoinType, joinNode.isNullAware(), joinNode.rightKeys(), joinNode.leftKeys(), joinNode.filter(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); } // Returns a plan with flipped join sides of the input merge join node. If the +// inputs of the join node are other merge join nodes, recursively flip the join +// sides of those join nodes as well. If the // join type doesn't allow flipping, returns a nullptr. core::PlanNodePtr tryFlipJoinSides(const core::MergeJoinNode& joinNode) { // Merge join only supports inner and left join, so only inner join can be @@ -639,33 +653,56 @@ core::PlanNodePtr tryFlipJoinSides(const core::MergeJoinNode& joinNode) { if (joinNode.joinType() != core::JoinType::kInner) { return nullptr; } - auto flippedJoinType = core::JoinType::kInner; + core::PlanNodePtr left = joinNode.sources()[0]; + core::PlanNodePtr right = joinNode.sources()[1]; + if (auto leftJoinInput = std::dynamic_pointer_cast( + joinNode.sources()[0])) { + left = tryFlipJoinSides(*leftJoinInput); + } + if (auto rightJoinInput = + std::dynamic_pointer_cast( + joinNode.sources()[1])) { + right = tryFlipJoinSides(*rightJoinInput); + } return std::make_shared( joinNode.id(), - flippedJoinType, + core::JoinType::kInner, joinNode.rightKeys(), joinNode.leftKeys(), joinNode.filter(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); } // Returns a plan with flipped join sides of the input nested loop join node. If -// the join type doesn't allow flipping, returns a nullptr. +// the inputs of the join node are other nested loop join nodes, recursively +// flip the join sides of those join nodes as well. If the join type doesn't +// allow flipping, returns a nullptr. core::PlanNodePtr tryFlipJoinSides(const core::NestedLoopJoinNode& joinNode) { auto flippedJoinType = tryFlipJoinType(joinNode.joinType()); - if (!flippedJoinType.has_value()) { + if (!flippedJoinType) { return nullptr; } - + core::PlanNodePtr left = joinNode.sources()[0]; + core::PlanNodePtr right = joinNode.sources()[1]; + if (auto leftJoinInput = + std::dynamic_pointer_cast( + joinNode.sources()[0])) { + left = tryFlipJoinSides(*leftJoinInput); + } + if (auto rightJoinInput = + std::dynamic_pointer_cast( + joinNode.sources()[1])) { + right = tryFlipJoinSides(*rightJoinInput); + } return std::make_shared( joinNode.id(), flippedJoinType.value(), joinNode.joinCondition(), - joinNode.sources()[1], - joinNode.sources()[0], + right, + left, joinNode.outputType()); }