Supercell.GUT/Supercell.GUT.Server/Protocol/Messaging.cs

125 lines
3.7 KiB
C#
Raw Normal View History

using Microsoft.Extensions.Logging;
using Supercell.GUT.Server.Network.Connection;
using Supercell.GUT.Titan.Encryption;
using Supercell.GUT.Titan.Logic.Message;
namespace Supercell.GUT.Server.Protocol;
internal class Messaging : IConnectionListener
{
private const int HeaderSize = 7;
private readonly ILogger _logger;
private readonly LogicMessageFactory _factory;
private IConnectionListener.SendCallback? _sendCallback;
private IConnectionListener.ReceiveCallback? _receiveCallback;
private RC4Encrypter? _encrypter;
private RC4Encrypter? _decrypter;
public Messaging(LogicMessageFactory factory, ILogger<Messaging> logger)
{
_logger = logger;
_factory = factory;
}
public void InitEncryption(string key, string nonce)
{
_decrypter = new RC4Encrypter(key, nonce);
_encrypter = new RC4Encrypter(key, nonce);
}
public async ValueTask<int> OnReceive(Memory<byte> buffer, int size)
{
int consumedBytes = 0;
while (size >= HeaderSize)
{
ReadHeader(buffer.Span, out int messageType, out int length, out int messageVersion);
if (size < HeaderSize + length) break;
size -= length + HeaderSize;
consumedBytes += length + HeaderSize;
byte[] encryptedBytes = buffer.Slice(HeaderSize, length).ToArray();
buffer = buffer[consumedBytes..];
byte[] encodingBytes;
_decrypter?.Encrypt(encryptedBytes);
encodingBytes = encryptedBytes;
int encodingLength = length;
PiranhaMessage? message = _factory.CreateMessageByType(messageType);
if (message == null)
{
_logger.LogWarning("Ignoring message of unknown type {messageType}", messageType);
continue;
}
message.MessageVersion = (short)messageVersion;
message.ByteStream.SetByteArray(encodingBytes, encodingLength);
message.Decode();
await _receiveCallback!(message);
}
return consumedBytes;
}
public async Task Send(PiranhaMessage message)
{
if (message.ByteStream.Offset == 0) message.Encode();
byte[] encodingBytes = message.ByteStream.ByteArray!.Take(message.ByteStream.Offset).ToArray();
_encrypter?.Encrypt(encodingBytes);
byte[] fullPayload = new byte[encodingBytes.Length + HeaderSize];
WriteHeader(fullPayload, message, encodingBytes.Length);
encodingBytes.CopyTo(fullPayload, HeaderSize);
await _sendCallback!(fullPayload);
}
public IConnectionListener.SendCallback OnSend
{
set
{
_sendCallback = value;
}
}
public IConnectionListener.ReceiveCallback RecvCallback
{
set
{
_receiveCallback = value;
}
}
private static void ReadHeader(ReadOnlySpan<byte> buffer, out int messageType, out int encodingLength, out int messageVersion)
{
messageType = buffer[0] << 8 | buffer[1];
encodingLength = buffer[2] << 16 | buffer[3] << 8 | buffer[4];
messageVersion = buffer[5] << 8 | buffer[6];
}
private static void WriteHeader(Span<byte> buffer, PiranhaMessage message, int length)
{
int messageType = message.GetMessageType();
int messageVersion = message.MessageVersion;
buffer[0] = (byte)(messageType >> 8);
buffer[1] = (byte)messageType;
buffer[2] = (byte)(length >> 16);
buffer[3] = (byte)(length >> 8);
buffer[4] = (byte)length;
buffer[5] = (byte)(messageVersion >> 8);
buffer[6] = (byte)messageVersion;
}
}