Skip to content

Commit 6117867

Browse files
jnthntatumcopybara-github
authored andcommitted
Remove temporary vectors in resolver logic for the planner.
PiperOrigin-RevId: 741157667
1 parent e661775 commit 6117867

File tree

6 files changed

+147
-27
lines changed

6 files changed

+147
-27
lines changed

eval/compiler/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,12 @@ cc_library(
419419
"//runtime:function_overload_reference",
420420
"//runtime:function_registry",
421421
"//runtime:type_registry",
422+
"@com_google_absl//absl/base:no_destructor",
422423
"@com_google_absl//absl/container:flat_hash_map",
423424
"@com_google_absl//absl/status:statusor",
424425
"@com_google_absl//absl/strings",
425426
"@com_google_absl//absl/types:optional",
427+
"@com_google_absl//absl/types:span",
426428
],
427429
)
428430

eval/compiler/flat_expr_builder.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,12 +1637,11 @@ class FlatExprVisitor : public cel::AstVisitor {
16371637
// Establish the search criteria for a given function.
16381638
bool receiver_style = call_expr->has_target();
16391639
size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0);
1640-
auto arguments_matcher = ArgumentsMatcher(num_args);
16411640

16421641
// First, search for lazily defined function overloads.
16431642
// Lazy functions shadow eager functions with the same signature.
16441643
auto lazy_overloads = resolver_.FindLazyOverloads(
1645-
function, call_expr->has_target(), arguments_matcher, expr->id());
1644+
function, call_expr->has_target(), num_args, expr->id());
16461645
if (!lazy_overloads.empty()) {
16471646
auto depth = RecursionEligible();
16481647
if (depth.has_value()) {
@@ -1659,8 +1658,8 @@ class FlatExprVisitor : public cel::AstVisitor {
16591658
}
16601659

16611660
// Second, search for eagerly defined function overloads.
1662-
auto overloads = resolver_.FindOverloads(function, receiver_style,
1663-
arguments_matcher, expr->id());
1661+
auto overloads =
1662+
resolver_.FindOverloads(function, receiver_style, num_args, expr->id());
16641663
if (overloads.empty()) {
16651664
// Create a warning that the overload could not be found. Depending on the
16661665
// builder_warnings configuration, this could result in termination of the

eval/compiler/resolver.cc

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
#include "eval/compiler/resolver.h"
1616

17+
#include <cstddef>
1718
#include <cstdint>
19+
#include <cstdio>
1820
#include <string>
1921
#include <utility>
2022
#include <vector>
2123

24+
#include "absl/base/no_destructor.h"
2225
#include "absl/container/flat_hash_map.h"
2326
#include "absl/status/statusor.h"
2427
#include "absl/strings/match.h"
@@ -27,6 +30,7 @@
2730
#include "absl/strings/string_view.h"
2831
#include "absl/strings/strip.h"
2932
#include "absl/types/optional.h"
33+
#include "absl/types/span.h"
3034
#include "common/kind.h"
3135
#include "common/type.h"
3236
#include "common/type_reflector.h"
@@ -96,37 +100,41 @@ std::vector<std::string> Resolver::FullyQualifiedNames(absl::string_view name,
96100
// and handle the case where this id is in the reference map as either a
97101
// function name or identifier name.
98102
std::vector<std::string> names;
99-
// Handle the case where the name contains a leading '.' indicating it is
100-
// already fully-qualified.
101-
if (absl::StartsWith(name, ".")) {
102-
std::string fully_qualified_name = std::string(name.substr(1));
103-
names.push_back(fully_qualified_name);
104-
return names;
105-
}
106103

107-
// namespace prefixes is guaranteed to contain at least empty string, so this
108-
// function will always produce at least one result.
109-
for (const auto& prefix : namespace_prefixes_) {
104+
auto prefixes = GetPrefixesFor(name);
105+
for (const auto& prefix : prefixes) {
110106
std::string fully_qualified_name = absl::StrCat(prefix, name);
111107
names.push_back(fully_qualified_name);
112108
}
113109
return names;
114110
}
115111

112+
absl::Span<const std::string> Resolver::GetPrefixesFor(
113+
absl::string_view& name) const {
114+
static const absl::NoDestructor<std::string> kEmptyPrefix("");
115+
if (absl::StartsWith(name, ".")) {
116+
name = name.substr(1);
117+
return absl::MakeConstSpan(kEmptyPrefix.get(), 1);
118+
}
119+
return namespace_prefixes_;
120+
}
121+
116122
absl::optional<cel::Value> Resolver::FindConstant(absl::string_view name,
117123
int64_t expr_id) const {
118-
auto names = FullyQualifiedNames(name, expr_id);
119-
for (const auto& name : names) {
124+
auto prefixes = GetPrefixesFor(name);
125+
for (const auto& prefix : prefixes) {
126+
std::string qualified_name = absl::StrCat(prefix, name);
120127
// Attempt to resolve the fully qualified name to a known enum.
121-
auto enum_entry = enum_value_map_.find(name);
128+
auto enum_entry = enum_value_map_.find(qualified_name);
122129
if (enum_entry != enum_value_map_.end()) {
123130
return enum_entry->second;
124131
}
125132
// Conditionally resolve fully qualified names as type values if the option
126133
// to do so is configured in the expression builder. If the type name is
127134
// not qualified, then it too may be returned as a constant value.
128-
if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) {
129-
auto type_value = type_reflector_.FindType(name);
135+
if (resolve_qualified_type_identifiers_ ||
136+
!absl::StrContains(qualified_name, ".")) {
137+
auto type_value = type_reflector_.FindType(qualified_name);
130138
if (type_value.ok() && type_value->has_value()) {
131139
return TypeValue(**type_value);
132140
}
@@ -157,6 +165,27 @@ std::vector<cel::FunctionOverloadReference> Resolver::FindOverloads(
157165
return funcs;
158166
}
159167

168+
std::vector<cel::FunctionOverloadReference> Resolver::FindOverloads(
169+
absl::string_view name, bool receiver_style, size_t arity,
170+
int64_t expr_id) const {
171+
std::vector<cel::FunctionOverloadReference> funcs;
172+
auto prefixes = GetPrefixesFor(name);
173+
for (const auto& prefix : prefixes) {
174+
std::string qualified_name = absl::StrCat(prefix, name);
175+
// Only one set of overloads is returned along the namespace hierarchy as
176+
// the function name resolution follows the same behavior as variable name
177+
// resolution, meaning the most specific definition wins. This is different
178+
// from how C++ namespaces work, as they will accumulate the overload set
179+
// over the namespace hierarchy.
180+
funcs = function_registry_.FindStaticOverloadsByArity(
181+
qualified_name, receiver_style, arity);
182+
if (!funcs.empty()) {
183+
return funcs;
184+
}
185+
}
186+
return funcs;
187+
}
188+
160189
std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
161190
absl::string_view name, bool receiver_style,
162191
const std::vector<cel::Kind>& types, int64_t expr_id) const {
@@ -173,10 +202,27 @@ std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
173202
return funcs;
174203
}
175204

205+
std::vector<cel::FunctionRegistry::LazyOverload> Resolver::FindLazyOverloads(
206+
absl::string_view name, bool receiver_style, size_t arity,
207+
int64_t expr_id) const {
208+
std::vector<cel::FunctionRegistry::LazyOverload> funcs;
209+
auto prefixes = GetPrefixesFor(name);
210+
for (const auto& prefix : prefixes) {
211+
std::string qualified_name = absl::StrCat(prefix, name);
212+
funcs = function_registry_.FindLazyOverloadsByArity(name, receiver_style,
213+
arity);
214+
if (!funcs.empty()) {
215+
return funcs;
216+
}
217+
}
218+
return funcs;
219+
}
220+
176221
absl::StatusOr<absl::optional<std::pair<std::string, cel::Type>>>
177222
Resolver::FindType(absl::string_view name, int64_t expr_id) const {
178-
auto qualified_names = FullyQualifiedNames(name, expr_id);
179-
for (auto& qualified_name : qualified_names) {
223+
auto prefixes = GetPrefixesFor(name);
224+
for (auto& prefix : prefixes) {
225+
std::string qualified_name = absl::StrCat(prefix, name);
180226
CEL_ASSIGN_OR_RETURN(auto maybe_type,
181227
type_reflector_.FindType(qualified_name));
182228
if (maybe_type.has_value()) {

eval/compiler/resolver.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_
1616
#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_
1717

18+
#include <cstddef>
1819
#include <cstdint>
1920
#include <string>
2021
#include <utility>
@@ -24,6 +25,7 @@
2425
#include "absl/status/statusor.h"
2526
#include "absl/strings/string_view.h"
2627
#include "absl/types/optional.h"
28+
#include "absl/types/span.h"
2729
#include "common/kind.h"
2830
#include "common/type_reflector.h"
2931
#include "common/value.h"
@@ -75,18 +77,28 @@ class Resolver {
7577
absl::string_view name, bool receiver_style,
7678
const std::vector<cel::Kind>& types, int64_t expr_id = -1) const;
7779

80+
std::vector<cel::FunctionRegistry::LazyOverload> FindLazyOverloads(
81+
absl::string_view name, bool receiver_style, size_t arity,
82+
int64_t expr_id = -1) const;
83+
7884
// FindOverloads returns the set, possibly empty, of eager function overloads
7985
// matching the given function signature.
8086
std::vector<cel::FunctionOverloadReference> FindOverloads(
8187
absl::string_view name, bool receiver_style,
8288
const std::vector<cel::Kind>& types, int64_t expr_id = -1) const;
8389

90+
std::vector<cel::FunctionOverloadReference> FindOverloads(
91+
absl::string_view name, bool receiver_style, size_t arity,
92+
int64_t expr_id = -1) const;
93+
8494
// FullyQualifiedNames returns the set of fully qualified names which may be
8595
// derived from the base_name within the specified expression container.
8696
std::vector<std::string> FullyQualifiedNames(absl::string_view base_name,
8797
int64_t expr_id = -1) const;
8898

8999
private:
100+
absl::Span<const std::string> GetPrefixesFor(absl::string_view& name) const;
101+
90102
std::vector<std::string> namespace_prefixes_;
91103
absl::flat_hash_map<std::string, cel::Value> enum_value_map_;
92104
const cel::FunctionRegistry& function_registry_;

runtime/function_registry.cc

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "runtime/function_registry.h"
1616

17+
#include <cstddef>
1718
#include <memory>
1819
#include <string>
1920
#include <utility>
@@ -134,6 +135,27 @@ FunctionRegistry::FindStaticOverloads(absl::string_view name,
134135
return matched_funcs;
135136
}
136137

138+
std::vector<cel::FunctionOverloadReference>
139+
FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name,
140+
bool receiver_style,
141+
size_t arity) const {
142+
std::vector<cel::FunctionOverloadReference> matched_funcs;
143+
144+
auto overloads = functions_.find(name);
145+
if (overloads == functions_.end()) {
146+
return matched_funcs;
147+
}
148+
149+
for (const auto& overload : overloads->second.static_overloads) {
150+
if (overload.descriptor->receiver_style() == receiver_style &&
151+
overload.descriptor->types().size() == arity) {
152+
matched_funcs.push_back({*overload.descriptor, *overload.implementation});
153+
}
154+
}
155+
156+
return matched_funcs;
157+
}
158+
137159
std::vector<FunctionRegistry::LazyOverload> FunctionRegistry::FindLazyOverloads(
138160
absl::string_view name, bool receiver_style,
139161
absl::Span<const cel::Kind> types) const {
@@ -153,6 +175,27 @@ std::vector<FunctionRegistry::LazyOverload> FunctionRegistry::FindLazyOverloads(
153175
return matched_funcs;
154176
}
155177

178+
std::vector<FunctionRegistry::LazyOverload>
179+
FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name,
180+
bool receiver_style,
181+
size_t arity) const {
182+
std::vector<FunctionRegistry::LazyOverload> matched_funcs;
183+
184+
auto overloads = functions_.find(name);
185+
if (overloads == functions_.end()) {
186+
return matched_funcs;
187+
}
188+
189+
for (const auto& entry : overloads->second.lazy_overloads) {
190+
if (entry.descriptor->receiver_style() == receiver_style &&
191+
entry.descriptor->types().size() == arity) {
192+
matched_funcs.push_back({*entry.descriptor, *entry.function_provider});
193+
}
194+
}
195+
196+
return matched_funcs;
197+
}
198+
156199
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
157200
FunctionRegistry::ListFunctions() const {
158201
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>
@@ -177,12 +220,22 @@ FunctionRegistry::ListFunctions() const {
177220

178221
bool FunctionRegistry::DescriptorRegistered(
179222
const cel::FunctionDescriptor& descriptor) const {
180-
return !(FindStaticOverloads(descriptor.name(), descriptor.receiver_style(),
181-
descriptor.types())
182-
.empty()) ||
183-
!(FindLazyOverloads(descriptor.name(), descriptor.receiver_style(),
184-
descriptor.types())
185-
.empty());
223+
auto overloads = functions_.find(descriptor.name());
224+
if (overloads == functions_.end()) {
225+
return false;
226+
}
227+
const RegistryEntry& entry = overloads->second;
228+
for (const auto& static_ovl : entry.static_overloads) {
229+
if (static_ovl.descriptor->ShapeMatches(descriptor)) {
230+
return true;
231+
}
232+
}
233+
for (const auto& lazy_ovl : entry.lazy_overloads) {
234+
if (lazy_ovl.descriptor->ShapeMatches(descriptor)) {
235+
return true;
236+
}
237+
}
238+
return false;
186239
}
187240

188241
bool FunctionRegistry::ValidateNonStrictOverload(

runtime/function_registry.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_
1616
#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_
1717

18+
#include <cstddef>
1819
#include <memory>
1920
#include <string>
2021
#include <utility>
@@ -83,6 +84,9 @@ class FunctionRegistry {
8384
absl::string_view name, bool receiver_style,
8485
absl::Span<const cel::Kind> types) const;
8586

87+
std::vector<cel::FunctionOverloadReference> FindStaticOverloadsByArity(
88+
absl::string_view name, bool receiver_style, size_t arity) const;
89+
8690
// Find subset of cel::Function providers that match overload conditions.
8791
// As types may not be available during expression compilation,
8892
// further narrowing of this subset will happen at evaluation stage.
@@ -98,6 +102,10 @@ class FunctionRegistry {
98102
absl::string_view name, bool receiver_style,
99103
absl::Span<const cel::Kind> types) const;
100104

105+
std::vector<LazyOverload> FindLazyOverloadsByArity(absl::string_view name,
106+
bool receiver_style,
107+
size_t arity) const;
108+
101109
// Retrieve list of registered function descriptors. This includes both
102110
// static and lazy functions.
103111
absl::node_hash_map<std::string, std::vector<const cel::FunctionDescriptor*>>

0 commit comments

Comments
 (0)