Skip to content

Commit bfa6bb2

Browse files
author
alexander.marek
committed
#223 - allow for custom registrations
1 parent 75a47e9 commit bfa6bb2

File tree

72 files changed

+5554
-802
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+5554
-802
lines changed

src/Mediator.SourceGenerator/Implementation/Models/NotificationMessageHandlerModel.cs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Mediator.SourceGenerator.Extensions;
1+
using Mediator.SourceGenerator.Extensions;
22

33
namespace Mediator.SourceGenerator;
44

@@ -19,8 +19,10 @@ public NotificationMessageHandlerModel(NotificationMessageHandler handler, Compi
1919

2020
if (!handler.Symbol.IsGenericType)
2121
{
22-
var concreteRegistration =
23-
$"services.TryAdd(new {sd}(typeof({concreteSymbol}), typeof({concreteSymbol}), {analyzer.ServiceLifetime}));";
22+
var concreteRegistration = $"""
23+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({concreteSymbol}), typeof({concreteSymbol})))
24+
services.TryAdd(new {sd}(typeof({concreteSymbol}), typeof({concreteSymbol}), {analyzer.ServiceLifetime}));
25+
""";
2426
builder.Add(concreteRegistration);
2527
}
2628

@@ -29,14 +31,19 @@ public NotificationMessageHandlerModel(NotificationMessageHandler handler, Compi
2931
var requestType = message.Symbol.GetTypeSymbolFullName();
3032
if (handler.Symbol.IsGenericType)
3133
{
32-
var concreteRegistration =
33-
$"services.TryAdd(new {sd}(typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>), {analyzer.ServiceLifetime}));";
34+
var concreteRegistration = $"""
35+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>)))
36+
services.TryAdd(new {sd}(typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>), {analyzer.ServiceLifetime}));
37+
""";
3438
builder.Add(concreteRegistration);
3539
}
36-
var getExpression =
37-
$"GetRequiredService<{concreteSymbol}{(handler.Symbol.IsGenericType ? $"<{requestType}>" : "")}>()";
38-
var registration =
39-
$"services.Add(new {sd}(typeof({interfaceSymbol}<{requestType}>), {getExpression}, {analyzer.ServiceLifetime}));";
40+
41+
var concreteImpl = $"{concreteSymbol}{(handler.Symbol.IsGenericType ? $"<{requestType}>" : "")}";
42+
var getExpression = $"GetRequiredService<{concreteImpl}>()";
43+
var registration = $"""
44+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({interfaceSymbol}<{requestType}>), typeof({concreteImpl})))
45+
services.Add(new {sd}(typeof({interfaceSymbol}<{requestType}>), {getExpression}, {analyzer.ServiceLifetime}));
46+
""";
4047
builder.Add(registration);
4148
}
4249

src/Mediator.SourceGenerator/Implementation/resources/Mediator.sbn-cs

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ namespace Microsoft.Extensions.DependencyInjection
4949
throw new global::System.Exception(errMsg);
5050
}
5151

52+
// Build cache of existing registrations for efficient lookup
53+
var existingRegistrations = BuildRegistrationCache(services);
54+
5255
{{~ if ServiceLifetimeIsTransient || ServiceLifetimeIsScoped ~}}
5356
services.Add(new {{ SD }}(typeof(global::{{ MediatorNamespace }}.Mediator), typeof(global::{{ MediatorNamespace }}.Mediator), {{ ServiceLifetime }}));
5457
services.TryAdd(new {{ SD }}(typeof(global::Mediator.IMediator), typeof(global::{{ MediatorNamespace }}.Mediator), {{ ServiceLifetime }}));
@@ -61,7 +64,7 @@ namespace Microsoft.Extensions.DependencyInjection
6164
services.TryAdd(new {{ SD }}(typeof(global::Mediator.IPublisher), sp => sp.GetRequiredService<global::{{ MediatorNamespace }}.Mediator>(), {{ ServiceLifetime }}));
6265
{{~ end ~}}
6366
{{~ if (object.size RequestMessages) > 0 ~}}
64-
67+
6568
// Register handlers for request messages
6669
{{~ for message in RequestMessages ~}}
6770
{{ message.Handler.ServiceRegistration }}
@@ -70,14 +73,14 @@ namespace Microsoft.Extensions.DependencyInjection
7073
{{~ end ~}}
7174
{{~ end ~}}
7275
{{~ if (object.size NotificationMessages) > 0 ~}}
73-
76+
7477
// Register handlers and wrappers for notification messages
7578
{{~ for message in NotificationMessages ~}}
7679
services.Add(new {{ SD }}(typeof({{ message.HandlerWrapperTypeNameWithGenericTypeArguments }}), typeof({{ message.HandlerWrapperTypeNameWithGenericTypeArguments }}), {{ SingletonServiceLifetime }}));
7780
{{~ end ~}}
7881
{{~ end ~}}
7982
{{~ if (object.size NotificationMessageHandlers) > 0 ~}}
80-
83+
8184
// Register notification handlers
8285
{{~ for handler in NotificationMessageHandlers ~}}
8386
{{~ for registration in handler.ServiceRegistrations ~}}
@@ -86,15 +89,15 @@ namespace Microsoft.Extensions.DependencyInjection
8689
{{~ end ~}}
8790
{{~ end ~}}
8891
{{~ if (object.size PipelineBehaviors) > 0 ~}}
89-
92+
9093
// Register pipeline behaviors configured through options
9194
{{~ for behavior in PipelineBehaviors ~}}
9295
{{~ for registration in behavior.ServiceRegistrations ~}}
9396
{{ registration }}
9497
{{~ end ~}}
9598
{{~ end ~}}
9699
{{~ end ~}}
97-
100+
98101
// Register the notification publisher that was configured
99102
{{~ if ServiceLifetimeIsScoped || ServiceLifetimeIsTransient ~}}
100103
services.Add(new {{ SD }}(typeof({{ NotificationPublisherType.FullName }}), typeof({{ NotificationPublisherType.FullName }}), {{ ServiceLifetime }}));
@@ -103,18 +106,68 @@ namespace Microsoft.Extensions.DependencyInjection
103106
services.Add(new {{ SD }}(typeof({{ NotificationPublisherType.FullName }}), typeof({{ NotificationPublisherType.FullName }}), {{ SingletonServiceLifetime }}));
104107
services.TryAdd(new {{ SD }}(typeof(global::Mediator.INotificationPublisher), sp => sp.GetRequiredService<{{ NotificationPublisherType.FullName }}>(), {{ SingletonServiceLifetime }}));
105108
{{~ end ~}}
106-
109+
107110
// Register internal components
108111
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.IContainerProbe), typeof(global::{{ InternalsNamespace }}.ContainerProbe0), {{ ServiceLifetime }}));
109112
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.IContainerProbe), typeof(global::{{ InternalsNamespace }}.ContainerProbe1), {{ ServiceLifetime }}));
110113
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.ContainerMetadata), typeof(global::{{ InternalsNamespace }}.ContainerMetadata), {{ SingletonServiceLifetime }}));
111-
114+
112115
return services;
113-
114-
{{~ if HasNotifications ~}}
116+
117+
{{~ if HasNotifications ~}}
115118
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
116119
static global::System.Func<global::System.IServiceProvider, T> GetRequiredService<T>() where T : notnull => sp => sp.GetRequiredService<T>();
117-
{{~ end ~}}
120+
{{~ end ~}}
121+
}
122+
123+
/// <summary>
124+
/// Builds a cache of existing service registrations for efficient duplicate detection.
125+
/// Maps service types to their registered implementation types.
126+
/// </summary>
127+
/// <param name="services">The service collection to analyze</param>
128+
/// <returns>Dictionary mapping service types to sets of implementation types</returns>
129+
private static global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>
130+
BuildRegistrationCache(IServiceCollection services)
131+
{
132+
var cache = new global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>();
133+
134+
foreach (var service in services)
135+
{
136+
if (service.ServiceType == null) continue;
137+
138+
if (!cache.ContainsKey(service.ServiceType))
139+
{
140+
cache[service.ServiceType] = new global::System.Collections.Generic.HashSet<global::System.Type>();
141+
}
142+
143+
// Handle different ServiceDescriptor registration patterns
144+
if (service.ImplementationType != null)
145+
{
146+
cache[service.ServiceType].Add(service.ImplementationType);
147+
}
148+
else if (service.ImplementationInstance != null)
149+
{
150+
cache[service.ServiceType].Add(service.ImplementationInstance.GetType());
151+
}
152+
}
153+
154+
return cache;
155+
}
156+
157+
/// <summary>
158+
/// Checks if a handler registration already exists in the service collection.
159+
/// </summary>
160+
/// <param name="existingRegistrations">Cache of existing registrations</param>
161+
/// <param name="serviceType">The service interface type</param>
162+
/// <param name="implementationType">The concrete implementation type</param>
163+
/// <returns>True if the handler is already registered</returns>
164+
private static bool IsHandlerAlreadyRegistered(
165+
global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>> existingRegistrations,
166+
global::System.Type serviceType,
167+
global::System.Type implementationType)
168+
{
169+
return existingRegistrations.ContainsKey(serviceType) &&
170+
existingRegistrations[serviceType].Contains(implementationType);
118171
}
119172
}
120173
}

test/Mediator.SourceGenerator.Tests/_snapshots/BasicTests.Multiple_Notification_Handlers_One_Class#Mediator.g.verified.cs

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,34 +50,98 @@ public static IServiceCollection AddMediator(this IServiceCollection services, g
5050
throw new global::System.Exception(errMsg);
5151
}
5252

53+
// Build cache of existing registrations for efficient lookup
54+
var existingRegistrations = BuildRegistrationCache(services);
55+
5356
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Mediator), typeof(global::Mediator.Mediator), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
5457
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.IMediator), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
5558
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.ISender), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
5659
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.IPublisher), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
57-
60+
5861
// Register handlers and wrappers for notification messages
5962
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification0>), typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification0>), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
6063
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification1>), typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification1>), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
61-
64+
6265
// Register notification handlers
63-
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
64-
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
65-
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
66-
66+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler)))
67+
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
68+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), typeof(global::TestCode.RequestHandler)))
69+
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
70+
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), typeof(global::TestCode.RequestHandler)))
71+
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
72+
6773
// Register the notification publisher that was configured
6874
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.ForeachAwaitPublisher), typeof(global::Mediator.ForeachAwaitPublisher), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
6975
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationPublisher), sp => sp.GetRequiredService<global::Mediator.ForeachAwaitPublisher>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
70-
76+
7177
// Register internal components
7278
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.IContainerProbe), typeof(global::Mediator.Internals.ContainerProbe0), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
7379
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.IContainerProbe), typeof(global::Mediator.Internals.ContainerProbe1), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
7480
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.ContainerMetadata), typeof(global::Mediator.Internals.ContainerMetadata), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
75-
81+
7682
return services;
77-
83+
7884
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
7985
static global::System.Func<global::System.IServiceProvider, T> GetRequiredService<T>() where T : notnull => sp => sp.GetRequiredService<T>();
8086
}
87+
88+
/// <summary>
89+
/// Builds a cache of existing service registrations for efficient duplicate detection.
90+
/// Maps service types to their registered implementation types.
91+
/// </summary>
92+
/// <param name="services">The service collection to analyze</param>
93+
/// <returns>Dictionary mapping service types to sets of implementation types</returns>
94+
private static global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>
95+
BuildRegistrationCache(IServiceCollection services)
96+
{
97+
var cache = new global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>();
98+
99+
foreach (var service in services)
100+
{
101+
if (service.ServiceType == null) continue;
102+
103+
if (!cache.ContainsKey(service.ServiceType))
104+
{
105+
cache[service.ServiceType] = new global::System.Collections.Generic.HashSet<global::System.Type>();
106+
}
107+
108+
// Handle different ServiceDescriptor registration patterns
109+
if (service.ImplementationType != null)
110+
{
111+
cache[service.ServiceType].Add(service.ImplementationType);
112+
}
113+
else if (service.ImplementationInstance != null)
114+
{
115+
cache[service.ServiceType].Add(service.ImplementationInstance.GetType());
116+
}
117+
else if (service.ImplementationFactory is {} implFac
118+
&& implFac.Method.ReturnType.IsAssignableTo(service.ServiceType)
119+
&& implFac.Method.ReturnType.IsClass
120+
&& implFac.Method.ReturnType != service.ServiceType)
121+
{
122+
// For factory registrations, mark service type as occupied
123+
cache[service.ServiceType].Add(implFac.Method.ReturnType);
124+
}
125+
}
126+
127+
return cache;
128+
}
129+
130+
/// <summary>
131+
/// Checks if a handler registration already exists in the service collection.
132+
/// </summary>
133+
/// <param name="existingRegistrations">Cache of existing registrations</param>
134+
/// <param name="serviceType">The service interface type</param>
135+
/// <param name="implementationType">The concrete implementation type</param>
136+
/// <returns>True if the handler is already registered</returns>
137+
private static bool IsHandlerAlreadyRegistered(
138+
global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>> existingRegistrations,
139+
global::System.Type serviceType,
140+
global::System.Type implementationType)
141+
{
142+
return existingRegistrations.ContainsKey(serviceType) &&
143+
existingRegistrations[serviceType].Contains(implementationType);
144+
}
81145
}
82146
}
83147

0 commit comments

Comments
 (0)