diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs new file mode 100644 index 0000000..3ab54d7 --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ExpressionResolverBenchmark.cs @@ -0,0 +1,56 @@ +using System.Linq.Expressions; +using System.Reflection; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using EntityFrameworkCore.Projectables.Services; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Micro-benchmarks in + /// isolation (no EF Core overhead) to directly compare the static registry path against + /// the reflection-based path (). + /// + [MemoryDiagnoser] + public class ExpressionResolverBenchmark + { + private static readonly MemberInfo _propertyMember = + typeof(TestEntity).GetProperty(nameof(TestEntity.IdPlus1))!; + + private static readonly MemberInfo _methodMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlus1Method))!; + + private static readonly MemberInfo _methodWithParamMember = + typeof(TestEntity).GetMethod(nameof(TestEntity.IdPlusDelta), new[] { typeof(int) })!; + + private readonly ProjectionExpressionResolver _resolver = new(); + + // ── Registry (source-generated) path ───────────────────────────────── + + [Benchmark(Baseline = true)] + public LambdaExpression? ResolveProperty_Registry() + => _resolver.FindGeneratedExpression(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod_Registry() + => _resolver.FindGeneratedExpression(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam_Registry() + => _resolver.FindGeneratedExpression(_methodWithParamMember); + + // ── Reflection path ─────────────────────────────────────────────────── + + [Benchmark] + public LambdaExpression? ResolveProperty_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_propertyMember); + + [Benchmark] + public LambdaExpression? ResolveMethod_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodMember); + + [Benchmark] + public LambdaExpression? ResolveMethodWithParam_Reflection() + => ProjectionExpressionResolver.FindGeneratedExpressionViaReflection(_methodWithParamMember); + } +} diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs index 4bc741c..68a04d8 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/Helpers/TestEntity.cs @@ -15,5 +15,8 @@ public class TestEntity [Projectable] public int IdPlus1Method() => Id + 1; + + [Projectable] + public int IdPlusDelta(int delta) => Id + delta; } } diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs index d064b7d..b9cda85 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/PlainOverhead.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class PlainOverhead { [Benchmark(Baseline = true)] diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs index fdb0f9f..467d470 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableExtensionMethods.cs @@ -9,6 +9,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableExtensionMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs index 785b52c..0618627 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableMethods.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableMethods { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs index 6f2fa57..3a2e2be 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ProjectableProperties.cs @@ -8,6 +8,7 @@ namespace EntityFrameworkCore.Projectables.Benchmarks { + [MemoryDiagnoser] public class ProjectableProperties { const int innerLoop = 10000; diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs new file mode 100644 index 0000000..3876641 --- /dev/null +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/ResolverOverhead.cs @@ -0,0 +1,72 @@ +using System.Linq; +using BenchmarkDotNet.Attributes; +using EntityFrameworkCore.Projectables.Benchmarks.Helpers; +using Microsoft.EntityFrameworkCore; + +namespace EntityFrameworkCore.Projectables.Benchmarks +{ + /// + /// Measures the per-DbContext cold-start cost of resolver lookup by creating a new + /// on every iteration. The previous benchmarks reuse a single + /// DbContext for 10 000 iterations, so the resolver cache is warm after the first query — + /// these benchmarks expose the cost of the very first query per context. + /// + [MemoryDiagnoser] + public class ResolverOverhead + { + const int Iterations = 1000; + + /// Baseline: no projectables, new DbContext per query. + [Benchmark(Baseline = true)] + public void WithoutProjectables_FreshDbContext() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(false); + dbContext.Entities.Select(x => x.Id + 1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable property. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Property() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method. + /// After the registry is in place this should approach baseline overhead. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_Method() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlus1Method()).ToQueryString(); + } + } + + /// + /// New DbContext per query with a projectable method that takes a parameter, + /// exercising parameter-type disambiguation in the registry key. + /// + [Benchmark] + public void WithProjectables_FreshDbContext_MethodWithParam() + { + for (int i = 0; i < Iterations; i++) + { + using var dbContext = new TestDbContext(true, false); + dbContext.Entities.Select(x => x.IdPlusDelta(5)).ToQueryString(); + } + } + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs new file mode 100644 index 0000000..bd4930e --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/IsExternalInit.cs @@ -0,0 +1,6 @@ +// Polyfill for C# 9 record types when targeting netstandard2.0 or netstandard2.1 +// The compiler requires this type to exist in order to use init-only setters (used by records). +namespace System.Runtime.CompilerServices +{ + internal sealed class IsExternalInit { } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs new file mode 100644 index 0000000..37a45e0 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableRegistryEntry.cs @@ -0,0 +1,61 @@ +using System.Collections.Immutable; +using System.Linq; + +namespace EntityFrameworkCore.Projectables.Generator +{ + /// + /// Incremental-pipeline-safe representation of a single projectable member. + /// Contains only primitive types and an equatable wrapper around + /// so that structural value equality works correctly across incremental generation steps. + /// + sealed internal record ProjectableRegistryEntry( + string DeclaringTypeFullName, + string MemberKind, + string MemberLookupName, + string GeneratedClassFullName, + bool IsGenericClass, + bool IsGenericMethod, + EquatableImmutableArray ParameterTypeNames + ); + + /// + /// A structural-equality wrapper around of strings. + /// uses reference equality by default, which breaks + /// Roslyn's incremental-source-generator caching when the same logical array is + /// produced by two different steps. This wrapper provides element-wise equality so + /// that incremental steps are correctly cached and skipped. + /// + internal readonly struct EquatableImmutableArray : System.IEquatable + { + public static readonly EquatableImmutableArray Empty = new(ImmutableArray.Empty); + + public readonly ImmutableArray Array; + + public EquatableImmutableArray(ImmutableArray array) + { + Array = array; + } + + public bool IsDefaultOrEmpty => Array.IsDefaultOrEmpty; + + public bool Equals(EquatableImmutableArray other) => + Array.SequenceEqual(other.Array); + + public override bool Equals(object? obj) => + obj is EquatableImmutableArray other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = 17; + foreach (var s in Array) + hash = hash * 31 + (s?.GetHashCode() ?? 0); + return hash; + } + } + + public static implicit operator ImmutableArray(EquatableImmutableArray e) => e.Array; + public static implicit operator EquatableImmutableArray(ImmutableArray a) => new(a); + } +} diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs index 440caa5..0c83141 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectionExpressionGenerator.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using System.Collections.Immutable; using System.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -32,7 +33,7 @@ public class ProjectionExpressionGenerator : IIncrementalGenerator public void Initialize(IncrementalGeneratorInitializationContext context) { // Do a simple filter for members - IncrementalValuesProvider memberDeclarations = context.SyntaxProvider + var memberDeclarations = context.SyntaxProvider .ForAttributeWithMetadataName( ProjectablesAttributeName, predicate: static (s, _) => s is MemberDeclarationSyntax, @@ -40,16 +41,28 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(new MemberDeclarationSyntaxEqualityComparer()); // Combine the selected enums with the `Compilation` - IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations + var compilationAndMemberPairs = memberDeclarations .Combine(context.CompilationProvider) .WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer()); // Generate the source using the compilation and enums context.RegisterImplementationSourceOutput(compilationAndMemberPairs, static (spc, source) => Execute(source.Item1, source.Item2, spc)); + + // Build the projection registry: collect all entries and emit a single registry file + var registryEntries = + compilationAndMemberPairs.Select( + static (pair, _) => ExtractRegistryEntry(pair.Item1, pair.Item2)); + + var allEntries = + registryEntries.Collect(); + + context.RegisterImplementationSourceOutput( + allEntries, + static (spc, entries) => EmitRegistry(entries, spc)); } - static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) + private static SyntaxTriviaList BuildSourceDocComment(ConstructorDeclarationSyntax ctor, Compilation compilation) { var chain = CollectConstructorChain(ctor, compilation); @@ -90,7 +103,7 @@ void AddLine(string text) /// then its delegate's delegate, …). Stops when a delegated constructor has no source /// available in the compilation (e.g. a compiler-synthesised parameterless constructor). /// - static IReadOnlyList CollectConstructorChain( + private static List CollectConstructorChain( ConstructorDeclarationSyntax ctor, Compilation compilation) { var result = new List { ctor }; @@ -101,7 +114,9 @@ static IReadOnlyList CollectConstructorChain( { var semanticModel = compilation.GetSemanticModel(current.SyntaxTree); if (semanticModel.GetSymbolInfo(initializer).Symbol is not IMethodSymbol delegated) + { break; + } var delegatedSyntax = delegated.DeclaringSyntaxReferences .Select(r => r.GetSyntax()) @@ -109,7 +124,9 @@ static IReadOnlyList CollectConstructorChain( .FirstOrDefault(); if (delegatedSyntax is null || !visited.Add(delegatedSyntax)) + { break; + } result.Add(delegatedSyntax); current = delegatedSyntax; @@ -118,7 +135,7 @@ static IReadOnlyList CollectConstructorChain( return result; } - static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context) + private static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context) { var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context); @@ -146,40 +163,40 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou .WithLeadingTrivia(member is ConstructorDeclarationSyntax ctor ? BuildSourceDocComment(ctor, compilation) : TriviaList()) .AddMembers( MethodDeclaration( - GenericName( - Identifier("global::System.Linq.Expressions.Expression"), - TypeArgumentList( - SingletonSeparatedList( - (TypeSyntax)GenericName( - Identifier("global::System.Func"), - GetLambdaTypeArgumentListSyntax(projectable) + GenericName( + Identifier("global::System.Linq.Expressions.Expression"), + TypeArgumentList( + SingletonSeparatedList( + (TypeSyntax)GenericName( + Identifier("global::System.Func"), + GetLambdaTypeArgumentListSyntax(projectable) + ) ) ) - ) - ), - "Expression" - ) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .WithTypeParameterList(projectable.TypeParameterList) - .WithConstraintClauses(projectable.ConstraintClauses ?? List()) - .WithBody( - Block( - ReturnStatement( - ParenthesizedLambdaExpression( - projectable.ParametersList ?? ParameterList(), - null, - projectable.ExpressionBody + ), + "Expression" + ) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .WithTypeParameterList(projectable.TypeParameterList) + .WithConstraintClauses(projectable.ConstraintClauses ?? List()) + .WithBody( + Block( + ReturnStatement( + ParenthesizedLambdaExpression( + projectable.ParametersList ?? ParameterList(), + null, + projectable.ExpressionBody + ) ) ) ) - ) ); #nullable disable var compilationUnit = CompilationUnit(); - foreach (var usingDirective in projectable.UsingDirectives) + foreach (var usingDirective in projectable.UsingDirectives!) { compilationUnit = compilationUnit.AddUsings(usingDirective); } @@ -226,5 +243,419 @@ static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescrip return lambdaTypeArguments; } } + +#nullable restore + + /// + /// Extracts a from a member declaration. + /// Returns null when the member does not have [Projectable], is an extension member, + /// or cannot be represented in the registry (e.g. a generic class member or generic method). + /// + private static ProjectableRegistryEntry? ExtractRegistryEntry(MemberDeclarationSyntax member, Compilation compilation) + { + var semanticModel = compilation.GetSemanticModel(member.SyntaxTree); + var memberSymbol = semanticModel.GetDeclaredSymbol(member); + + if (memberSymbol is null) + { + return null; + } + + // Verify [Projectable] attribute + var projectableAttributeTypeSymbol = compilation.GetTypeByMetadataName("EntityFrameworkCore.Projectables.ProjectableAttribute"); + var projectableAttribute = memberSymbol.GetAttributes() + .FirstOrDefault(x => x.AttributeClass?.Name == "ProjectableAttribute"); + + if (projectableAttribute is null || + !SymbolEqualityComparer.Default.Equals(projectableAttribute.AttributeClass, projectableAttributeTypeSymbol)) + { + return null; + } + + // Skip C# 14 extension type members — they require special handling (fall back to reflection) + if (memberSymbol.ContainingType is { IsExtension: true }) + { + return null; + } + + var containingType = memberSymbol.ContainingType; + var isGenericClass = containingType.TypeParameters.Length > 0; + + // Determine member kind and lookup name + string memberKind; + string memberLookupName; + var parameterTypeNames = ImmutableArray.Empty; + var methodTypeParamCount = 0; + var isGenericMethod = false; + + if (memberSymbol is IMethodSymbol methodSymbol) + { + isGenericMethod = methodSymbol.TypeParameters.Length > 0; + methodTypeParamCount = methodSymbol.TypeParameters.Length; + + if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor) + { + memberKind = "Constructor"; + memberLookupName = "_ctor"; + } + else + { + memberKind = "Method"; + memberLookupName = memberSymbol.Name; + } + + parameterTypeNames = [ + ..methodSymbol.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; + } + else + { + memberKind = "Property"; + memberLookupName = memberSymbol.Name; + } + + // Build the generated class name using the same logic as Execute + var classNamespace = containingType.ContainingNamespace.IsGlobalNamespace + ? null + : containingType.ContainingNamespace.ToDisplayString(); + + var nestedTypePath = GetRegistryNestedTypePath(containingType); + + var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName( + classNamespace, + nestedTypePath, + memberLookupName, + parameterTypeNames.IsEmpty ? null : parameterTypeNames); + + var generatedClassFullName = "EntityFrameworkCore.Projectables.Generated." + generatedClassName; + + var declaringTypeFullName = containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + return new ProjectableRegistryEntry( + DeclaringTypeFullName: declaringTypeFullName, + MemberKind: memberKind, + MemberLookupName: memberLookupName, + GeneratedClassFullName: generatedClassFullName, + IsGenericClass: isGenericClass, + IsGenericMethod: isGenericMethod, + ParameterTypeNames: parameterTypeNames); + } + + private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) + { + if (typeSymbol.ContainingType is not null) + { + foreach (var name in GetRegistryNestedTypePath(typeSymbol.ContainingType)) + { + yield return name; + } + } + yield return typeSymbol.Name; + } + + /// + /// Emits the ProjectionRegistry.g.cs file that aggregates all projectable members + /// into a single static dictionary keyed by . + /// Uses SyntaxFactory for the class/method/field structure, consistent with . + /// The generated Build() method uses a shared Register helper to avoid repeating + /// the lookup boilerplate for every entry. + /// + private static void EmitRegistry(ImmutableArray entries, SourceProductionContext context) + { + // Build the per-entry Register(...) statements first so we can bail out early + // if every entry is generic (they all fall back to reflection, no registry needed). + var entryStatements = entries + .Where(e => e is not null) + .Select(e => BuildRegistryEntryStatement(e!)) + .Where(s => s is not null) + .Select(s => s!) + .ToList(); + + if (entryStatements.Count == 0) + { + return; + } + + // Build() body: + // const BindingFlags allFlags = ...; + // var map = new Dictionary(); + // Register(map, typeof(T).GetXxx(...), "ClassName"); ← one line per entry + // return map; + var buildStatements = new List + { + // const BindingFlags allFlags = BindingFlags.Public | BindingFlags.NonPublic | ...; + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("BindingFlags")) + .AddVariables( + VariableDeclarator("allFlags") + .WithInitializer(EqualsValueClause( + ParseExpression( + "BindingFlags.Public | BindingFlags.NonPublic | " + + "BindingFlags.Instance | BindingFlags.Static"))))) + .WithModifiers(TokenList(Token(SyntaxKind.ConstKeyword))), + + // var map = new Dictionary(); + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("map") + .WithInitializer(EqualsValueClause( + ObjectCreationExpression( + ParseTypeName("Dictionary")) + .WithArgumentList(ArgumentList()))))), + }; + + buildStatements.AddRange(entryStatements); + buildStatements.Add(ReturnStatement(IdentifierName("map"))); + + var classSyntax = ClassDeclaration("ProjectionRegistry") + .WithModifiers(TokenList( + Token(SyntaxKind.InternalKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddAttributeLists(AttributeList().AddAttributes(_editorBrowsableAttribute)) + .AddMembers( + // private static readonly Dictionary _map = Build(); + FieldDeclaration( + VariableDeclaration(ParseTypeName("Dictionary")) + .AddVariables( + VariableDeclarator("_map") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("Build")) + .WithArgumentList(ArgumentList()))))) + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.ReadOnlyKeyword))), + + // public static LambdaExpression TryGet(MemberInfo member) + MethodDeclaration(ParseTypeName("LambdaExpression"), "TryGet") + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("member")) + .WithType(ParseTypeName("MemberInfo"))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("handle") + .WithInitializer(EqualsValueClause( + InvocationExpression(IdentifierName("GetHandle")) + .AddArgumentListArguments( + Argument(IdentifierName("member"))))))), + ReturnStatement( + ParseExpression( + "handle.HasValue && _map.TryGetValue(handle.Value, out var expr) ? expr : null")))), + + // private static nint? GetHandle(MemberInfo member) => member switch { ... }; + MethodDeclaration(ParseTypeName("nint?"), "GetHandle") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("member")) + .WithType(ParseTypeName("MemberInfo"))) + .WithExpressionBody(ArrowExpressionClause( + SwitchExpression(IdentifierName("member")) + .WithArms(SeparatedList( + new SyntaxNodeOrToken[] + { + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("MethodInfo"), + SingleVariableDesignation(Identifier("m"))), + ParseExpression("m.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("PropertyInfo"), + SingleVariableDesignation(Identifier("p"))), + ParseExpression("p.GetMethod?.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DeclarationPattern( + ParseTypeName("ConstructorInfo"), + SingleVariableDesignation(Identifier("c"))), + ParseExpression("c.MethodHandle.Value")), + Token(SyntaxKind.CommaToken), + SwitchExpressionArm( + DiscardPattern(), + LiteralExpression(SyntaxKind.NullLiteralExpression)) + })))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)), + + // private static Dictionary Build() { ... } + MethodDeclaration(ParseTypeName("Dictionary"), "Build") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .WithBody(Block(buildStatements)), + + // private static void Register(Dictionary map, MethodBase m, string exprClass) + BuildRegisterHelperMethod()); + + var compilationUnit = CompilationUnit() + .AddUsings( + UsingDirective(ParseName("System")), + UsingDirective(ParseName("System.Collections.Generic")), + UsingDirective(ParseName("System.Linq.Expressions")), + UsingDirective(ParseName("System.Reflection"))) + .AddMembers( + NamespaceDeclaration(ParseName("EntityFrameworkCore.Projectables.Generated")) + .AddMembers(classSyntax)) + .WithLeadingTrivia(TriviaList( + Comment("// "), + Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))); + + context.AddSource("ProjectionRegistry.g.cs", + SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8)); + } + + /// + /// Builds a single compact Register(map, typeof(T).GetXxx(...), "ClassName") + /// statement for one projectable entry in Build(). + /// Returns for generic class/method entries (they fall back to reflection). + /// + private static ExpressionStatementSyntax? BuildRegistryEntryStatement(ProjectableRegistryEntry entry) + { + if (entry.IsGenericClass || entry.IsGenericMethod) + { + return null; + } + + // typeof(DeclaringType).GetProperty/Method/Constructor(name, allFlags, ...) + ExpressionSyntax? memberCallExpr = entry.MemberKind switch + { + // typeof(T).GetProperty("Name", allFlags)?.GetMethod + "Property" => ConditionalAccessExpression( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetProperty"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(IdentifierName("allFlags"))), + MemberBindingExpression(IdentifierName("GetMethod"))), + + // typeof(T).GetMethod("Name", allFlags, null, new Type[] {...}, null) + "Method" => InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetMethod"))) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.MemberLookupName))), + Argument(IdentifierName("allFlags")), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), + + // typeof(T).GetConstructor(allFlags, null, new Type[] {...}, null) + "Constructor" => InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseTypeName(entry.DeclaringTypeFullName)), + IdentifierName("GetConstructor"))) + .AddArgumentListArguments( + Argument(IdentifierName("allFlags")), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument(BuildTypeArrayExpr(entry.ParameterTypeNames)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression))), + + _ => null + }; + + if (memberCallExpr is null) + { + return null; + } + + // Register(map, , ""); + return ExpressionStatement( + InvocationExpression(IdentifierName("Register")) + .AddArgumentListArguments( + Argument(IdentifierName("map")), + Argument(memberCallExpr), + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, + Literal(entry.GeneratedClassFullName))))); + } + + /// + /// Builds the Register private static helper method that all per-entry calls delegate to. + /// It handles the null checks and the common reflection lookup pattern once, centrally. + /// + private static MethodDeclarationSyntax BuildRegisterHelperMethod() => + // private static void Register(Dictionary map, MethodBase m, string exprClass) + // { + // if (m is null) return; + // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + // if (exprMethod is not null) + // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + // } + MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Register") + .WithModifiers(TokenList( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters( + Parameter(Identifier("map")) + .WithType(ParseTypeName("Dictionary")), + Parameter(Identifier("m")) + .WithType(ParseTypeName("MethodBase")), + Parameter(Identifier("exprClass")) + .WithType(PredefinedType(Token(SyntaxKind.StringKeyword)))) + .WithBody(Block( + // if (m is null) return; + IfStatement( + IsPatternExpression( + IdentifierName("m"), + ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), + ReturnStatement()), + // var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprType") + .WithInitializer(EqualsValueClause( + ParseExpression("m.DeclaringType?.Assembly.GetType(exprClass)"))))), + // var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables( + VariableDeclarator("exprMethod") + .WithInitializer(EqualsValueClause( + ParseExpression( + @"exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic)"))))), + // if (exprMethod is not null) + // map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!; + IfStatement( + ParseExpression("exprMethod is not null"), + ExpressionStatement( + ParseExpression( + "map[m.MethodHandle.Value] = (LambdaExpression)exprMethod.Invoke(null, null)!"))))); + + /// + /// Builds the typeof(...)-array expression used for reflection method/constructor lookup. + /// Returns global::System.Type.EmptyTypes when there are no parameters. + /// + private static ExpressionSyntax BuildTypeArrayExpr(ImmutableArray parameterTypeNames) + { + if (parameterTypeNames.IsEmpty) + { + return ParseExpression("global::System.Type.EmptyTypes"); + } + + var typeofExprs = parameterTypeNames + .Select(name => (ExpressionSyntax)TypeOfExpression(ParseTypeName(name))) + .ToArray(); + + return ArrayCreationExpression( + ArrayType(ParseTypeName("global::System.Type")) + .AddRankSpecifiers(ArrayRankSpecifier())) + .WithInitializer( + InitializerExpression(SyntaxKind.ArrayInitializerExpression, + SeparatedList(typeofExprs))); + } } -} +} \ No newline at end of file diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 16e547a..e342a6e 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -21,27 +21,19 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor private readonly bool _trackingByDefault; private IEntityType? _entityType; - private readonly MethodInfo _select; - private readonly MethodInfo _where; + // Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain) + private static readonly MethodInfo _select = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition(); + + private static readonly MethodInfo _where = + ((MethodCallExpression)((Expression, IQueryable>>) + (q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition(); public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false) { _trackingByDefault = trackByDefault; _resolver = projectionExpressionResolver; - _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Select)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); - _where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Where)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 7e26d69..a4b8cd6 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -8,6 +9,35 @@ namespace EntityFrameworkCore.Projectables.Services { public sealed class ProjectionExpressionResolver : IProjectionExpressionResolver { + // We never store null in the dictionary; assemblies without a registry use a sentinel delegate. + private static readonly Func _nullRegistry = static _ => null!; + private static readonly ConcurrentDictionary> _assemblyRegistries = new(); + + /// + /// Looks up the generated ProjectionRegistry class in an assembly (once, then caches it). + /// Returns a delegate that calls TryGet(MemberInfo) on the registry, or null if the registry + /// is not present in that assembly (e.g. if the source generator was not run against it). + /// + private static Func? GetAssemblyRegistry(Assembly assembly) + { + var registry = _assemblyRegistries.GetOrAdd(assembly, static asm => + { + var registryType = asm.GetType("EntityFrameworkCore.Projectables.Generated.ProjectionRegistry"); + var tryGetMethod = registryType?.GetMethod("TryGet", BindingFlags.Static | BindingFlags.Public); + + if (tryGetMethod is null) + { + // Use sentinel to indicate "no registry for this assembly" + return _nullRegistry; + } + + return (Func)Delegate.CreateDelegate(typeof(Func), tryGetMethod); + }); + + // Translate sentinel back to null for callers, preserving existing behavior. + return ReferenceEquals(registry, _nullRegistry) ? null : registry; + } + public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) { var projectableAttribute = projectableMemberInfo.GetCustomAttribute() @@ -62,165 +92,186 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + + // Fast path: check the per-assembly static registry (generated by source generator). + // The first call per assembly does a reflection lookup to find the registry class and + // caches it as a delegate; subsequent calls use the cached delegate for an O(1) dictionary lookup. + var registry = GetAssemblyRegistry(declaringType.Assembly); + var registeredExpr = registry?.Invoke(projectableMemberInfo); - // Keep track of the original declaring type's generic arguments for later use - var originalDeclaringType = declaringType; - - // For generic types, use the generic type definition to match the generated name - // which is based on the open generic type - if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) - { - declaringType = declaringType.GetGenericTypeDefinition(); - } - - // Get parameter types for method overload disambiguation - // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat - // which uses C# keywords for primitive types (int, string, etc.) - string[]? parameterTypeNames = null; - string memberLookupName = projectableMemberInfo.Name; - if (projectableMemberInfo is MethodInfo method) - { - // For generic methods, use the generic definition to get parameter types - // This ensures type parameters like TEntity are used instead of concrete types - var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; - - parameterTypeNames = methodToInspect.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); - } - else if (projectableMemberInfo is ConstructorInfo ctor) + return registeredExpr ?? + // Slow path: reflection fallback for open-generic class members and generic methods + // that are not yet in the registry. + FindGeneratedExpressionViaReflection(projectableMemberInfo); + } + } + + /// + /// Resolves the for a [Projectable] member using the + /// reflection-based slow path only, bypassing the static registry. + /// Useful for benchmarking and for members not yet in the registry (e.g. open-generic types). + /// + public static LambdaExpression? FindGeneratedExpressionViaReflection(MemberInfo projectableMemberInfo) + { + var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + + // Keep track of the original declaring type's generic arguments for later use + var originalDeclaringType = declaringType; + + // For generic types, use the generic type definition to match the generated name + // which is based on the open generic type + if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition) + { + declaringType = declaringType.GetGenericTypeDefinition(); + } + + // Get parameter types for method overload disambiguation + // Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat + // which uses C# keywords for primitive types (int, string, etc.) + string[]? parameterTypeNames = null; + string memberLookupName = projectableMemberInfo.Name; + if (projectableMemberInfo is MethodInfo method) + { + // For generic methods, use the generic definition to get parameter types + // This ensures type parameters like TEntity are used instead of concrete types + var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; + + parameterTypeNames = methodToInspect.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + else if (projectableMemberInfo is ConstructorInfo ctor) + { + // Constructors are stored under the synthetic name "_ctor" + memberLookupName = "_ctor"; + parameterTypeNames = ctor.GetParameters() + .Select(p => GetFullTypeName(p.ParameterType)) + .ToArray(); + } + + var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); + + var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + + if (expressionFactoryType is not null) + { + if (expressionFactoryType.IsGenericTypeDefinition) { - // Constructors are stored under the synthetic name "_ctor" - memberLookupName = "_ctor"; - parameterTypeNames = ctor.GetParameters() - .Select(p => GetFullTypeName(p.ParameterType)) - .ToArray(); + expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); } - - var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), memberLookupName, parameterTypeNames); - var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - if (expressionFactoryType is not null) + var methodGenericArguments = projectableMemberInfo switch { + MethodInfo methodInfo => methodInfo.GetGenericArguments(), + _ => null + }; + + if (expressionFactoryMethod is not null) { - if (expressionFactoryType.IsGenericTypeDefinition) + if (methodGenericArguments is { Length: > 0 }) { - expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments); + expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); } - var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); - - var methodGenericArguments = projectableMemberInfo switch { - MethodInfo methodInfo => methodInfo.GetGenericArguments(), - _ => null - }; + return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); + } + } - if (expressionFactoryMethod is not null) - { - if (methodGenericArguments is { Length: > 0 }) - { - expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(methodGenericArguments); - } + return null; + } - return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); - } - } + private static string GetFullTypeName(Type type) + { + // Handle generic type parameters (e.g., T, TEntity) + if (type.IsGenericParameter) + { + return type.Name; + } - return null; + // Handle nullable value types (e.g., int? -> int?) + var underlyingType = Nullable.GetUnderlyingType(type); + if (underlyingType != null) + { + return $"{GetFullTypeName(underlyingType)}?"; } - - static string GetFullTypeName(Type type) + + // Handle array types + if (type.IsArray) { - // Handle generic type parameters (e.g., T, TEntity) - if (type.IsGenericParameter) + var elementType = type.GetElementType(); + if (elementType == null) { + // Fallback for edge cases where GetElementType() might return null return type.Name; } - - // Handle nullable value types (e.g., int? -> int?) - var underlyingType = Nullable.GetUnderlyingType(type); - if (underlyingType != null) - { - return $"{GetFullTypeName(underlyingType)}?"; - } - - // Handle array types - if (type.IsArray) - { - var elementType = type.GetElementType(); - if (elementType == null) - { - // Fallback for edge cases where GetElementType() might return null - return type.Name; - } - - var rank = type.GetArrayRank(); - var elementTypeName = GetFullTypeName(elementType); - - if (rank == 1) - { - return $"{elementTypeName}[]"; - } - else - { - var commas = new string(',', rank - 1); - return $"{elementTypeName}[{commas}]"; - } - } - - // Map primitive types to their C# keyword equivalents to match Roslyn's output - var typeKeyword = GetCSharpKeyword(type); - if (typeKeyword != null) + + var rank = type.GetArrayRank(); + var elementTypeName = GetFullTypeName(elementType); + + if (rank == 1) { - return typeKeyword; + return $"{elementTypeName}[]"; } - - // For generic types, construct the full name matching Roslyn's format - if (type.IsGenericType) + else { - var genericTypeDef = type.GetGenericTypeDefinition(); - var genericArgs = type.GetGenericArguments(); - var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; - - // Remove the `n suffix (e.g., `1, `2) - var backtickIndex = baseName.IndexOf('`'); - if (backtickIndex > 0) - { - baseName = baseName.Substring(0, backtickIndex); - } - - var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); - return $"{baseName}<{args}>"; + var commas = new string(',', rank - 1); + return $"{elementTypeName}[{commas}]"; } - - if (type.FullName != null) + } + + // Map primitive types to their C# keyword equivalents to match Roslyn's output + var typeKeyword = GetCSharpKeyword(type); + if (typeKeyword != null) + { + return typeKeyword; + } + + // For generic types, construct the full name matching Roslyn's format + if (type.IsGenericType) + { + var genericTypeDef = type.GetGenericTypeDefinition(); + var genericArgs = type.GetGenericArguments(); + var baseName = genericTypeDef.FullName ?? genericTypeDef.Name; + + // Remove the `n suffix (e.g., `1, `2) + var backtickIndex = baseName.IndexOf('`'); + if (backtickIndex > 0) { - // Replace + with . for nested types to match Roslyn's format - return type.FullName.Replace('+', '.'); + baseName = baseName.Substring(0, backtickIndex); } - - return type.Name; + + var args = string.Join(", ", genericArgs.Select(GetFullTypeName)); + return $"{baseName}<{args}>"; } - - static string? GetCSharpKeyword(Type type) + + if (type.FullName != null) { - if (type == typeof(bool)) return "bool"; - if (type == typeof(byte)) return "byte"; - if (type == typeof(sbyte)) return "sbyte"; - if (type == typeof(char)) return "char"; - if (type == typeof(decimal)) return "decimal"; - if (type == typeof(double)) return "double"; - if (type == typeof(float)) return "float"; - if (type == typeof(int)) return "int"; - if (type == typeof(uint)) return "uint"; - if (type == typeof(long)) return "long"; - if (type == typeof(ulong)) return "ulong"; - if (type == typeof(short)) return "short"; - if (type == typeof(ushort)) return "ushort"; - if (type == typeof(object)) return "object"; - if (type == typeof(string)) return "string"; - return null; + // Replace + with . for nested types to match Roslyn's format + return type.FullName.Replace('+', '.'); } + + return type.Name; + } + + private static string? GetCSharpKeyword(Type type) + { + if (type == typeof(bool)) return "bool"; + if (type == typeof(byte)) return "byte"; + if (type == typeof(sbyte)) return "sbyte"; + if (type == typeof(char)) return "char"; + if (type == typeof(decimal)) return "decimal"; + if (type == typeof(double)) return "double"; + if (type == typeof(float)) return "float"; + if (type == typeof(int)) return "int"; + if (type == typeof(uint)) return "uint"; + if (type == typeof(long)) return "long"; + if (type == typeof(ulong)) return "ulong"; + if (type == typeof(short)) return "short"; + if (type == typeof(ushort)) return "ushort"; + if (type == typeof(object)) return "object"; + if (type == typeof(string)) return "string"; + return null; } } } diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs index 9a791e2..acfc52a 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTestsBase.cs @@ -1,5 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using System.Collections.Immutable; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -17,6 +18,47 @@ protected ProjectionExpressionGeneratorTestsBase(ITestOutputHelper testOutputHel _testOutputHelper = testOutputHelper; } + /// + /// Wraps and exposes + /// as a filtered view that excludes the generated ProjectionRegistry.g.cs file. + /// This keeps all existing tests working without modification after the registry was added. + /// + protected sealed class TestGeneratorRunResult + { + private readonly GeneratorDriverRunResult _inner; + + public TestGeneratorRunResult(GeneratorDriverRunResult inner) + { + _inner = inner; + } + + /// + /// Diagnostics from the generator run. + /// + public ImmutableArray Diagnostics => _inner.Diagnostics; + + /// + /// Generated trees excluding ProjectionRegistry.g.cs. + /// Existing tests use this and should continue to work without modification. + /// + public ImmutableArray GeneratedTrees => + _inner.GeneratedTrees + .Where(t => !t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)) + .ToImmutableArray(); + + /// + /// All generated trees including ProjectionRegistry.g.cs. + /// Use this in new tests that need to verify the registry. + /// + public ImmutableArray AllGeneratedTrees => _inner.GeneratedTrees; + + /// + /// The generated ProjectionRegistry.g.cs tree, or null if it was not generated. + /// + public SyntaxTree? RegistryTree => + _inner.GeneratedTrees.FirstOrDefault(t => t.FilePath.EndsWith("ProjectionRegistry.g.cs", StringComparison.Ordinal)); + } + protected Compilation CreateCompilation([StringSyntax("csharp")] string source) { var references = Basic.Reference.Assemblies. @@ -58,7 +100,7 @@ protected Compilation CreateCompilation([StringSyntax("csharp")] string source) return compilation; } - protected GeneratorDriverRunResult RunGenerator(Compilation compilation) + protected TestGeneratorRunResult RunGenerator(Compilation compilation) { _testOutputHelper.WriteLine("Running generator and updating compilation..."); @@ -67,7 +109,8 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) .Create(subject) .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); - var result = driver.GetRunResult(); + var rawResult = driver.GetRunResult(); + var result = new TestGeneratorRunResult(rawResult); if (result.Diagnostics.IsEmpty) { @@ -83,7 +126,7 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) } } - foreach (var newSyntaxTree in result.GeneratedTrees) + foreach (var newSyntaxTree in result.AllGeneratedTrees) { _testOutputHelper.WriteLine($"Produced syntax tree with path produced: {newSyntaxTree.FilePath}"); _testOutputHelper.WriteLine(newSyntaxTree.GetText().ToString()); @@ -91,7 +134,7 @@ protected GeneratorDriverRunResult RunGenerator(Compilation compilation) // Verify that the generated code compiles without errors var hasGeneratorErrors = result.Diagnostics.Any(d => d.Severity == DiagnosticSeverity.Error); - if (!hasGeneratorErrors && result.GeneratedTrees.Length > 0) + if (!hasGeneratorErrors && result.AllGeneratedTrees.Length > 0) { _testOutputHelper.WriteLine("Checking that generated code compiles..."); diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs new file mode 100644 index 0000000..27fe1f8 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/RegistryTests.cs @@ -0,0 +1,191 @@ +using Xunit; +using Xunit.Abstractions; + +namespace EntityFrameworkCore.Projectables.Generator.Tests; + +public class RegistryTests : ProjectionExpressionGeneratorTestsBase +{ + public RegistryTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { } + + [Fact] + public void NoProjectables_NoRegistry() + { + var compilation = CreateCompilation(@"class C { }"); + var result = RunGenerator(compilation); + + Assert.Null(result.RegistryTree); + } + + [Fact] + public void SingleProperty_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + Assert.Contains("ProjectionRegistry", src); + // Uses the compact Register helper — not a repeated block + Assert.Contains("private static void Register(", src); + Assert.Contains("Register(map,", src); + Assert.Contains("GetProperty(\"IdPlus1\"", src); + Assert.Contains("Foo_C_IdPlus1", src); + } + + [Fact] + public void SingleMethod_RegistryContainsEntry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + Assert.Contains("GetMethod(\"AddDelta\"", src); + Assert.Contains("typeof(int)", src); + Assert.Contains("Foo_C_AddDelta_P0_int", src); + } + + [Fact] + public void MultipleProjectables_AllRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + [Projectable] + public int AddDelta(int delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Two separate Register(map, ...) calls — one per projectable + Assert.Contains("GetProperty(\"IdPlus1\"", src); + Assert.Contains("GetMethod(\"AddDelta\"", src); + // Each entry is a single line, not a repeated multi-line block + var registerCallCount = CountOccurrences(src, "Register(map,"); + Assert.Equal(2, registerCallCount); + } + + [Fact] + public void GenericClass_NotIncludedInRegistry() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + // Generic class members fall back to reflection — no registry emitted + Assert.Null(result.RegistryTree); + } + + [Fact] + public void Registry_ConstBindingFlagsUsedInBuild() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Build() uses a single const BindingFlags instead of repeating the flags per entry + Assert.Contains("const BindingFlags allFlags", src); + Assert.Contains("allFlags", src); + } + + [Fact] + public void Registry_RegisterHelperUsesDeclaringTypeAssembly() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int IdPlus1 => Id + 1; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Register helper derives the assembly from m.DeclaringType (no typeof repeated per entry) + Assert.Contains("m.DeclaringType?.Assembly.GetType(exprClass)", src); + } + + [Fact] + public void MethodOverloads_BothRegistered() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; +namespace Foo { + class C { + public int Id { get; set; } + [Projectable] + public int Add(int delta) => Id + delta; + [Projectable] + public long Add(long delta) => Id + delta; + } +}"); + var result = RunGenerator(compilation); + + Assert.NotNull(result.RegistryTree); + var src = result.RegistryTree!.GetText().ToString(); + + // Both overloads registered by parameter-type disambiguation + Assert.Contains("typeof(int)", src); + Assert.Contains("typeof(long)", src); + var registerCallCount = CountOccurrences(src, "Register(map,"); + Assert.Equal(2, registerCallCount); + } + + private static int CountOccurrences(string text, string pattern) + { + int count = 0; + int index = 0; + while ((index = text.IndexOf(pattern, index, System.StringComparison.Ordinal)) >= 0) + { + count++; + index += pattern.Length; + } + return count; + } +}