Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions axiom/optimizer/DerivedTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,23 +459,18 @@ importExpr(ExprCP expr, const ColumnVector& outer, const ExprVector& inner) {
case PlanType::kCallExpr: {
auto children = expr->children();
ExprVector newChildren(children.size());
FunctionSet functions;
bool anyChange = false;
for (auto i = 0; i < children.size(); ++i) {
newChildren[i] = importExpr(children[i]->as<Expr>(), outer, inner);
anyChange |= newChildren[i] != children[i];
if (newChildren[i]->isFunction()) {
functions = functions | newChildren[i]->as<Call>()->functions();
}
}

if (!anyChange) {
return expr;
}

const auto* call = expr->as<Call>();
return make<Call>(
call->name(), call->value(), std::move(newChildren), functions);
return make<Call>(call->name(), call->value(), std::move(newChildren));
}
default:
VELOX_UNREACHABLE("{}", expr->toString());
Expand Down
5 changes: 5 additions & 0 deletions axiom/optimizer/FunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class FunctionSet {
return 0 != (set_ & item);
}

FunctionSet& operator|=(const FunctionSet& other) {
set_ |= other.set_;
return *this;
}

/// Unions 'this' and 'other' and returns the result.
FunctionSet operator|(const FunctionSet& other) const {
return FunctionSet(set_ | other.set_);
Expand Down
8 changes: 2 additions & 6 deletions axiom/optimizer/JoinSample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ template <typename... T>
ExprCP
makeCall(std::string_view name, const velox::TypePtr& type, T... inputs) {
return make<Call>(
toName(name),
Value(toType(type), 1),
ExprVector{inputs...},
FunctionSet{});
toName(name), Value(toType(type), 1), ExprVector{inputs...});
}

Value bigintValue() {
Expand Down Expand Up @@ -144,8 +141,7 @@ std::shared_ptr<runner::Runner> prepareSampleRunner(
hashes.emplace_back(makeCall(kHash, velox::BIGINT(), key));
}

ExprCP hash =
make<Call>(toName(kHashMix), bigintValue(), hashes, FunctionSet{});
ExprCP hash = make<Call>(toName(kHashMix), bigintValue(), hashes);

ColumnCP hashColumn = make<Column>(toName("hash"), nullptr, hash->value());
RelationOpPtr project = make<Project>(
Expand Down
6 changes: 1 addition & 5 deletions axiom/optimizer/Optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,6 @@ void Optimization::addAggregation(
agg->name(),
agg->value(),
std::move(args),
agg->functions(),
agg->isDistinct(),
condition,
agg->intermediateType(),
Expand Down Expand Up @@ -1857,10 +1856,7 @@ ExprCP Optimization::combineLeftDeep(Name func, const ExprVector& exprs) {
ExprCP result = copy[0];
for (auto i = 1; i < copy.size(); ++i) {
result = toGraph_.deduppedCall(
func,
result->value(),
ExprVector{result, copy[i]},
result->functions() | copy[i]->functions());
func, result->value(), ExprVector{result, copy[i]});
}
return result;
}
Expand Down
67 changes: 60 additions & 7 deletions axiom/optimizer/QueryGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "axiom/optimizer/Optimization.h"
#include "axiom/optimizer/PlanUtils.h"
#include "velox/expression/ScopedVarSetter.h"
#include "velox/functions/FunctionRegistry.h"

namespace facebook::axiom::optimizer {

Expand Down Expand Up @@ -94,21 +95,40 @@ std::string Column::toString() const {
return fmt::format("{}.{}", cname, name_);
}

Call::Call(
PlanType type,
Name name,
const Value& value,
ExprVector args,
FunctionSet functions)
namespace {

// Helper to efficiently collect function sets from arguments.
// Only calls the virtual functions() method for Call expressions.
FunctionSet availableFunctions(const ExprCP& arg) {
if (arg->is(PlanType::kCallExpr) || arg->is(PlanType::kAggregateExpr)) {
return arg->as<Call>()->functions();
}

return FunctionSet(0);
}

} // namespace

Call::Call(PlanType type, Name name, const Value& value, ExprVector args)
: Expr(type, value),
name_(name),
args_(std::move(args)),
functions_(functions),
functions_(),
metadata_(functionMetadata(name_)) {
if (metadata_) {
functions_ |= metadata_->functionSet;
}

const auto deterministic = velox::isDeterministic(name);
if (deterministic.has_value() && !deterministic.value()) {
functions_ |= FunctionSet(FunctionSet::kNonDeterministic);
}

for (auto arg : args_) {
columns_.unionSet(arg->columns());
subexpressions_.unionSet(arg->subexpressions());
subexpressions_.add(arg);
functions_ |= availableFunctions(arg);
}
}

Expand All @@ -121,6 +141,39 @@ std::string Call::toString() const {
return out.str();
}

Aggregate::Aggregate(
Name name,
const Value& value,
ExprVector args,
bool isDistinct,
ExprCP condition,
const velox::Type* intermediateType,
ExprVector orderKeys,
OrderTypeVector orderTypes)
: Call(PlanType::kAggregateExpr, name, value, std::move(args)),
isDistinct_(isDistinct),
condition_(condition),
intermediateType_(intermediateType),
orderKeys_(std::move(orderKeys)),
orderTypes_(std::move(orderTypes)) {
VELOX_CHECK_EQ(orderKeys_.size(), orderTypes_.size());
functions_ |= FunctionSet(FunctionSet::kAggregate);

for (const auto& arg : this->args()) {
rawInputType_.push_back(arg->value().type);
}

if (condition_) {
columns_.unionSet(condition_->columns());
functions_ |= availableFunctions(condition_);
}

for (auto& key : orderKeys_) {
columns_.unionSet(key->columns());
functions_ |= availableFunctions(key);
}
}

std::string Aggregate::toString() const {
std::stringstream out;
out << name() << "(";
Expand Down
44 changes: 8 additions & 36 deletions axiom/optimizer/QueryGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,10 @@ using FunctionMetadataCP = const FunctionMetadata*;
/// subexpressions.
class Call : public Expr {
public:
Call(
PlanType type,
Name name,
const Value& value,
ExprVector args,
FunctionSet functions);
Call(PlanType type, Name name, const Value& value, ExprVector args);

Call(Name name, Value value, ExprVector args, FunctionSet functions)
: Call(PlanType::kCallExpr, name, value, std::move(args), functions) {}
Call(Name name, Value value, ExprVector args)
: Call(PlanType::kCallExpr, name, value, std::move(args)) {}

Name name() const {
return name_;
Expand Down Expand Up @@ -289,16 +284,17 @@ class Call : public Expr {
return metadata_;
}

protected:
// Set of functions used in 'this' and 'args'.
FunctionSet functions_;

private:
// name of function.
Name const name_;

// Arguments.
const ExprVector args_;

// Set of functions used in 'this' and 'args'.
const FunctionSet functions_;

FunctionMetadataCP metadata_;
};

Expand Down Expand Up @@ -822,35 +818,11 @@ class Aggregate : public Call {
Name name,
const Value& value,
ExprVector args,
FunctionSet functions,
bool isDistinct,
ExprCP condition,
const velox::Type* intermediateType,
ExprVector orderKeys,
OrderTypeVector orderTypes)
: Call(
PlanType::kAggregateExpr,
name,
value,
std::move(args),
functions | FunctionSet::kAggregate),
isDistinct_(isDistinct),
condition_(condition),
intermediateType_(intermediateType),
orderKeys_(std::move(orderKeys)),
orderTypes_(std::move(orderTypes)) {
VELOX_CHECK_EQ(orderKeys_.size(), orderTypes_.size());

for (auto& arg : this->args()) {
rawInputType_.push_back(arg->value().type);
}
if (condition_) {
columns_.unionSet(condition_->columns());
}
for (auto& key : orderKeys_) {
columns_.unionSet(key->columns());
}
}
OrderTypeVector orderTypes);

ExprCP condition() const {
return condition_;
Expand Down
35 changes: 8 additions & 27 deletions axiom/optimizer/ToGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ ExprCP ToGraph::tryFoldConstant(
const ExprVector& literals) {
try {
Value value(toType(returnType), 1);
auto* veraxExpr = make<Call>(
PlanType::kCallExpr, toName(callName), value, literals, FunctionSet());
auto* veraxExpr =
make<Call>(PlanType::kCallExpr, toName(callName), value, literals);
auto typedExpr = queryCtx()->optimization()->toTypedExpr(veraxExpr);
auto exprSet = evaluator_.compile(typedExpr);
auto first = exprSet->exprs().front().get();
Expand Down Expand Up @@ -478,17 +478,15 @@ ExprCP ToGraph::makeGettersOverSkyline(
subscriptLiteral(subscriptType->kind(), step)),
};

expr = make<Call>(
subscript_, Value(valueType, 1), std::move(args), FunctionSet());
expr = make<Call>(subscript_, Value(valueType, 1), std::move(args));
break;
}

case StepKind::kCardinality: {
expr = make<Call>(
cardinality_,
Value(toType(velox::BIGINT()), 1),
ExprVector{expr},
FunctionSet());
ExprVector{expr});
break;
}
default:
Expand Down Expand Up @@ -591,19 +589,15 @@ void ToGraph::canonicalizeCall(Name& name, ExprVector& args) {
}
}

ExprCP ToGraph::deduppedCall(
Name name,
Value value,
ExprVector args,
FunctionSet flags) {
ExprCP ToGraph::deduppedCall(Name name, Value value, ExprVector args) {
canonicalizeCall(name, args);
ExprDedupKey key = {name, args};

auto [it, emplaced] = functionDedup_.try_emplace(key);
if (it->second) {
return it->second;
}
auto* call = make<Call>(name, value, std::move(args), flags);
auto* call = make<Call>(name, value, std::move(args));
if (emplaced && !call->containsNonDeterministic()) {
it->second = call;
}
Expand Down Expand Up @@ -705,7 +699,6 @@ ExprCP ToGraph::translateExpr(const lp::ExprPtr& expr) {
: nullptr;

if (call || specialForm) {
FunctionSet funcs;
const auto& inputs = expr->inputs();
ExprVector args;
args.reserve(inputs.size());
Expand All @@ -717,9 +710,6 @@ ExprCP ToGraph::translateExpr(const lp::ExprPtr& expr) {
args.emplace_back(arg);
allConstant &= arg->is(PlanType::kLiteralExpr);
cardinality = std::max(cardinality, arg->value().cardinality);
if (arg->is(PlanType::kCallExpr)) {
funcs = funcs | arg->as<Call>()->functions();
}
}

auto name = call ? toName(callName)
Expand All @@ -730,9 +720,8 @@ ExprCP ToGraph::translateExpr(const lp::ExprPtr& expr) {
}
}

funcs = funcs | functionBits(name);
auto* callExpr = deduppedCall(
name, Value(toType(expr->type()), cardinality), std::move(args), funcs);
name, Value(toType(expr->type()), cardinality), std::move(args));
return callExpr;
}

Expand Down Expand Up @@ -810,15 +799,11 @@ std::optional<ExprCP> ToGraph::translateSubfieldFunction(
const auto& inputs = call->inputs();
ExprVector args(inputs.size());
float cardinality = 1;
FunctionSet funcs;
for (auto i = 0; i < inputs.size(); ++i) {
const auto& input = inputs[i];
if (allUsed || usedArgs.contains(i)) {
args[i] = translateExpr(input);
cardinality = std::max(cardinality, args[i]->value().cardinality);
if (args[i]->is(PlanType::kCallExpr)) {
funcs = funcs | args[i]->as<Call>()->functions();
}
} else {
// Make a null of the type for the unused arg to keep the tree valid.
const auto& inputType = input->type();
Expand All @@ -829,7 +814,6 @@ std::optional<ExprCP> ToGraph::translateSubfieldFunction(
}

auto* name = toName(velox::exec::sanitizeName(call->name()));
funcs = funcs | functionBits(name);

if (metadata->explode) {
auto map = metadata->explode(call, paths);
Expand All @@ -855,7 +839,7 @@ std::optional<ExprCP> ToGraph::translateSubfieldFunction(
}
}
auto* callExpr =
make<Call>(name, Value(toType(call->type()), cardinality), args, funcs);
make<Call>(name, Value(toType(call->type()), cardinality), args);
return callExpr;
}

Expand Down Expand Up @@ -1021,10 +1005,8 @@ AggregationPlanCP ToGraph::translateAggregation(const lp::AggregateNode& agg) {
const auto& aggregate = agg.aggregates()[i];
ExprVector args = translateExprs(aggregate->inputs());

FunctionSet funcs;
std::vector<velox::TypePtr> argTypes;
for (auto& arg : args) {
funcs = funcs | arg->functions();
argTypes.push_back(toTypePtr(arg->value().type));
}
ExprCP condition = nullptr;
Expand Down Expand Up @@ -1082,7 +1064,6 @@ AggregationPlanCP ToGraph::translateAggregation(const lp::AggregateNode& agg) {
aggName,
finalValue,
std::move(args),
funcs,
isDistinct,
condition,
accumulatorType,
Expand Down
3 changes: 1 addition & 2 deletions axiom/optimizer/ToGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ class ToGraph {

/// Creates or returns pre-existing function call with name+args. If
/// deterministic, a new ExprCP is remembered for reuse.
ExprCP
deduppedCall(Name name, Value value, ExprVector args, FunctionSet flags);
ExprCP deduppedCall(Name name, Value value, ExprVector args);

/// True if 'expr' is of the form a = b where a depends on one of 'tables' and
/// b on the other. If true, returns the side depending on tables[0] in 'left'
Expand Down