Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,12 @@ cc_library(
"//runtime:function_overload_reference",
"//runtime:function_registry",
"//runtime:type_registry",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

Expand Down
7 changes: 3 additions & 4 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1637,12 +1637,11 @@ class FlatExprVisitor : public cel::AstVisitor {
// Establish the search criteria for a given function.
bool receiver_style = call_expr->has_target();
size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0);
auto arguments_matcher = ArgumentsMatcher(num_args);

// First, search for lazily defined function overloads.
// Lazy functions shadow eager functions with the same signature.
auto lazy_overloads = resolver_.FindLazyOverloads(
function, call_expr->has_target(), arguments_matcher, expr->id());
function, call_expr->has_target(), num_args, expr->id());
if (!lazy_overloads.empty()) {
auto depth = RecursionEligible();
if (depth.has_value()) {
Expand All @@ -1659,8 +1658,8 @@ class FlatExprVisitor : public cel::AstVisitor {
}

// Second, search for eagerly defined function overloads.
auto overloads = resolver_.FindOverloads(function, receiver_style,
arguments_matcher, expr->id());
auto overloads =
resolver_.FindOverloads(function, receiver_style, num_args, expr->id());
if (overloads.empty()) {
// Create a warning that the overload could not be found. Depending on the
// builder_warnings configuration, this could result in termination of the
Expand Down
80 changes: 63 additions & 17 deletions eval/compiler/resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

#include "eval/compiler/resolver.h"

#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
Expand All @@ -27,6 +30,7 @@
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "common/kind.h"
#include "common/type.h"
#include "common/type_reflector.h"
Expand Down Expand Up @@ -96,37 +100,41 @@ std::vector<std::string> Resolver::FullyQualifiedNames(absl::string_view name,
// and handle the case where this id is in the reference map as either a
// function name or identifier name.
std::vector<std::string> names;
// Handle the case where the name contains a leading '.' indicating it is
// already fully-qualified.
if (absl::StartsWith(name, ".")) {
std::string fully_qualified_name = std::string(name.substr(1));
names.push_back(fully_qualified_name);
return names;
}

// namespace prefixes is guaranteed to contain at least empty string, so this
// function will always produce at least one result.
for (const auto& prefix : namespace_prefixes_) {
auto prefixes = GetPrefixesFor(name);
for (const auto& prefix : prefixes) {
std::string fully_qualified_name = absl::StrCat(prefix, name);
names.push_back(fully_qualified_name);
}
return names;
}

absl::Span<const std::string> Resolver::GetPrefixesFor(
absl::string_view& name) const {
static const absl::NoDestructor<std::string> kEmptyPrefix("");
if (absl::StartsWith(name, ".")) {
name = name.substr(1);
return absl::MakeConstSpan(kEmptyPrefix.get(), 1);
}
return namespace_prefixes_;
}

absl::optional<cel::Value> Resolver::FindConstant(absl::string_view name,
int64_t expr_id) const {
auto names = FullyQualifiedNames(name, expr_id);
for (const auto& name : names) {
auto prefixes = GetPrefixesFor(name);
for (const auto& prefix : prefixes) {
std::string qualified_name = absl::StrCat(prefix, name);
// Attempt to resolve the fully qualified name to a known enum.
auto enum_entry = enum_value_map_.find(name);
auto enum_entry = enum_value_map_.find(qualified_name);
if (enum_entry != enum_value_map_.end()) {
return enum_entry->second;
}
// Conditionally resolve fully qualified names as type values if the option
// to do so is configured in the expression builder. If the type name is
// not qualified, then it too may be returned as a constant value.
if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) {
auto type_value = type_reflector_.FindType(name);
if (resolve_qualified_type_identifiers_ ||
!absl::StrContains(qualified_name, ".")) {
auto type_value = type_reflector_.FindType(qualified_name);
if (type_value.ok() && type_value->has_value()) {
return TypeValue(**type_value);
}
Expand Down Expand Up @@ -157,6 +165,27 @@ std::vector<cel::FunctionOverloadReference> Resolver::FindOverloads(
return funcs;
}

std::vector<cel::FunctionOverloadReference> Resolver::FindOverloads(
absl::string_view name, bool receiver_style, size_t arity,
int64_t expr_id) const {
std::vector<cel::FunctionOverloadReference> funcs;
auto prefixes = GetPrefixesFor(name);
for (const auto& prefix : prefixes) {
std::string qualified_name = absl::StrCat(prefix, name);
// Only one set of overloads is returned along the namespace hierarchy as
// the function name resolution follows the same behavior as variable name
// resolution, meaning the most specific definition wins. This is different
// from how C++ namespaces work, as they will accumulate the overload set
// over the namespace hierarchy.
funcs = function_registry_.FindStaticOverloadsByArity(
qualified_name, receiver_style, arity);
if (!funcs.empty()) {
return funcs;
}
}
return funcs;
}

std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
absl::string_view name, bool receiver_style,
const std::vector<cel::Kind>& types, int64_t expr_id) const {
Expand All @@ -173,10 +202,27 @@ std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
return funcs;
}

std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
absl::string_view name, bool receiver_style, size_t arity,
int64_t expr_id) const {
std::vector<cel::FunctionRegistry::LazyOverload> funcs;
auto prefixes = GetPrefixesFor(name);
for (const auto& prefix : prefixes) {
std::string qualified_name = absl::StrCat(prefix, name);
funcs = function_registry_.FindLazyOverloadsByArity(name, receiver_style,
arity);
if (!funcs.empty()) {
return funcs;
}
}
return funcs;
}

absl::StatusOr<absl::optional<std::pair<std::string, cel::Type>>>
Resolver::FindType(absl::string_view name, int64_t expr_id) const {
auto qualified_names = FullyQualifiedNames(name, expr_id);
for (auto& qualified_name : qualified_names) {
auto prefixes = GetPrefixesFor(name);
for (auto& prefix : prefixes) {
std::string qualified_name = absl::StrCat(prefix, name);
CEL_ASSIGN_OR_RETURN(auto maybe_type,
type_reflector_.FindType(qualified_name));
if (maybe_type.has_value()) {
Expand Down
12 changes: 12 additions & 0 deletions eval/compiler/resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_
#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_

#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
Expand All @@ -24,6 +25,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "common/kind.h"
#include "common/type_reflector.h"
#include "common/value.h"
Expand Down Expand Up @@ -75,18 +77,28 @@ class Resolver {
absl::string_view name, bool receiver_style,
const std::vector<cel::Kind>& types, int64_t expr_id = -1) const;

std::vector<cel::FunctionRegistry::LazyOverload> FindLazyOverloads(
absl::string_view name, bool receiver_style, size_t arity,
int64_t expr_id = -1) const;

// FindOverloads returns the set, possibly empty, of eager function overloads
// matching the given function signature.
std::vector<cel::FunctionOverloadReference> FindOverloads(
absl::string_view name, bool receiver_style,
const std::vector<cel::Kind>& types, int64_t expr_id = -1) const;

std::vector<cel::FunctionOverloadReference> FindOverloads(
absl::string_view name, bool receiver_style, size_t arity,
int64_t expr_id = -1) const;

// FullyQualifiedNames returns the set of fully qualified names which may be
// derived from the base_name within the specified expression container.
std::vector<std::string> FullyQualifiedNames(absl::string_view base_name,
int64_t expr_id = -1) const;

private:
absl::Span<const std::string> GetPrefixesFor(absl::string_view& name) const;

std::vector<std::string> namespace_prefixes_;
absl::flat_hash_map<std::string, cel::Value> enum_value_map_;
const cel::FunctionRegistry& function_registry_;
Expand Down
65 changes: 59 additions & 6 deletions runtime/function_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "runtime/function_registry.h"

#include <cstddef>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -134,6 +135,27 @@ FunctionRegistry::FindStaticOverloads(absl::string_view name,
return matched_funcs;
}

std::vector<cel::FunctionOverloadReference>
FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name,
bool receiver_style,
size_t arity) const {
std::vector<cel::FunctionOverloadReference> matched_funcs;

auto overloads = functions_.find(name);
if (overloads == functions_.end()) {
return matched_funcs;
}

for (const auto& overload : overloads->second.static_overloads) {
if (overload.descriptor->receiver_style() == receiver_style &&
overload.descriptor->types().size() == arity) {
matched_funcs.push_back({*overload.descriptor, *overload.implementation});
}
}

return matched_funcs;
}

std::vector<FunctionRegistry::LazyOverload> FunctionRegistry::FindLazyOverloads(
absl::string_view name, bool receiver_style,
absl::Span<const cel::Kind> types) const {
Expand All @@ -153,6 +175,27 @@ std::vector<FunctionRegistry::LazyOverload> FunctionRegistry::FindLazyOverloads(
return matched_funcs;
}

std::vector<FunctionRegistry::LazyOverload>
FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name,
bool receiver_style,
size_t arity) const {
std::vector<FunctionRegistry::LazyOverload> matched_funcs;

auto overloads = functions_.find(name);
if (overloads == functions_.end()) {
return matched_funcs;
}

for (const auto& entry : overloads->second.lazy_overloads) {
if (entry.descriptor->receiver_style() == receiver_style &&
entry.descriptor->types().size() == arity) {
matched_funcs.push_back({*entry.descriptor, *entry.function_provider});
}
}

return matched_funcs;
}

absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
FunctionRegistry::ListFunctions() const {
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
Expand All @@ -177,12 +220,22 @@ FunctionRegistry::ListFunctions() const {

bool FunctionRegistry::DescriptorRegistered(
const cel::FunctionDescriptor& descriptor) const {
return !(FindStaticOverloads(descriptor.name(), descriptor.receiver_style(),
descriptor.types())
.empty()) ||
!(FindLazyOverloads(descriptor.name(), descriptor.receiver_style(),
descriptor.types())
.empty());
auto overloads = functions_.find(descriptor.name());
if (overloads == functions_.end()) {
return false;
}
const RegistryEntry& entry = overloads->second;
for (const auto& static_ovl : entry.static_overloads) {
if (static_ovl.descriptor->ShapeMatches(descriptor)) {
return true;
}
}
for (const auto& lazy_ovl : entry.lazy_overloads) {
if (lazy_ovl.descriptor->ShapeMatches(descriptor)) {
return true;
}
}
return false;
}

bool FunctionRegistry::ValidateNonStrictOverload(
Expand Down
8 changes: 8 additions & 0 deletions runtime/function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_
#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_

#include <cstddef>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -83,6 +84,9 @@ class FunctionRegistry {
absl::string_view name, bool receiver_style,
absl::Span<const cel::Kind> types) const;

std::vector<cel::FunctionOverloadReference> FindStaticOverloadsByArity(
absl::string_view name, bool receiver_style, size_t arity) const;

// Find subset of cel::Function providers that match overload conditions.
// As types may not be available during expression compilation,
// further narrowing of this subset will happen at evaluation stage.
Expand All @@ -98,6 +102,10 @@ class FunctionRegistry {
absl::string_view name, bool receiver_style,
absl::Span<const cel::Kind> types) const;

std::vector<LazyOverload> FindLazyOverloadsByArity(absl::string_view name,
bool receiver_style,
size_t arity) const;

// Retrieve list of registered function descriptors. This includes both
// static and lazy functions.
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
Expand Down