@@ -245,10 +245,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
245245 const core::QueryConfig& config)
246246 -> std::unique_ptr<Aggregate> {
247247 if (auto func = getAggregateFunctionEntry (name)) {
248+ core::AggregationNode::Step usedStep{
249+ core::AggregationNode::Step::kPartial };
248250 if (!exec::isRawInput (step)) {
249- step = core::AggregationNode::Step::kIntermediate ;
251+ usedStep = core::AggregationNode::Step::kIntermediate ;
250252 }
251- auto fn = func->factory (step, argTypes, resultType, config);
253+ auto fn =
254+ func->factory (usedStep, argTypes, resultType, config);
252255 VELOX_CHECK_NOT_NULL (fn);
253256 return std::make_unique<
254257 AggregateCompanionAdapter::PartialFunction>(
@@ -366,56 +369,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
366369 const std::string& name,
367370 const std::vector<AggregateFunctionSignaturePtr>& signatures,
368371 bool overwrite) {
372+ bool registered = false ;
369373 if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures (
370374 signatures)) {
371- return registerMergeExtractFunctionWithSuffix (name, signatures, overwrite);
375+ registered |=
376+ registerMergeExtractFunctionWithSuffix (name, signatures, overwrite);
372377 }
373378
374379 auto mergeExtractSignatures =
375380 CompanionSignatures::mergeExtractFunctionSignatures (signatures);
376381 if (mergeExtractSignatures.empty ()) {
377- return false ;
382+ return registered ;
378383 }
379384
380385 auto mergeExtractFunctionName =
381386 CompanionSignatures::mergeExtractFunctionName (name);
382- return exec::registerAggregateFunction (
383- mergeExtractFunctionName,
384- std::move (mergeExtractSignatures),
385- [name, mergeExtractFunctionName](
386- core::AggregationNode::Step /* step*/ ,
387- const std::vector<TypePtr>& argTypes,
388- const TypePtr& resultType,
389- const core::QueryConfig& config)
390- -> std::unique_ptr<Aggregate> {
391- const auto & [originalResultType, _] =
392- resolveAggregateFunction (mergeExtractFunctionName, argTypes);
393- if (!originalResultType) {
394- // TODO: limitation -- result type must be resolveable given
395- // intermediate type of the original UDAF.
396- VELOX_UNREACHABLE (
397- " Signatures whose result types are not resolvable given intermediate types should have been excluded." );
398- }
399-
400- if (auto func = getAggregateFunctionEntry (name)) {
401- auto fn = func->factory (
402- core::AggregationNode::Step::kFinal ,
403- argTypes,
404- originalResultType,
405- config);
406- VELOX_CHECK_NOT_NULL (fn);
407- return std::make_unique<
408- AggregateCompanionAdapter::MergeExtractFunction>(
409- std::move (fn), resultType);
410- }
411- VELOX_FAIL (
412- " Original aggregation function {} not found: {}" ,
413- name,
414- mergeExtractFunctionName);
415- },
416- /* registerCompanionFunctions*/ false ,
417- overwrite)
418- .mainFunction ;
387+ registered |=
388+ exec::registerAggregateFunction (
389+ mergeExtractFunctionName,
390+ std::move (mergeExtractSignatures),
391+ [name, mergeExtractFunctionName](
392+ core::AggregationNode::Step /* step*/ ,
393+ const std::vector<TypePtr>& argTypes,
394+ const TypePtr& resultType,
395+ const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
396+ if (auto func = getAggregateFunctionEntry (name)) {
397+ auto fn = func->factory (
398+ core::AggregationNode::Step::kFinal ,
399+ argTypes,
400+ resultType,
401+ config);
402+ VELOX_CHECK_NOT_NULL (fn);
403+ return std::make_unique<
404+ AggregateCompanionAdapter::MergeExtractFunction>(
405+ std::move (fn), resultType);
406+ }
407+ VELOX_FAIL (
408+ " Original aggregation function {} not found: {}" ,
409+ name,
410+ mergeExtractFunctionName);
411+ },
412+ /* registerCompanionFunctions*/ false ,
413+ overwrite)
414+ .mainFunction ;
415+ return registered;
419416}
420417
421418bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix (
0 commit comments