ryujinx/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
TSRBerry ba5c0cf5d8
Bsd: Implement Select (#4017)
* bsd: Add gdkchan's Select implementation

Co-authored-by: TSRBerry <20988865+tsrberry@users.noreply.github.com>

* bsd: Fix Select() causing a crash with an ArgumentException

.NET Sockets have to be used for the Select() call

* bsd: Make Select more generic

* bsd: Adjust namespaces and remove unused imports

* bsd: Fix NullReferenceException in Select

Co-authored-by: gdkchan <gab.dark.100@gmail.com>
2022-12-12 14:59:31 +01:00

252 lines
7.4 KiB
C#

using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
using System.IO;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
class SslManagedSocketConnection : ISslConnectionBase
{
public int SocketFd { get; }
public ISocket Socket { get; }
private BsdContext _bsdContext;
private SslVersion _sslVersion;
private SslStream _stream;
private bool _isBlockingSocket;
private int _previousReadTimeout;
public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
{
_bsdContext = bsdContext;
_sslVersion = sslVersion;
SocketFd = socketFd;
Socket = socket;
}
private void StartSslOperation()
{
// Save blocking state
_isBlockingSocket = Socket.Blocking;
// Force blocking for SslStream
Socket.Blocking = true;
}
private void EndSslOperation()
{
// Restore blocking state
Socket.Blocking = _isBlockingSocket;
}
private void StartSslReadOperation()
{
StartSslOperation();
if (!_isBlockingSocket)
{
_previousReadTimeout = _stream.ReadTimeout;
_stream.ReadTimeout = 1;
}
}
private void EndSslReadOperation()
{
if (!_isBlockingSocket)
{
_stream.ReadTimeout = _previousReadTimeout;
}
EndSslOperation();
}
// NOTE: We silence warnings about TLS 1.0 and 1.1 as games will likely use it.
#pragma warning disable SYSLIB0039
private static SslProtocols TranslateSslVersion(SslVersion version)
{
switch (version & SslVersion.VersionMask)
{
case SslVersion.Auto:
return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
case SslVersion.TlsV10:
return SslProtocols.Tls;
case SslVersion.TlsV11:
return SslProtocols.Tls11;
case SslVersion.TlsV12:
return SslProtocols.Tls12;
case SslVersion.TlsV13:
return SslProtocols.Tls13;
default:
throw new NotImplementedException(version.ToString());
}
}
#pragma warning restore SYSLIB0039
public ResultCode Handshake(string hostName)
{
StartSslOperation();
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();
return ResultCode.Success;
}
public ResultCode Peek(out int peekCount, Memory<byte> buffer)
{
// NOTE: We cannot support that on .NET SSL API.
// As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
peekCount = -1;
return ResultCode.WouldBlock;
}
public int Pending()
{
// Unsupported
return 0;
}
private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
{
switch (error)
{
case WsaError.WSAETIMEDOUT:
resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
return true;
case WsaError.WSAECONNABORTED:
resultCode = ResultCode.ConnectionAbort;
return true;
case WsaError.WSAECONNRESET:
resultCode = ResultCode.ConnectionReset;
return true;
default:
resultCode = ResultCode.Success;
return false;
}
}
public ResultCode Read(out int readCount, Memory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectRead))
{
readCount = -1;
return ResultCode.WouldBlock;
}
StartSslReadOperation();
try
{
readCount = _stream.Read(buffer.Span);
}
catch (IOException exception)
{
readCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslReadOperation();
}
return ResultCode.Success;
}
public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
{
if (!Socket.Poll(0, SelectMode.SelectWrite))
{
writtenCount = 0;
return ResultCode.WouldBlock;
}
StartSslOperation();
try
{
_stream.Write(buffer.Span);
}
catch (IOException exception)
{
writtenCount = -1;
if (exception.InnerException is SocketException socketException)
{
WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
{
return result;
}
else
{
throw socketException;
}
}
else
{
throw exception;
}
}
finally
{
EndSslOperation();
}
// .NET API doesn't provide the size written, assume all written.
writtenCount = buffer.Length;
return ResultCode.Success;
}
public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
{
byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
storageSize = (uint)rawCertData.Length;
certificateCount = 1;
if (rawCertData.Length > certificates.Length)
{
return ResultCode.CertBufferTooSmall;
}
rawCertData.CopyTo(certificates);
return ResultCode.Success;
}
public void Dispose()
{
_bsdContext.CloseFileDescriptor(SocketFd);
}
}
}