using System;
using System.Buffers;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace KcpSharp
{
///
/// A Socket transport for upper-level connections.
///
///
public abstract class KcpSocketTransport : IKcpTransport, IDisposable where T : class, IKcpConversation
{
private readonly UdpClient _udp;
private readonly int _mtu;
private T? _connection;
private CancellationTokenSource? _cts;
private bool _disposed;
private int _handshakeSize;
private Func? _handshakeHandler;
///
/// Construct a socket transport with the specified socket and remote endpoint.
///
/// The socket instance.
/// The remote endpoint.
/// The maximum packet size that can be transmitted.
protected KcpSocketTransport(UdpClient udp, int mtu)
{
_udp = udp ?? throw new ArgumentNullException(nameof(udp));
_mtu = mtu;
if (mtu < 50)
{
throw new ArgumentOutOfRangeException(nameof(mtu));
}
}
///
/// Get the upper-level connection instace. If Start is not called or the transport is closed, will be thrown.
///
/// Start is not called or the transport is closed.
public T Connection => _connection ?? throw new InvalidOperationException();
///
/// Create the upper-level connection instance.
///
/// The upper-level connection instance.
protected abstract T Activate();
///
/// Allocate a block of memory used to receive from socket.
///
/// The minimum size of the buffer.
/// The allocated memory buffer.
protected virtual IMemoryOwner AllocateBuffer(int size)
{
#if NEED_POH_SHIM
return MemoryPool.Shared.Rent(size);
#else
return new ArrayMemoryOwner(GC.AllocateUninitializedArray(size, pinned: true));
#endif
}
///
/// Handle exception thrown when receiving from remote endpoint.
///
/// The exception thrown.
/// Whether error should be ignored.
protected virtual bool HandleException(Exception ex) => false;
///
/// Create the upper-level connection and start pumping packets from the socket to the upper-level connection.
///
public void Start()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(KcpSocketTransport));
}
if (_connection is not null)
{
throw new InvalidOperationException();
}
_connection = Activate();
if (_connection is null)
{
throw new InvalidOperationException();
}
_cts = new CancellationTokenSource();
RunReceiveLoop();
}
public void SetHandshakeHandler(int size, Func handshakeHandler)
{
_handshakeSize = size;
_handshakeHandler = handshakeHandler;
}
#if NEED_SOCKET_SHIM
///
public async ValueTask SendPacketAsync(Memory packet, CancellationToken cancellationToken = default)
{
if (_disposed)
{
return;
}
cancellationToken.ThrowIfCancellationRequested();
if (packet.Length > _mtu)
{
return;
}
byte[]? rentedArray = null;
if (!MemoryMarshal.TryGetArray(packet, out ArraySegment segment))
{
rentedArray = ArrayPool.Shared.Rent(packet.Length);
segment = new ArraySegment(rentedArray, 0, packet.Length);
packet.CopyTo(segment.AsMemory());
}
try
{
using var saea = new AwaitableSocketAsyncEventArgs();
saea.SetBuffer(segment.Array, segment.Offset, segment.Count);
saea.SocketFlags = SocketFlags.None;
saea.RemoteEndPoint = _endPoint;
if (_socket.SendToAsync(saea))
{
await saea.WaitAsync().ConfigureAwait(false);
saea.Reset();
}
if (saea.SocketError != SocketError.Success)
{
throw new SocketException((int)saea.SocketError);
}
}
finally
{
if (rentedArray is not null)
{
ArrayPool.Shared.Return(rentedArray);
}
}
}
private static async ValueTask ReceiveFromAsync(Socket socket, Memory buffer, EndPoint endPoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
byte[]? rentedArray = null;
if (!MemoryMarshal.TryGetArray(buffer, out ArraySegment segment))
{
rentedArray = ArrayPool.Shared.Rent(buffer.Length);
segment = new ArraySegment(rentedArray, 0, buffer.Length);
}
try
{
using var saea = new AwaitableSocketAsyncEventArgs();
saea.SetBuffer(segment.Array, segment.Offset, segment.Count);
saea.SocketFlags = SocketFlags.None;
saea.RemoteEndPoint = endPoint;
if (socket.SendToAsync(saea))
{
await saea.WaitAsync().ConfigureAwait(false);
saea.Reset();
}
if (saea.SocketError != SocketError.Success)
{
throw new SocketException((int)saea.SocketError);
}
if (rentedArray is not null)
{
segment.AsMemory().CopyTo(buffer);
}
return saea.BytesTransferred;
}
finally
{
if (rentedArray is not null)
{
ArrayPool.Shared.Return(rentedArray);
}
}
}
#else
///
public ValueTask SendPacketAsync(Memory packet, IPEndPoint endpoint, CancellationToken cancellationToken = default)
{
if (_disposed)
{
return default;
}
if (packet.Length > _mtu)
{
return default;
}
return new ValueTask(_udp.SendAsync(packet.ToArray(), endpoint, cancellationToken).AsTask());
}
#endif
private async void RunReceiveLoop()
{
CancellationToken cancellationToken = _cts?.Token ?? new CancellationToken(true);
IKcpConversation? connection = _connection;
if (connection is null || cancellationToken.IsCancellationRequested)
{
return;
}
using IMemoryOwner memoryOwner = AllocateBuffer(_mtu);
try
{
Memory memory = memoryOwner.Memory;
while (!cancellationToken.IsCancellationRequested)
{
int bytesReceived;
UdpReceiveResult result = default;
try
{
result = await _udp.ReceiveAsync(cancellationToken);
bytesReceived = result.Buffer.Length;
}
catch (Exception)
{
bytesReceived = 0;
}
if (bytesReceived != 0 && bytesReceived <= _mtu)
{
if (bytesReceived == _handshakeSize && _handshakeHandler != null)
await _handshakeHandler(result);
else
await connection.InputPakcetAsync(result.Buffer, cancellationToken).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException)
{
// Do nothing
}
catch (Exception ex)
{
HandleExceptionWrapper(ex);
}
}
private bool HandleExceptionWrapper(Exception ex)
{
bool result;
try
{
result = HandleException(ex);
}
catch
{
result = false;
}
_connection?.SetTransportClosed();
CancellationTokenSource? cts = Interlocked.Exchange(ref _cts, null);
if (cts is not null)
{
cts.Cancel();
cts.Dispose();
}
return result;
}
///
/// Dispose all the managed and the unmanaged resources used by this instance.
///
/// If managed resources should be disposed.
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
CancellationTokenSource? cts = Interlocked.Exchange(ref _cts, null);
if (cts is not null)
{
cts.Cancel();
cts.Dispose();
}
_connection?.Dispose();
}
_connection = null;
_cts = null;
_disposed = true;
}
}
///
/// Dispose the unmanaged resources used by this instance.
///
~KcpSocketTransport()
{
Dispose(disposing: false);
}
///
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}