using System.Collections.Immutable; using System.Linq.Expressions; using System.Reflection; using GameServer.Controllers.Attributes; using GameServer.Network.Messages; using GameServer.Systems.Event; using Google.Protobuf; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Protocol; namespace GameServer.Controllers.Factory; internal delegate Task GameEventHandler(IServiceProvider serviceProvider); internal class EventHandlerFactory { private readonly ImmutableDictionary _rpcHandlers; private readonly ImmutableDictionary _pushHandlers; private readonly ImmutableDictionary> _eventHandlers; public EventHandlerFactory(ILogger logger) { IEnumerable controllerTypes = Assembly.GetExecutingAssembly().GetTypes() .Where(t => t.IsAssignableTo(typeof(Controller)) && !t.IsAbstract); _rpcHandlers = RegisterRpcHandlers(controllerTypes); _pushHandlers = RegisterPushHandlers(controllerTypes); _eventHandlers = RegisterEventHandlers(controllerTypes); logger.LogInformation("Registered {rpc_count} rpc handlers, {push_count} push handlers", _rpcHandlers.Count, _pushHandlers.Count); } public RpcHandler? GetRpcHandler(MessageId messageId) { _rpcHandlers.TryGetValue(messageId, out RpcHandler? handler); return handler; } public PushHandler? GetPushHandler(MessageId messageId) { _pushHandlers.TryGetValue(messageId, out PushHandler? handler); return handler; } public IEnumerable GetEventHandlers(GameEventType eventType) { if (!_eventHandlers.TryGetValue(eventType, out List? handlers)) return []; return handlers; } private static ImmutableDictionary> RegisterEventHandlers(IEnumerable controllerTypes) { var builder = ImmutableDictionary.CreateBuilder>(); MethodInfo getServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), [typeof(IServiceProvider)])!; MethodInfo taskFromResultMethod = typeof(Task).GetMethod(nameof(Task.FromResult))!.MakeGenericMethod(typeof(ResponseMessage)); foreach (Type type in controllerTypes) { IEnumerable methods = type.GetMethods() .Where(method => method.GetCustomAttribute() != null && (method.ReturnType == typeof(Task) || method.ReturnType == typeof(void))); foreach (MethodInfo method in methods) { GameEventAttribute attribute = method.GetCustomAttribute()!; ParameterExpression serviceProviderParam = Expression.Parameter(typeof(IServiceProvider)); MethodCallExpression getServiceCall = Expression.Call(getServiceMethod.MakeGenericMethod(type), serviceProviderParam); Expression handlerCall = Expression.Call(getServiceCall, method, FetchArgumentsForMethod(method, serviceProviderParam, getServiceMethod)); if (method.ReturnType == typeof(void)) // Allow non-async methods as well handlerCall = Expression.Block(handlerCall, Expression.Constant(Task.CompletedTask)); Expression lambda = Expression.Lambda(handlerCall, serviceProviderParam); if (!builder.TryGetKey(attribute.Type, out _)) builder.Add(attribute.Type, []); builder[attribute.Type].Add(lambda.Compile()); } } return builder.ToImmutable(); } private static ImmutableDictionary RegisterRpcHandlers(IEnumerable controllerTypes) { var builder = ImmutableDictionary.CreateBuilder(); MethodInfo getServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), [typeof(IServiceProvider)])!; MethodInfo taskFromResultMethod = typeof(Task).GetMethod(nameof(Task.FromResult))!.MakeGenericMethod(typeof(ResponseMessage)); foreach (Type type in controllerTypes) { IEnumerable methods = type.GetMethods() .Where(method => method.GetCustomAttribute() != null && (method.ReturnType == typeof(Task) || method.ReturnType == typeof(ResponseMessage))); foreach (MethodInfo method in methods) { NetEventAttribute attribute = method.GetCustomAttribute()!; ParameterExpression serviceProviderParam = Expression.Parameter(typeof(IServiceProvider)); ParameterExpression dataParam = Expression.Parameter(typeof(ReadOnlySpan)); MethodCallExpression getServiceCall = Expression.Call(getServiceMethod.MakeGenericMethod(type), serviceProviderParam); Expression handlerCall = Expression.Call(getServiceCall, method, FetchArgumentsForMethod(method, serviceProviderParam, getServiceMethod, dataParam)); if (method.ReturnType == typeof(ResponseMessage)) // Allow non-async methods as well handlerCall = Expression.Call(taskFromResultMethod, handlerCall); Expression lambda = Expression.Lambda(handlerCall, serviceProviderParam, dataParam); builder.Add(attribute.MessageId, lambda.Compile()); } } return builder.ToImmutable(); } private static ImmutableDictionary RegisterPushHandlers(IEnumerable controllerTypes) { var builder = ImmutableDictionary.CreateBuilder(); MethodInfo getServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), [typeof(IServiceProvider)])!; MethodInfo taskFromResultMethod = typeof(Task).GetMethod(nameof(Task.FromResult))!.MakeGenericMethod(typeof(ResponseMessage)); foreach (Type type in controllerTypes) { IEnumerable methods = type.GetMethods() .Where(method => method.GetCustomAttribute() != null && (method.ReturnType == typeof(Task) || method.ReturnType == typeof(void))); foreach (MethodInfo method in methods) { NetEventAttribute attribute = method.GetCustomAttribute()!; ParameterExpression serviceProviderParam = Expression.Parameter(typeof(IServiceProvider)); ParameterExpression dataParam = Expression.Parameter(typeof(ReadOnlySpan)); MethodCallExpression getServiceCall = Expression.Call(getServiceMethod.MakeGenericMethod(type), serviceProviderParam); Expression handlerCall = Expression.Call(getServiceCall, method, FetchArgumentsForMethod(method, serviceProviderParam, getServiceMethod, dataParam)); if (method.ReturnType == typeof(void)) // Allow non-async methods as well handlerCall = Expression.Block(handlerCall, Expression.Constant(Task.CompletedTask)); Expression lambda = Expression.Lambda(handlerCall, serviceProviderParam, dataParam); builder.Add(attribute.MessageId, lambda.Compile()); } } return builder.ToImmutable(); } private static List FetchArgumentsForMethod(MethodInfo method, Expression serviceProviderParam, MethodInfo getServiceMethod, Expression? dataParam = null) { List arguments = []; foreach (ParameterInfo param in method.GetParameters()) { if (dataParam != null && param.ParameterType.IsAssignableTo(typeof(IMessage))) { PropertyInfo parser = (param.ParameterType.GetMember("Parser", BindingFlags.Static | BindingFlags.Public).Single() as PropertyInfo)!; MethodInfo parseFrom = parser.PropertyType.GetMethod(nameof(MessageParser.ParseFrom), [typeof(ReadOnlySpan)])!; arguments.Add(Expression.Call(Expression.Constant(parser.GetValue(null)), parseFrom, dataParam)); } else { arguments.Add(Expression.Call(getServiceMethod.MakeGenericMethod(param.ParameterType), serviceProviderParam)); } } return arguments; } }