diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs
new file mode 100644
index 0000000..3ab54d7
--- /dev/null
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs
@@ -0,0 +1,56 @@
+using System.Linq.Expressions;
+using System.Reflection;
+using BenchmarkDotNet.Attributes;
+using EntityFrameworkCore.Projectables.Benchmarks.Helpers;
+using EntityFrameworkCore.Projectables.Services;
+
+namespace EntityFrameworkCore.Projectables.Benchmarks
+{
+ ///
+ /// Micro-benchmarks in
+ /// isolation (no EF Core overhead) to directly compare the static registry path against
+ /// the reflection-based path ().
+ ///
+ [MemoryDiagnoser]
+ public class ExpressionResolverBenchmark
+ {
+ private static readonly MemberInfo _propertyMember =
+ typeof(TestEntity).GetProperty(nameof(TestEntity.IdPlus1))!;
+
+ private static readonly MemberInfo _methodMember =
+ typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlus1Method))!;
+
+ private static readonly MemberInfo _methodWithParamMember =
+ typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlusDelta), new[] { typeof(int) })!;
+
+ private readonly ProjectionExpressionResolver _resolver = new();
+
+ // ── Registry (source-generated) path ─────────────────────────────────
+
+ [Benchmark(Baseline = true)]
+ public LambdaExpression? ResolveProperty_Registry()
+ => _resolver.FindGeneratedExpression(_propertyMember);
+
+ [Benchmark]
+ public LambdaExpression? ResolveMethod_Registry()
+ => _resolver.FindGeneratedExpression(_methodMember);
+
+ [Benchmark]
+ public LambdaExpression? ResolveMethodWithParam_Registry()
+ => _resolver.FindGeneratedExpression(_methodWithParamMember);
+
+ // ── Reflection path ───────────────────────────────────────────────────
+
+ [Benchmark]
+ public LambdaExpression? ResolveProperty_Reflection()
+ => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_propertyMember);
+
+ [Benchmark]
+ public LambdaExpression? ResolveMethod_Reflection()
+ => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodMember);
+
+ [Benchmark]
+ public LambdaExpression? ResolveMethodWithParam_Reflection()
+ => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodWithParamMember);
+ }
+}
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs
index 4bc741c..68a04d8 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs
@@ -15,5 +15,8 @@ public class TestEntity
[Projectable]
public int IdPlus1Method() => Id + 1;
+
+ [Projectable]
+ public int IdPlusDelta(int delta) => Id + delta;
}
}
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs
index d064b7d..b9cda85 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs
@@ -9,6 +9,7 @@
namespace EntityFrameworkCore.Projectables.Benchmarks
{
+ [MemoryDiagnoser]
public class PlainOverhead
{
[Benchmark(Baseline = true)]
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs
index fdb0f9f..467d470 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs
@@ -9,6 +9,7 @@
namespace EntityFrameworkCore.Projectables.Benchmarks
{
+ [MemoryDiagnoser]
public class ProjectableExtensionMethods
{
const int innerLoop = 10000;
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs
index 785b52c..0618627 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs
@@ -8,6 +8,7 @@
namespace EntityFrameworkCore.Projectables.Benchmarks
{
+ [MemoryDiagnoser]
public class ProjectableMethods
{
const int innerLoop = 10000;
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs
index 6f2fa57..3a2e2be 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs
@@ -8,6 +8,7 @@
namespace EntityFrameworkCore.Projectables.Benchmarks
{
+ [MemoryDiagnoser]
public class ProjectableProperties
{
const int innerLoop = 10000;
diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs
new file mode 100644
index 0000000..3876641
--- /dev/null
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs
@@ -0,0 +1,72 @@
+using System.Linq;
+using BenchmarkDotNet.Attributes;
+using EntityFrameworkCore.Projectables.Benchmarks.Helpers;
+using Microsoft.EntityFrameworkCore;
+
+namespace EntityFrameworkCore.Projectables.Benchmarks
+{
+ ///
+ /// Measures the per-DbContext cold-start cost of resolver lookup by creating a new
+ /// on every iteration. The previous benchmarks reuse a single
+ /// DbContext for 10 000 iterations, so the resolver cache is warm after the first query —
+ /// these benchmarks expose the cost of the very first query per context.
+ ///
+ [MemoryDiagnoser]
+ public class ResolverOverhead
+ {
+ const int Iterations = 1000;
+
+ /// Baseline: no projectables, new DbContext per query.
+ [Benchmark(Baseline = true)]
+ public void WithoutProjectables_FreshDbContext()
+ {
+ for (int i = 0; i < Iterations; i++)
+ {
+ using var dbContext = new TestDbContext(false);
+ dbContext.Entities.Select(x => x.Id + 1).ToQueryString();
+ }
+ }
+
+ ///
+ /// New DbContext per query with a projectable property.
+ /// After the registry is in place this should approach baseline overhead.
+ ///
+ [Benchmark]
+ public void WithProjectables_FreshDbContext_Property()
+ {
+ for (int i = 0; i < Iterations; i++)
+ {
+ using var dbContext = new TestDbContext(true, false);
+ dbContext.Entities.Select(x => x.IdPlus1).ToQueryString();
+ }
+ }
+
+ ///
+ /// New DbContext per query with a projectable method.
+ /// After the registry is in place this should approach baseline overhead.
+ ///
+ [Benchmark]
+ public void WithProjectables_FreshDbContext_Method()
+ {
+ for (int i = 0; i < Iterations; i++)
+ {
+ using var dbContext = new TestDbContext(true, false);
+ dbContext.Entities.Select(x => x.IdPlus1Method()).ToQueryString();
+ }
+ }
+
+ ///
+ /// New DbContext per query with a projectable method that takes a parameter,
+ /// exercising parameter-type disambiguation in the registry key.
+ ///
+ [Benchmark]
+ public void WithProjectables_FreshDbContext_MethodWithParam()
+ {
+ for (int i = 0; i < Iterations; i++)
+ {
+ using var dbContext = new TestDbContext(true, false);
+ dbContext.Entities.Select(x => x.IdPlusDelta(5)).ToQueryString();
+ }
+ }
+ }
+}
diff --git a/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs
new file mode 100644
index 0000000..bd4930e
--- /dev/null
+++ b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs
@@ -0,0 +1,6 @@
+// Polyfill for C# 9 record types when targeting netstandard2.0 or netstandard2.1
+// The compiler requires this type to exist in order to use init-only setters (used by records).
+namespace System.Runtime.CompilerServices
+{
+ internal sealed class IsExternalInit { }
+}
diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs
new file mode 100644
index 0000000..37a45e0
--- /dev/null
+++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs
@@ -0,0 +1,61 @@
+using System.Collections.Immutable;
+using System.Linq;
+
+namespace EntityFrameworkCore.Projectables.Generator
+{
+ ///
+ /// Incremental-pipeline-safe representation of a single projectable member.
+ /// Contains only primitive types and an equatable wrapper around
+ /// so that structural value equality works correctly across incremental generation steps.
+ ///
+ sealed internal record ProjectableRegistryEntry(
+ string DeclaringTypeFullName,
+ string MemberKind,
+ string MemberLookupName,
+ string GeneratedClassFullName,
+ bool IsGenericClass,
+ bool IsGenericMethod,
+ EquatableImmutableArray ParameterTypeNames
+ );
+
+ ///
+ /// A structural-equality wrapper around of strings.
+ /// uses reference equality by default, which breaks
+ /// Roslyn's incremental-source-generator caching when the same logical array is
+ /// produced by two different steps. This wrapper provides element-wise equality so
+ /// that incremental steps are correctly cached and skipped.
+ ///
+ internal readonly struct EquatableImmutableArray : System.IEquatable
+ {
+ public static readonly EquatableImmutableArray Empty = new(ImmutableArray.Empty);
+
+ public readonly ImmutableArray Array;
+
+ public EquatableImmutableArray(ImmutableArray array)
+ {
+ Array = array;
+ }
+
+ public bool IsDefaultOrEmpty => Array.IsDefaultOrEmpty;
+
+ public bool Equals(EquatableImmutableArray other) =>
+ Array.SequenceEqual(other.Array);
+
+ public override bool Equals(object? obj) =>
+ obj is EquatableImmutableArray other && Equals(other);
+
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ var hash = 17;
+ foreach (var s in Array)
+ hash = hash * 31 + (s?.GetHashCode() ?? 0);
+ return hash;
+ }
+ }
+
+ public static implicit operator ImmutableArray(EquatableImmutableArray e) => e.Array;
+ public static implicit operator EquatableImmutableArray(ImmutableArray a) => new(a);
+ }
+}
diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs
index 440caa5..0c83141 100644
--- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs
+++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs
@@ -3,6 +3,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
+using System.Collections.Immutable;
using System.Text;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
@@ -32,7 +33,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Do a simple filter for members
- IncrementalValuesProvider memberDeclarations = context.SyntaxProvider
+ var memberDeclarations = context.SyntaxProvider
.ForAttributeWithMetadataName(
ProjectablesAttributeName,
predicate: static (s, _) => s is MemberDeclarationSyntax,
@@ -40,16 +41,28 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(new MemberDeclarationSyntaxEqualityComparer());
// Combine the selected enums with the `Compilation`
- IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations
+ var compilationAndMemberPairs = memberDeclarations
.Combine(context.CompilationProvider)
.WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer());
// Generate the source using the compilation and enums
context.RegisterImplementationSourceOutput(compilationAndMemberPairs,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
+
+ // Build the projection registry: collect all entries and emit a single registry file
+ var registryEntries =
+ compilationAndMemberPairs.Select(
+ static (pair, _) => ExtractRegistryEntry(pair.Item1, pair.Item2));
+
+ var allEntries =
+ registryEntries.Collect();
+
+ context.RegisterImplementationSourceOutput(
+ allEntries,
+ static (spc, entries) => EmitRegistry(entries, spc));
}
- static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation)
+ private static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation)
{
var chain = CollectConstructorChain(ctor, compilation);
@@ -90,7 +103,7 @@ void AddLine(string text)
/// then its delegate's delegate, …). Stops when a delegated constructor has no source
/// available in the compilation (e.g. a compiler-synthesised parameterless constructor).
///
- static IReadOnlyList CollectConstructorChain(
+ private static List CollectConstructorChain(
ConstructorDeclarationSyntax ctor, Compilation compilation)
{
var result = new List { ctor };
@@ -101,7 +114,9 @@ static IReadOnlyList CollectConstructorChain(
{
var semanticModel = compilation.GetSemanticModel(current.SyntaxTree);
if (semanticModel.GetSymbolInfo(initializer).Symbol is not IMethodSymbol delegated)
+ {
break;
+ }
var delegatedSyntax = delegated.DeclaringSyntaxReferences
.Select(r => r.GetSyntax())
@@ -109,7 +124,9 @@ static IReadOnlyList CollectConstructorChain(
.FirstOrDefault();
if (delegatedSyntax is null || !visited.Add(delegatedSyntax))
+ {
break;
+ }
result.Add(delegatedSyntax);
current = delegatedSyntax;
@@ -118,7 +135,7 @@ static IReadOnlyList CollectConstructorChain(
return result;
}
- static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context)
+ private static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context)
{
var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context);
@@ -146,40 +163,40 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou
.WithLeadingTrivia(member is ConstructorDeclarationSyntax ctor ? BuildSourceDocComment(ctor, compilation) : TriviaList())
.AddMembers(
MethodDeclaration(
- GenericName(
- Identifier("global::System.Linq.Expressions.Expression"),
- TypeArgumentList(
- SingletonSeparatedList(
- (TypeSyntax)GenericName(
- Identifier("global::System.Func"),
- GetLambdaTypeArgumentListSyntax(projectable)
+ GenericName(
+ Identifier("global::System.Linq.Expressions.Expression"),
+ TypeArgumentList(
+ SingletonSeparatedList(
+ (TypeSyntax)GenericName(
+ Identifier("global::System.Func"),
+ GetLambdaTypeArgumentListSyntax(projectable)
+ )
)
)
- )
- ),
- "Expression"
- )
- .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
- .WithTypeParameterList(projectable.TypeParameterList)
- .WithConstraintClauses(projectable.ConstraintClauses ?? List())
- .WithBody(
- Block(
- ReturnStatement(
- ParenthesizedLambdaExpression(
- projectable.ParametersList ?? ParameterList(),
- null,
- projectable.ExpressionBody
+ ),
+ "Expression"
+ )
+ .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
+ .WithTypeParameterList(projectable.TypeParameterList)
+ .WithConstraintClauses(projectable.ConstraintClauses ?? List())
+ .WithBody(
+ Block(
+ ReturnStatement(
+ ParenthesizedLambdaExpression(
+ projectable.ParametersList ?? ParameterList(),
+ null,
+ projectable.ExpressionBody
+ )
)
)
)
- )
);
#nullable disable
var compilationUnit = CompilationUnit();
- foreach (var usingDirective in projectable.UsingDirectives)
+ foreach (var usingDirective in projectable.UsingDirectives!)
{
compilationUnit = compilationUnit.AddUsings(usingDirective);
}
@@ -226,5 +243,419 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip
return lambdaTypeArguments;
}
}
+
+#nullable restore
+
+ ///
+ /// Extracts a from a member declaration.
+ /// Returns null when the member does not have [Projectable], is an extension member,
+ /// or cannot be represented in the registry (e.g. a generic class member or generic method).
+ ///
+ private static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation)
+ {
+ var semanticModel = compilation.GetSemanticModel(member.SyntaxTree);
+ var memberSymbol = semanticModel.GetDeclaredSymbol(member);
+
+ if (memberSymbol is null)
+ {
+ return null;
+ }
+
+ // Verify [Projectable] attribute
+ var projectableAttributeTypeSymbol = compilation.GetTypeByMetadataName("EntityFrameworkCore.Projectables.ProjectableAttribute");
+ var projectableAttribute = memberSymbol.GetAttributes()
+ .FirstOrDefault(x => x.AttributeClass?.Name == "ProjectableAttribute");
+
+ if (projectableAttribute is null ||
+ !SymbolEqualityComparer.Default.Equals(projectableAttribute.AttributeClass, projectableAttributeTypeSymbol))
+ {
+ return null;
+ }
+
+ // Skip C# 14 extension type members — they require special handling (fall back to reflection)
+ if (memberSymbol.ContainingType is { IsExtension: true })
+ {
+ return null;
+ }
+
+ var containingType = memberSymbol.ContainingType;
+ var isGenericClass = containingType.TypeParameters.Length > 0;
+
+ // Determine member kind and lookup name
+ string memberKind;
+ string memberLookupName;
+ var parameterTypeNames = ImmutableArray.Empty;
+ var methodTypeParamCount = 0;
+ var isGenericMethod = false;
+
+ if (memberSymbol is IMethodSymbol methodSymbol)
+ {
+ isGenericMethod = methodSymbol.TypeParameters.Length > 0;
+ methodTypeParamCount = methodSymbol.TypeParameters.Length;
+
+ if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor)
+ {
+ memberKind = "Constructor";
+ memberLookupName = "_ctor";
+ }
+ else
+ {
+ memberKind = "Method";
+ memberLookupName = memberSymbol.Name;
+ }
+
+ parameterTypeNames = [
+ ..methodSymbol.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
+ ];
+ }
+ else
+ {
+ memberKind = "Property";
+ memberLookupName = memberSymbol.Name;
+ }
+
+ // Build the generated class name using the same logic as Execute
+ var classNamespace = containingType.ContainingNamespace.IsGlobalNamespace
+ ? null
+ : containingType.ContainingNamespace.ToDisplayString();
+
+ var nestedTypePath = GetRegistryNestedTypePath(containingType);
+
+ var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(
+ classNamespace,
+ nestedTypePath,
+ memberLookupName,
+ parameterTypeNames.IsEmpty ? null : parameterTypeNames);
+
+ var generatedClassFullName = "EntityFrameworkCore.Projectables.Generated." + generatedClassName;
+
+ var declaringTypeFullName = containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+
+ return new ProjectableRegistryEntry(
+ DeclaringTypeFullName: declaringTypeFullName,
+ MemberKind: memberKind,
+ MemberLookupName: memberLookupName,
+ GeneratedClassFullName: generatedClassFullName,
+ IsGenericClass: isGenericClass,
+ IsGenericMethod: isGenericMethod,
+ ParameterTypeNames: parameterTypeNames);
+ }
+
+ private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol)
+ {
+ if (typeSymbol.ContainingType is not null)
+ {
+ foreach (var name in GetRegistryNestedTypePath(typeSymbol.ContainingType))
+ {
+ yield return name;
+ }
+ }
+ yield return typeSymbol.Name;
+ }
+
+ ///
+ /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members
+ /// into a single static dictionary keyed by .
+ /// Uses SyntaxFactory for the class/method/field structure, consistent with .
+ /// The generated Build() method uses a shared Register helper to avoid repeating
+ /// the lookup boilerplate for every entry.
+ ///
+ private static void EmitRegistry(ImmutableArray entries, SourceProductionContext context)
+ {
+ // Build the per-entry Register(...) statements first so we can bail out early
+ // if every entry is generic (they all fall back to reflection, no registry needed).
+ var entryStatements = entries
+ .Where(e => e is not null)
+ .Select(e => BuildRegistryEntryStatement(e!))
+ .Where(s => s is not null)
+ .Select(s => s!)
+ .ToList();
+
+ if (entryStatements.Count == 0)
+ {
+ return;
+ }
+
+ // Build() body:
+ // const BindingFlags allFlags = ...;
+ // var map = new Dictionary();
+ // Register(map, typeof(T).GetXxx(...), "ClassName"); ← one line per entry
+ // return map;
+ var buildStatements = new List
+ {
+ // const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | ...;
+ LocalDeclarationStatement(
+ VariableDeclaration(ParseTypeName("BindingFlags"))
+ .AddVariables(
+ VariableDeclarator("allFlags")
+ .WithInitializer(EqualsValueClause(
+ ParseExpression(
+ "BindingFlags.Public | BindingFlags.NonPublic | " +
+ "BindingFlags.Instance | BindingFlags.Static")))))
+ .WithModifiers(TokenList(Token(SyntaxKind.ConstKeyword))),
+
+ // var map = new Dictionary();
+ LocalDeclarationStatement(
+ VariableDeclaration(ParseTypeName("var"))
+ .AddVariables(
+ VariableDeclarator("map")
+ .WithInitializer(EqualsValueClause(
+ ObjectCreationExpression(
+ ParseTypeName("Dictionary"))
+ .WithArgumentList(ArgumentList()))))),
+ };
+
+ buildStatements.AddRange(entryStatements);
+ buildStatements.Add(ReturnStatement(IdentifierName("map")));
+
+ var classSyntax = ClassDeclaration("ProjectionRegistry")
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.InternalKeyword),
+ Token(SyntaxKind.StaticKeyword)))
+ .AddAttributeLists(AttributeList().AddAttributes(_editorBrowsableAttribute))
+ .AddMembers(
+ // private static readonly Dictionary _map = Build();
+ FieldDeclaration(
+ VariableDeclaration(ParseTypeName("Dictionary"))
+ .AddVariables(
+ VariableDeclarator("_map")
+ .WithInitializer(EqualsValueClause(
+ InvocationExpression(IdentifierName("Build"))
+ .WithArgumentList(ArgumentList())))))
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.StaticKeyword),
+ Token(SyntaxKind.ReadOnlyKeyword))),
+
+ // public static LambdaExpression TryGet(MemberInfo member)
+ MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet")
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PublicKeyword),
+ Token(SyntaxKind.StaticKeyword)))
+ .AddParameterListParameters(
+ Parameter(Identifier("member"))
+ .WithType(ParseTypeName("MemberInfo")))
+ .WithBody(Block(
+ LocalDeclarationStatement(
+ VariableDeclaration(ParseTypeName("var"))
+ .AddVariables(
+ VariableDeclarator("handle")
+ .WithInitializer(EqualsValueClause(
+ InvocationExpression(IdentifierName("GetHandle"))
+ .AddArgumentListArguments(
+ Argument(IdentifierName("member"))))))),
+ ReturnStatement(
+ ParseExpression(
+ "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))),
+
+ // private static nint? GetHandle(MemberInfo member) => member switch { ... };
+ MethodDeclaration(ParseTypeName("nint?"), "GetHandle")
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.StaticKeyword)))
+ .AddParameterListParameters(
+ Parameter(Identifier("member"))
+ .WithType(ParseTypeName("MemberInfo")))
+ .WithExpressionBody(ArrowExpressionClause(
+ SwitchExpression(IdentifierName("member"))
+ .WithArms(SeparatedList(
+ new SyntaxNodeOrToken[]
+ {
+ SwitchExpressionArm(
+ DeclarationPattern(
+ ParseTypeName("MethodInfo"),
+ SingleVariableDesignation(Identifier("m"))),
+ ParseExpression("m.MethodHandle.Value")),
+ Token(SyntaxKind.CommaToken),
+ SwitchExpressionArm(
+ DeclarationPattern(
+ ParseTypeName("PropertyInfo"),
+ SingleVariableDesignation(Identifier("p"))),
+ ParseExpression("p.GetMethod?.MethodHandle.Value")),
+ Token(SyntaxKind.CommaToken),
+ SwitchExpressionArm(
+ DeclarationPattern(
+ ParseTypeName("ConstructorInfo"),
+ SingleVariableDesignation(Identifier("c"))),
+ ParseExpression("c.MethodHandle.Value")),
+ Token(SyntaxKind.CommaToken),
+ SwitchExpressionArm(
+ DiscardPattern(),
+ LiteralExpression(SyntaxKind.NullLiteralExpression))
+ }))))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)),
+
+ // private static Dictionary Build() { ... }
+ MethodDeclaration(ParseTypeName("Dictionary"), "Build")
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.StaticKeyword)))
+ .WithBody(Block(buildStatements)),
+
+ // private static void Register(Dictionary map, MethodBase m, string exprClass)
+ BuildRegisterHelperMethod());
+
+ var compilationUnit = CompilationUnit()
+ .AddUsings(
+ UsingDirective(ParseName("System")),
+ UsingDirective(ParseName("System.Collections.Generic")),
+ UsingDirective(ParseName("System.Linq.Expressions")),
+ UsingDirective(ParseName("System.Reflection")))
+ .AddMembers(
+ NamespaceDeclaration(ParseName("EntityFrameworkCore.Projectables.Generated"))
+ .AddMembers(classSyntax))
+ .WithLeadingTrivia(TriviaList(
+ Comment("// "),
+ Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))));
+
+ context.AddSource("ProjectionRegistry.g.cs",
+ SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));
+ }
+
+ ///
+ /// Builds a single compact Register(map, typeof(T).GetXxx(...), "ClassName")
+ /// statement for one projectable entry in Build().
+ /// Returns for generic class/method entries (they fall back to reflection).
+ ///
+ private static ExpressionStatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry)
+ {
+ if (entry.IsGenericClass || entry.IsGenericMethod)
+ {
+ return null;
+ }
+
+ // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...)
+ ExpressionSyntax? memberCallExpr = entry.MemberKind switch
+ {
+ // typeof(T).GetProperty("Name", allFlags)?.GetMethod
+ "Property" => ConditionalAccessExpression(
+ InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)),
+ IdentifierName("GetProperty")))
+ .AddArgumentListArguments(
+ Argument(LiteralExpression(SyntaxKind.StringLiteralExpression,
+ Literal(entry.MemberLookupName))),
+ Argument(IdentifierName("allFlags"))),
+ MemberBindingExpression(IdentifierName("GetMethod"))),
+
+ // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null)
+ "Method" => InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)),
+ IdentifierName("GetMethod")))
+ .AddArgumentListArguments(
+ Argument(LiteralExpression(SyntaxKind.StringLiteralExpression,
+ Literal(entry.MemberLookupName))),
+ Argument(IdentifierName("allFlags")),
+ Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
+ Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)),
+ Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))),
+
+ // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null)
+ "Constructor" => InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)),
+ IdentifierName("GetConstructor")))
+ .AddArgumentListArguments(
+ Argument(IdentifierName("allFlags")),
+ Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
+ Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)),
+ Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))),
+
+ _ => null
+ };
+
+ if (memberCallExpr is null)
+ {
+ return null;
+ }
+
+ // Register(map, , "");
+ return ExpressionStatement(
+ InvocationExpression(IdentifierName("Register"))
+ .AddArgumentListArguments(
+ Argument(IdentifierName("map")),
+ Argument(memberCallExpr),
+ Argument(LiteralExpression(SyntaxKind.StringLiteralExpression,
+ Literal(entry.GeneratedClassFullName)))));
+ }
+
+ ///
+ /// Builds the Register private static helper method that all per-entry calls delegate to.
+ /// It handles the null checks and the common reflection lookup pattern once, centrally.
+ ///
+ private static MethodDeclarationSyntax BuildRegisterHelperMethod() =>
+ // private static void Register(Dictionary map, MethodBase m, string exprClass)
+ // {
+ // if (m is null) return;
+ // var exprType = m.DeclaringType?.Assembly.GetType(exprClass);
+ // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic);
+ // if (exprMethod is not null)
+ // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!;
+ // }
+ MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Register")
+ .WithModifiers(TokenList(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.StaticKeyword)))
+ .AddParameterListParameters(
+ Parameter(Identifier("map"))
+ .WithType(ParseTypeName("Dictionary")),
+ Parameter(Identifier("m"))
+ .WithType(ParseTypeName("MethodBase")),
+ Parameter(Identifier("exprClass"))
+ .WithType(PredefinedType(Token(SyntaxKind.StringKeyword))))
+ .WithBody(Block(
+ // if (m is null) return;
+ IfStatement(
+ IsPatternExpression(
+ IdentifierName("m"),
+ ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))),
+ ReturnStatement()),
+ // var exprType = m.DeclaringType?.Assembly.GetType(exprClass);
+ LocalDeclarationStatement(
+ VariableDeclaration(ParseTypeName("var"))
+ .AddVariables(
+ VariableDeclarator("exprType")
+ .WithInitializer(EqualsValueClause(
+ ParseExpression("m.DeclaringType?.Assembly.GetType(exprClass)"))))),
+ // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic);
+ LocalDeclarationStatement(
+ VariableDeclaration(ParseTypeName("var"))
+ .AddVariables(
+ VariableDeclarator("exprMethod")
+ .WithInitializer(EqualsValueClause(
+ ParseExpression(
+ @"exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic)"))))),
+ // if (exprMethod is not null)
+ // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!;
+ IfStatement(
+ ParseExpression("exprMethod is not null"),
+ ExpressionStatement(
+ ParseExpression(
+ "map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!")))));
+
+ ///
+ /// Builds the typeof(...)-array expression used for reflection method/constructor lookup.
+ /// Returns global::System.Type.EmptyTypes when there are no parameters.
+ ///
+ private static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames)
+ {
+ if (parameterTypeNames.IsEmpty)
+ {
+ return ParseExpression("global::System.Type.EmptyTypes");
+ }
+
+ var typeofExprs = parameterTypeNames
+ .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name)))
+ .ToArray();
+
+ return ArrayCreationExpression(
+ ArrayType(ParseTypeName("global::System.Type"))
+ .AddRankSpecifiers(ArrayRankSpecifier()))
+ .WithInitializer(
+ InitializerExpression(SyntaxKind.ArrayInitializerExpression,
+ SeparatedList(typeofExprs)));
+ }
}
-}
+}
\ No newline at end of file
diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
index 16e547a..e342a6e 100644
--- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
+++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
@@ -21,27 +21,19 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
private readonly bool _trackingByDefault;
private IEntityType? _entityType;
- private readonly MethodInfo _select;
- private readonly MethodInfo _where;
+ // Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain)
+ private static readonly MethodInfo _select =
+ ((MethodCallExpression)((Expression, IQueryable