Skip to content

Commit 981c79c

Browse files
pashandor789meta-codesync[bot]
authored andcommitted
feat: Use aggregate function metadata to drop DISTINCT/ORDER BY (#490)
Summary: - Optimize agg(distinct x) into agg(x) if 'agg' ignores duplicates. E.g. sum(distinct x) -> sum(x). - Optimize agg(x order by y) into agg(x) if 'agg' is not order sensitive. E.g. sum(x order by y) -> sum(x). Pull Request resolved: #490 Reviewed By: kKPulla Differential Revision: D84598420 Pulled By: mbasmanova fbshipit-source-id: c16a18c5459a1767eb365a8550d72e6086663be5
1 parent 54c79fe commit 981c79c

File tree

3 files changed

+284
-11
lines changed

3 files changed

+284
-11
lines changed

axiom/optimizer/ToGraph.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "axiom/optimizer/Optimization.h"
2323
#include "axiom/optimizer/Plan.h"
2424
#include "axiom/optimizer/PlanUtils.h"
25+
#include "velox/exec/Aggregate.h"
2526
#include "velox/exec/AggregateFunctionRegistry.h"
2627
#include "velox/expression/ConstantExpr.h"
2728
#include "velox/expression/Expr.h"
@@ -1031,14 +1032,24 @@ AggregationPlanCP ToGraph::translateAggregation(const lp::AggregateNode& agg) {
10311032
condition = translateExpr(aggregate->filter());
10321033
}
10331034

1034-
auto [orderKeys, orderTypes] = dedupOrdering(aggregate->ordering());
1035+
const auto& metadata =
1036+
velox::exec::getAggregateFunctionMetadata(aggregate->name());
10351037

1036-
if (aggregate->isDistinct() && !orderKeys.empty()) {
1038+
const bool isDistinct =
1039+
!metadata.ignoreDuplicates && aggregate->isDistinct();
1040+
1041+
ExprVector orderKeys;
1042+
OrderTypeVector orderTypes;
1043+
if (metadata.orderSensitive) {
1044+
std::tie(orderKeys, orderTypes) = dedupOrdering(aggregate->ordering());
1045+
}
1046+
1047+
if (isDistinct && !orderKeys.empty()) {
10371048
VELOX_FAIL(
10381049
"DISTINCT with ORDER BY in same aggregation expression isn't supported yet");
10391050
}
10401051

1041-
if (aggregate->isDistinct()) {
1052+
if (isDistinct) {
10421053
const auto& options = queryCtx()->optimization()->runnerOptions();
10431054
VELOX_CHECK(
10441055
options.numWorkers == 1 && options.numDrivers == 1,
@@ -1056,12 +1067,7 @@ AggregationPlanCP ToGraph::translateAggregation(const lp::AggregateNode& agg) {
10561067
auto name = toName(agg.outputNames()[channel]);
10571068

10581069
AggregateDedupKey key{
1059-
aggName,
1060-
aggregate->isDistinct(),
1061-
condition,
1062-
args,
1063-
orderKeys,
1064-
orderTypes};
1070+
aggName, isDistinct, condition, args, orderKeys, orderTypes};
10651071

10661072
auto it = uniqueAggregates.try_emplace(key).first;
10671073
if (it->second) {
@@ -1077,7 +1083,7 @@ AggregationPlanCP ToGraph::translateAggregation(const lp::AggregateNode& agg) {
10771083
finalValue,
10781084
std::move(args),
10791085
funcs,
1080-
aggregate->isDistinct(),
1086+
isDistinct,
10811087
condition,
10821088
accumulatorType,
10831089
std::move(orderKeys),

axiom/optimizer/tests/HiveAggregationQueriesTest.cpp

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,5 +219,272 @@ TEST_F(HiveAggregationQueriesTest, distinctWithOrderBy) {
219219
"DISTINCT with ORDER BY in same aggregation expression isn't supported yet");
220220
}
221221

222+
TEST_F(HiveAggregationQueriesTest, ignoreDuplicates) {
223+
lp::PlanBuilder::Context context(exec::test::kHiveConnectorId);
224+
auto logicalPlan =
225+
lp::PlanBuilder(context)
226+
.tableScan("nation")
227+
.aggregate(
228+
{},
229+
{"bool_and(DISTINCT n_nationkey % 2 = 0)",
230+
"bool_or(DISTINCT n_regionkey % 2 = 0)",
231+
"bool_and(n_nationkey % 2 = 0)",
232+
"bool_or(DISTINCT n_nationkey % 2 = 0)",
233+
"bool_and(DISTINCT n_nationkey % 2 = 0) FILTER (WHERE n_nationkey > 10)",
234+
"bool_or(DISTINCT n_nationkey % 2 = 0) FILTER (WHERE n_nationkey < 20)"})
235+
.build();
236+
237+
{
238+
auto plan = toSingleNodePlan(logicalPlan);
239+
240+
auto matcher =
241+
core::PlanMatcherBuilder()
242+
.tableScan("nation")
243+
.project(
244+
{"n_nationkey % 2 = 0 as m1",
245+
"n_regionkey % 2 = 0 as m2",
246+
"n_nationkey > 10 as m3",
247+
"n_nationkey < 20 as m4"})
248+
.singleAggregation(
249+
{},
250+
{"bool_and(m1) as agg1",
251+
"bool_or(m2) as agg2",
252+
"bool_or(m1) as agg3",
253+
"bool_and(m1) FILTER (WHERE m3) as agg4",
254+
"bool_or(m1) FILTER (WHERE m4) as agg5"})
255+
.project({"agg1", "agg2", "agg1", "agg3", "agg4", "agg5"})
256+
.build();
257+
258+
ASSERT_TRUE(matcher->match(plan));
259+
}
260+
261+
{
262+
auto plan = planVelox(logicalPlan).plan;
263+
const auto& fragments = plan->fragments();
264+
ASSERT_EQ(2, fragments.size());
265+
266+
auto matcher = core::PlanMatcherBuilder()
267+
.tableScan("nation")
268+
.project(
269+
{"n_nationkey % 2 = 0 as m1",
270+
"n_regionkey % 2 = 0 as m2",
271+
"n_nationkey > 10 as m3",
272+
"n_nationkey < 20 as m4"})
273+
.partialAggregation(
274+
{},
275+
{"bool_and(m1)",
276+
"bool_or(m2)",
277+
"bool_or(m1)",
278+
"bool_and(m1) FILTER (WHERE m3)",
279+
"bool_or(m1) FILTER (WHERE m4)"})
280+
.partitionedOutput()
281+
.build();
282+
283+
ASSERT_TRUE(matcher->match(fragments.at(0).fragment.planNode));
284+
285+
matcher = core::PlanMatcherBuilder()
286+
.exchange()
287+
.localPartition()
288+
.finalAggregation()
289+
.project()
290+
.build();
291+
292+
ASSERT_TRUE(matcher->match(fragments.at(1).fragment.planNode));
293+
}
294+
295+
auto referencePlan =
296+
exec::test::PlanBuilder()
297+
.tableScan("nation", getSchema("nation"))
298+
.project(
299+
{"n_nationkey % 2 = 0 as m1",
300+
"n_regionkey % 2 = 0 as m2",
301+
"n_nationkey > 10 as m3",
302+
"n_nationkey < 20 as m4"})
303+
.singleAggregation(
304+
{},
305+
{"bool_and(m1) as agg1",
306+
"bool_or(m2) as agg2",
307+
"bool_or(m1) as agg3",
308+
"bool_and(m1) FILTER (WHERE m3) as agg4",
309+
"bool_or(m1) FILTER (WHERE m4) as agg5"})
310+
.project({"agg1", "agg2", "agg1", "agg3", "agg4", "agg5"})
311+
.planNode();
312+
313+
checkSame(logicalPlan, referencePlan);
314+
}
315+
316+
TEST_F(HiveAggregationQueriesTest, orderNonSensitive) {
317+
lp::PlanBuilder::Context context(exec::test::kHiveConnectorId);
318+
auto logicalPlan =
319+
lp::PlanBuilder(context)
320+
.tableScan("nation")
321+
.aggregate(
322+
{},
323+
{"sum(n_nationkey ORDER BY n_regionkey)",
324+
"sum(n_nationkey ORDER BY n_nationkey DESC, n_regionkey)",
325+
"count(n_regionkey ORDER BY n_nationkey)",
326+
"sum(n_nationkey ORDER BY n_regionkey) FILTER (WHERE n_nationkey > 10)",
327+
"count(n_regionkey ORDER BY n_nationkey) FILTER (WHERE n_nationkey < 20)"})
328+
.build();
329+
330+
{
331+
auto plan = toSingleNodePlan(logicalPlan);
332+
333+
auto matcher = core::PlanMatcherBuilder()
334+
.tableScan("nation")
335+
.project(
336+
{"n_nationkey",
337+
"n_regionkey",
338+
"n_nationkey > 10 as m1",
339+
"n_nationkey < 20 as m2"})
340+
.singleAggregation(
341+
{},
342+
{"sum(n_nationkey) as agg1",
343+
"count(n_regionkey) as agg2",
344+
"sum(n_nationkey) FILTER (WHERE m1) as agg3",
345+
"count(n_regionkey) FILTER (WHERE m2) as agg4"})
346+
.project({"agg1", "agg1", "agg2", "agg3", "agg4"})
347+
.build();
348+
349+
ASSERT_TRUE(matcher->match(plan));
350+
}
351+
352+
{
353+
auto plan = planVelox(logicalPlan, {.numWorkers = 4, .numDrivers = 4}).plan;
354+
const auto& fragments = plan->fragments();
355+
ASSERT_EQ(2, fragments.size());
356+
357+
auto matcher = core::PlanMatcherBuilder()
358+
.tableScan("nation")
359+
.project(
360+
{"n_nationkey",
361+
"n_regionkey",
362+
"n_nationkey > 10 as m1",
363+
"n_nationkey < 20 as m2"})
364+
.partialAggregation(
365+
{},
366+
{"sum(n_nationkey)",
367+
"count(n_regionkey)",
368+
"sum(n_nationkey) FILTER (WHERE m1)",
369+
"count(n_regionkey) FILTER (WHERE m2)"})
370+
.partitionedOutput()
371+
.build();
372+
373+
ASSERT_TRUE(matcher->match(fragments.at(0).fragment.planNode));
374+
375+
matcher = core::PlanMatcherBuilder()
376+
.exchange()
377+
.localPartition()
378+
.finalAggregation()
379+
.project()
380+
.build();
381+
382+
ASSERT_TRUE(matcher->match(fragments.at(1).fragment.planNode));
383+
}
384+
385+
auto referencePlan = exec::test::PlanBuilder()
386+
.tableScan("nation", getSchema("nation"))
387+
.project(
388+
{"n_nationkey",
389+
"n_regionkey",
390+
"n_nationkey > 10 as m1",
391+
"n_nationkey < 20 as m2"})
392+
.singleAggregation(
393+
{},
394+
{"sum(n_nationkey) as agg1",
395+
"count(n_regionkey) as agg2",
396+
"sum(n_nationkey) FILTER (WHERE m1) as agg3",
397+
"count(n_regionkey) FILTER (WHERE m2) as agg4"})
398+
.project({"agg1", "agg1", "agg2", "agg3", "agg4"})
399+
.planNode();
400+
401+
checkSame(logicalPlan, referencePlan);
402+
}
403+
404+
TEST_F(HiveAggregationQueriesTest, ignoreDuplicatesXOrderNonSensitive) {
405+
lp::PlanBuilder::Context context(exec::test::kHiveConnectorId);
406+
auto logicalPlan =
407+
lp::PlanBuilder(context)
408+
.tableScan("nation")
409+
.aggregate(
410+
{},
411+
{
412+
"bool_and(DISTINCT n_nationkey % 2 = 0 ORDER BY n_regionkey)",
413+
"bool_or(DISTINCT n_nationkey % 2 = 0 ORDER BY n_regionkey DESC, n_nationkey)",
414+
"bool_and(n_nationkey % 2 = 0 ORDER BY n_regionkey)",
415+
"bool_and(DISTINCT n_nationkey % 2 = 0 ORDER BY n_regionkey) FILTER (WHERE n_nationkey > 10)",
416+
})
417+
.build();
418+
419+
{
420+
auto plan = toSingleNodePlan(logicalPlan);
421+
422+
auto matcher =
423+
core::PlanMatcherBuilder()
424+
.tableScan("nation")
425+
.project({"n_nationkey % 2 = 0 as m1", "n_nationkey > 10 as m2"})
426+
.singleAggregation(
427+
{},
428+
{"bool_and(m1) as agg1",
429+
"bool_or(m1) as agg2",
430+
"bool_and(m1) FILTER (WHERE m2) as agg3"})
431+
.project({"agg1", "agg2", "agg1", "agg3"})
432+
.build();
433+
434+
ASSERT_TRUE(matcher->match(plan));
435+
}
436+
437+
{
438+
auto plan = planVelox(logicalPlan).plan;
439+
const auto& fragments = plan->fragments();
440+
ASSERT_EQ(2, fragments.size());
441+
442+
auto matcher =
443+
core::PlanMatcherBuilder()
444+
.tableScan("nation")
445+
.project({"n_nationkey % 2 = 0 as m1", "n_nationkey > 10 as m2"})
446+
.partialAggregation(
447+
{},
448+
{
449+
"bool_and(m1)",
450+
"bool_or(m1)",
451+
"bool_and(m1) FILTER (WHERE m2)",
452+
})
453+
.partitionedOutput()
454+
.build();
455+
456+
ASSERT_TRUE(matcher->match(fragments.at(0).fragment.planNode));
457+
458+
matcher = core::PlanMatcherBuilder()
459+
.exchange()
460+
.localPartition()
461+
.finalAggregation()
462+
.project()
463+
.build();
464+
465+
ASSERT_TRUE(matcher->match(fragments.at(1).fragment.planNode));
466+
}
467+
468+
auto referencePlan = exec::test::PlanBuilder()
469+
.tableScan("nation", getSchema("nation"))
470+
.project(
471+
{"n_nationkey % 2 = 0 as m1",
472+
"n_nationkey",
473+
"n_regionkey",
474+
"n_nationkey > 10 as m2",
475+
"n_nationkey < 20 as m3"})
476+
.singleAggregation(
477+
{},
478+
{
479+
"bool_and(m1) as agg1",
480+
"bool_or(m1) as agg2",
481+
"bool_and(m1) FILTER (WHERE m2) as agg3",
482+
})
483+
.project({"agg1", "agg2", "agg1", "agg3"})
484+
.planNode();
485+
486+
checkSame(logicalPlan, referencePlan);
487+
}
488+
222489
} // namespace
223490
} // namespace facebook::axiom::optimizer

axiom/optimizer/tests/PlanMatcher.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ class AggregationMatcher : public PlanMatcherImpl<AggregationNode> {
563563

564564
for (auto i = 0; i < aggregates_.size(); ++i) {
565565
auto aggregateExpr = duckdb::parseAggregateExpr(aggregates_[i], {});
566-
auto expected = aggregateExpr.expr;
566+
auto expected = rewriteInputNames(aggregateExpr.expr, newSymbols);
567567
if (expected->alias()) {
568568
newSymbols[expected->alias().value()] = plan.aggregateNames()[i];
569569
}

0 commit comments

Comments
 (0)