Skip to content

Commit cb65b32

Browse files
zhztheplayerJkSelf
authored andcommitted
Register merge extract companion agg functions without suffix
1 parent 8fbed16 commit cb65b32

File tree

4 files changed

+49
-45
lines changed

4 files changed

+49
-45
lines changed

velox/exec/AggregateCompanionAdapter.cpp

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
245245
const core::QueryConfig& config)
246246
-> std::unique_ptr<Aggregate> {
247247
if (auto func = getAggregateFunctionEntry(name)) {
248+
core::AggregationNode::Step usedStep{
249+
core::AggregationNode::Step::kPartial};
248250
if (!exec::isRawInput(step)) {
249-
step = core::AggregationNode::Step::kIntermediate;
251+
usedStep = core::AggregationNode::Step::kIntermediate;
250252
}
251-
auto fn = func->factory(step, argTypes, resultType, config);
253+
auto fn =
254+
func->factory(usedStep, argTypes, resultType, config);
252255
VELOX_CHECK_NOT_NULL(fn);
253256
return std::make_unique<
254257
AggregateCompanionAdapter::PartialFunction>(
@@ -366,56 +369,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
366369
const std::string& name,
367370
const std::vector<AggregateFunctionSignaturePtr>& signatures,
368371
bool overwrite) {
372+
bool registered = false;
369373
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
370374
signatures)) {
371-
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
375+
registered |=
376+
registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
372377
}
373378

374379
auto mergeExtractSignatures =
375380
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
376381
if (mergeExtractSignatures.empty()) {
377-
return false;
382+
return registered;
378383
}
379384

380385
auto mergeExtractFunctionName =
381386
CompanionSignatures::mergeExtractFunctionName(name);
382-
return exec::registerAggregateFunction(
383-
mergeExtractFunctionName,
384-
std::move(mergeExtractSignatures),
385-
[name, mergeExtractFunctionName](
386-
core::AggregationNode::Step /*step*/,
387-
const std::vector<TypePtr>& argTypes,
388-
const TypePtr& resultType,
389-
const core::QueryConfig& config)
390-
-> std::unique_ptr<Aggregate> {
391-
const auto& [originalResultType, _] =
392-
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
393-
if (!originalResultType) {
394-
// TODO: limitation -- result type must be resolveable given
395-
// intermediate type of the original UDAF.
396-
VELOX_UNREACHABLE(
397-
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
398-
}
399-
400-
if (auto func = getAggregateFunctionEntry(name)) {
401-
auto fn = func->factory(
402-
core::AggregationNode::Step::kFinal,
403-
argTypes,
404-
originalResultType,
405-
config);
406-
VELOX_CHECK_NOT_NULL(fn);
407-
return std::make_unique<
408-
AggregateCompanionAdapter::MergeExtractFunction>(
409-
std::move(fn), resultType);
410-
}
411-
VELOX_FAIL(
412-
"Original aggregation function {} not found: {}",
413-
name,
414-
mergeExtractFunctionName);
415-
},
416-
/*registerCompanionFunctions*/ false,
417-
overwrite)
418-
.mainFunction;
387+
registered |=
388+
exec::registerAggregateFunction(
389+
mergeExtractFunctionName,
390+
std::move(mergeExtractSignatures),
391+
[name, mergeExtractFunctionName](
392+
core::AggregationNode::Step /*step*/,
393+
const std::vector<TypePtr>& argTypes,
394+
const TypePtr& resultType,
395+
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
396+
if (auto func = getAggregateFunctionEntry(name)) {
397+
auto fn = func->factory(
398+
core::AggregationNode::Step::kFinal,
399+
argTypes,
400+
resultType,
401+
config);
402+
VELOX_CHECK_NOT_NULL(fn);
403+
return std::make_unique<
404+
AggregateCompanionAdapter::MergeExtractFunction>(
405+
std::move(fn), resultType);
406+
}
407+
VELOX_FAIL(
408+
"Original aggregation function {} not found: {}",
409+
name,
410+
mergeExtractFunctionName);
411+
},
412+
/*registerCompanionFunctions*/ false,
413+
overwrite)
414+
.mainFunction;
415+
return registered;
419416
}
420417

421418
bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(

velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ class BloomFilterAggAggregate : public exec::Aggregate {
288288
} // namespace
289289

290290
exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
291-
const std::string& name) {
291+
const std::string& name,
292+
bool withCompanionFunctions,
293+
bool overwrite) {
292294
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
293295
exec::AggregateFunctionSignatureBuilder()
294296
.argumentType("bigint")
@@ -318,6 +320,8 @@ exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
318320
const TypePtr& resultType,
319321
const core::QueryConfig& config) -> std::unique_ptr<exec::Aggregate> {
320322
return std::make_unique<BloomFilterAggAggregate>(resultType, config);
321-
});
323+
},
324+
withCompanionFunctions,
325+
overwrite);
322326
}
323327
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
namespace facebook::velox::functions::aggregate::sparksql {
2424

2525
exec::AggregateRegistrationResult registerBloomFilterAggAggregate(
26-
const std::string& name);
26+
const std::string& name,
27+
bool withCompanionFunctions,
28+
bool overwrite);
2729

2830
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/Register.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ void registerAggregateFunctions(
3939
registerFirstLastAggregates(prefix, withCompanionFunctions, overwrite);
4040
registerMinMaxByAggregates(prefix, withCompanionFunctions, overwrite);
4141
registerBitwiseXorAggregate(prefix, withCompanionFunctions, overwrite);
42-
registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
42+
registerBloomFilterAggAggregate(
43+
prefix + "bloom_filter_agg", withCompanionFunctions, overwrite);
4344
registerAverage(prefix + "avg", withCompanionFunctions, overwrite);
4445
registerSum(prefix + "sum", withCompanionFunctions, overwrite);
4546
}

0 commit comments

Comments
 (0)