@@ -1622,107 +1622,100 @@ void ToGraph::makeUnionDistributionAndStats(
16221622 }
16231623}
16241624
1625- void ToGraph::translateUnion (const lp::SetNode& set) {
1626- auto * const setDt = currentDt_;
1627- auto initialRenames = std::move (renames_);
1628- QGVector<DerivedTableP> children;
1629- bool isLeftLeaf = true ;
1630- const auto topSetOp = set.operation ();
1625+ void ToGraph::translateUnionInput (
1626+ const folly::F14FastMap<std::string, ExprCP>& renames,
1627+ const lp::LogicalPlanNode& input,
1628+ bool & isLeftLeaf) {
1629+ renames_ = renames;
16311630
1632- auto isUnionLike =
1631+ auto * setDt = currentDt_;
1632+
1633+ auto maybeFlatten =
16331634 [&](const lp::LogicalPlanNode& node) -> const lp::SetNode* {
1634- if (node.kind () == lp::NodeKind::kSet ) {
1635- const auto * set = node.asUnchecked <lp::SetNode>();
1636- if (topSetOp == set->operation ()) {
1637- // Same set operation can be flattened.
1638- return set;
1639- }
1640- if (topSetOp == lp::SetOperation::kUnion &&
1641- set->operation () == lp::SetOperation::kUnionAll ) {
1642- // UNION ALL can be flattened into UNION.
1643- return set;
1644- }
1635+ if (node.kind () != lp::NodeKind::kSet ) {
1636+ return nullptr ;
1637+ }
1638+ const auto * set = node.asUnchecked <lp::SetNode>();
1639+ if (setDt->setOp == set->operation ()) {
1640+ // Same set operation can be flattened.
1641+ return set;
1642+ }
1643+ if (setDt->setOp == lp::SetOperation::kUnion &&
1644+ set->operation () == lp::SetOperation::kUnionAll ) {
1645+ // UNION ALL can be flattened into UNION.
1646+ return set;
16451647 }
1646-
16471648 return nullptr ;
16481649 };
1650+ if (const auto * setNode = maybeFlatten (input)) {
1651+ for (const auto & child : setNode->inputs ()) {
1652+ translateUnionInput (renames, *child, isLeftLeaf);
1653+ }
1654+ } else {
1655+ currentDt_ = newDt ();
1656+ auto * queryDt = makeUnordered (input, kAllAllowedInDt );
1657+ VELOX_DCHECK_NULL (queryDt);
1658+ auto * newDt = currentDt_;
1659+ currentDt_ = setDt;
1660+
1661+ const auto & type = input.outputType ();
16491662
1650- // TODO: use deducing this lambda when C++23 is available.
1651- std::function<void (const lp::LogicalPlanNode&)> addChild;
1663+ if (isLeftLeaf) {
1664+ // This is the left leaf of a union tree.
1665+ for (auto i : usedChannels (input)) {
1666+ const auto & name = type->nameOf (i);
16521667
1653- addChild = [&]( const lp::LogicalPlanNode& input) {
1654- renames_ = initialRenames ;
1668+ ExprCP inner = translateColumn (name);
1669+ newDt-> exprs . push_back (inner) ;
16551670
1656- if (auto * setNode = isUnionLike (input)) {
1657- for (auto & child : setNode->inputs ()) {
1658- addChild (*child);
1671+ // The top dt has the same columns as all the unioned dts.
1672+ const auto * columnName = toName (name);
1673+ auto * outer =
1674+ make<Column>(columnName, setDt, inner->value (), columnName);
1675+ setDt->columns .push_back (outer);
1676+ newDt->columns .push_back (outer);
16591677 }
1678+ isLeftLeaf = false ;
16601679 } else {
1661- currentDt_ = newDt ();
1662- auto * queryDt = makeUnordered (input, kAllAllowedInDt );
1663- VELOX_DCHECK_NULL (queryDt);
1664- auto * newDt = currentDt_;
1665-
1666- const auto & type = input.outputType ();
1667-
1668- if (isLeftLeaf) {
1669- // This is the left leaf of a union tree.
1670- for (auto i : usedChannels (input)) {
1671- const auto & name = type->nameOf (i);
1672-
1673- ExprCP inner = translateColumn (name);
1674- newDt->exprs .push_back (inner);
1675-
1676- // The top dt has the same columns as all the unioned dts.
1677- const auto * columnName = toName (name);
1678- auto * outer =
1679- make<Column>(columnName, setDt, inner->value (), columnName);
1680- setDt->columns .push_back (outer);
1681- newDt->columns .push_back (outer);
1682- }
1683- isLeftLeaf = false ;
1684- } else {
1685- for (auto i : usedChannels (input)) {
1686- ExprCP inner = translateColumn (type->nameOf (i));
1687- newDt->exprs .push_back (inner);
1688- }
1689-
1690- // Same outward facing columns as the top dt of union.
1691- newDt->columns = setDt->columns ;
1680+ for (auto i : usedChannels (input)) {
1681+ ExprCP inner = translateColumn (type->nameOf (i));
1682+ newDt->exprs .push_back (inner);
16921683 }
16931684
1694- newDt-> makeInitialPlan ();
1695- children. push_back ( newDt) ;
1685+ // Same outward facing columns as the top dt of union.
1686+ newDt-> columns = setDt-> columns ;
16961687 }
1697- };
16981688
1699- addChild (set);
1700- currentDt_ = setDt;
1689+ newDt->makeInitialPlan ();
1690+ setDt->children .push_back (newDt);
1691+ }
1692+ }
17011693
1702- setDt-> children = std::move (children);
1703- setDt-> setOp = set. operation ( );
1694+ void ToGraph::translateUnion ( const lp::SetNode& set) {
1695+ auto renames = std::move (renames_ );
17041696
1697+ auto * setDt = currentDt_;
1698+ setDt->setOp = set.operation ();
1699+ bool isLeftLeaf = true ;
1700+ translateUnionInput (renames, set, isLeftLeaf);
17051701 makeUnionDistributionAndStats (setDt);
17061702
1707- renames_ = std::move (initialRenames );
1703+ renames_ = std::move (renames );
17081704 for (const auto * column : setDt->columns ) {
17091705 renames_[column->name ()] = column;
17101706 }
17111707}
17121708
17131709DerivedTableP ToGraph::makeQueryGraph (const lp::LogicalPlanNode& logicalPlan) {
17141710 markAllSubfields (logicalPlan);
1715-
1716- currentDt_ = newDt ();
1717- auto * queryDt = makeQueryGraph (logicalPlan, kAllAllowedInDt );
1718- VELOX_DCHECK_NULL (queryDt);
1711+ wrapInDt (logicalPlan);
17191712 return currentDt_;
17201713}
17211714
17221715DerivedTableP ToGraph::makeUnordered (
1723- const lp::LogicalPlanNode& input ,
1716+ const lp::LogicalPlanNode& node ,
17241717 uint64_t allowedInDt) {
1725- auto * outerDt = makeQueryGraph (input , allowedInDt);
1718+ auto * outerDt = makeQueryGraph (node , allowedInDt);
17261719 if (currentDt_->hasOrderBy () && !currentDt_->hasLimit ()) {
17271720 currentDt_->orderKeys .clear ();
17281721 currentDt_->orderTypes .clear ();
@@ -1731,24 +1724,28 @@ DerivedTableP ToGraph::makeUnordered(
17311724}
17321725
17331726DerivedTableP ToGraph::makeStream (
1734- const lp::LogicalPlanNode& input ,
1727+ const lp::LogicalPlanNode& node ,
17351728 uint64_t allowedInDt) {
1736- auto * outerDt = makeQueryGraph (input , allowedInDt);
1729+ auto * outerDt = makeQueryGraph (node , allowedInDt);
17371730 if (currentDt_->hasLimit ()) {
1738- finalizeDt (input , outerDt);
1731+ finalizeDt (node , outerDt);
17391732 return nullptr ;
17401733 }
17411734 return outerDt;
17421735}
17431736
1737+ void ToGraph::wrapInDt (const lp::LogicalPlanNode& node) {
1738+ currentDt_ = newDt ();
1739+ auto * queryDt = makeQueryGraph (node, kAllAllowedInDt );
1740+ VELOX_DCHECK_NULL (queryDt);
1741+ }
1742+
17441743DerivedTableP ToGraph::makeQueryGraph (
17451744 const lp::LogicalPlanNode& node,
17461745 uint64_t allowedInDt) {
17471746 if (!contains (allowedInDt, node.kind ())) {
17481747 auto * outerDt = currentDt_;
1749- currentDt_ = newDt ();
1750- auto * queryDt = makeQueryGraph (node, kAllAllowedInDt );
1751- VELOX_DCHECK_NULL (queryDt);
1748+ wrapInDt (node);
17521749 return outerDt;
17531750 }
17541751
@@ -1793,20 +1790,22 @@ DerivedTableP ToGraph::makeQueryGraph(
17931790 }
17941791 case lp::NodeKind::kJoin : {
17951792 const auto & join = *node.asUnchecked <lp::JoinNode>();
1793+ const auto & left = *join.left ();
1794+ const auto & right = *join.right ();
17961795 // TODO Allow mixing Unnest with Join in a single DT.
17971796 // https://github.com/facebookincubator/axiom/issues/286
17981797 allowedInDt = makeDtIf (allowedInDt, lp::NodeKind::kUnnest );
17991798 allowedInDt = makeDtIf (allowedInDt, lp::NodeKind::kAggregate );
18001799 allowedInDt = makeDtIf (allowedInDt, lp::NodeKind::kLimit );
1801- if (auto * outerDt = makeUnordered (*join. left () , allowedInDt)) {
1802- finalizeDt (*join. left () , outerDt);
1800+ if (auto * outerDt = makeUnordered (left, allowedInDt)) {
1801+ finalizeDt (left, outerDt);
18031802 }
18041803 if (join.joinType () != lp::JoinType::kInner ||
18051804 queryCtx ()->optimization ()->options ().syntacticJoinOrder ) {
18061805 allowedInDt = makeDtIf (allowedInDt, lp::NodeKind::kJoin );
18071806 }
1808- if (auto * outerDt = makeUnordered (*join. right () , allowedInDt)) {
1809- finalizeDt (*join. right () , outerDt);
1807+ if (auto * outerDt = makeUnordered (right, allowedInDt)) {
1808+ finalizeDt (right, outerDt);
18101809 }
18111810 translateJoin (join);
18121811 return nullptr ;
0 commit comments