using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Net.Sockets; using System.Net; using Esiur.Misc; using Esiur.Engine; using System.Threading; using System.Net.Security; using System.Security.Cryptography.X509Certificates; using Esiur.Resource; using System.Threading.Tasks; using Esiur.Data; namespace Esiur.Net.Sockets { public class SSLSocket : ISocket { Socket sock; byte[] receiveBuffer; NetworkBuffer receiveNetworkBuffer = new NetworkBuffer(); object sendLock = new object(); Queue sendBufferQueue = new Queue(); bool asyncSending; SocketState state = SocketState.Initial; public event ISocketReceiveEvent OnReceive; public event ISocketConnectEvent OnConnect; public event ISocketCloseEvent OnClose; public event DestroyedEvent OnDestroy; SslStream ssl; X509Certificate2 cert; bool server; string hostname; private void Connected(Task t) { if (server) { ssl.AuthenticateAsServerAsync(cert).ContinueWith(Authenticated); } else { ssl.AuthenticateAsClientAsync(hostname).ContinueWith(Authenticated); } } public bool Connect(string hostname, ushort port) { try { this.hostname = hostname; server = false; state = SocketState.Connecting; sock.ConnectAsync(hostname, port).ContinueWith(Connected); return true; } catch { return false; } } private void DataSent(Task task) { try { if (sendBufferQueue.Count > 0) { byte[] data = sendBufferQueue.Dequeue(); lock (sendLock) ssl.WriteAsync(data, 0, data.Length).ContinueWith(DataSent); } else { asyncSending = false; } } catch (Exception ex) { if (state != SocketState.Closed && !sock.Connected) { state = SocketState.Terminated; Close(); } asyncSending = false; Global.Log("SSLSocket", LogType.Error, ex.ToString()); } } public IPEndPoint LocalEndPoint { get { return (IPEndPoint)sock.LocalEndPoint; } } public SSLSocket() { sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); receiveBuffer = new byte[sock.ReceiveBufferSize]; } public SSLSocket(IPEndPoint localEndPoint, X509Certificate2 certificate) { // create the socket sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); state = SocketState.Listening; // bind sock.Bind(localEndPoint); // start listening sock.Listen(UInt16.MaxValue); cert = certificate; } public IPEndPoint RemoteEndPoint { get { return (IPEndPoint)sock.RemoteEndPoint; } } public SocketState State { get { return state; } } public SSLSocket(Socket Socket, X509Certificate2 certificate, bool authenticateAsServer) { cert = certificate; sock = Socket; receiveBuffer = new byte[sock.ReceiveBufferSize]; ssl = new SslStream(new NetworkStream(sock)); server = authenticateAsServer; } public void Close() { if (state != SocketState.Closed && state != SocketState.Terminated) state = SocketState.Closed; if (sock.Connected) { try { sock.Shutdown(SocketShutdown.Both); } catch { state = SocketState.Terminated; } } sock.Shutdown(SocketShutdown.Both); OnClose?.Invoke(); } public void Send(byte[] message) { Send(message, 0, message.Length); } public void Send(byte[] message, int offset, int size) { lock (sendLock) { if (asyncSending) { sendBufferQueue.Enqueue(message.Clip((uint)offset, (uint)size)); } else { asyncSending = true; ssl.WriteAsync(message, offset, size).ContinueWith(DataSent); } } } void Authenticated(Task task) { try { state = SocketState.Established; OnConnect?.Invoke(); if (!server) Begin(); } catch (Exception ex) { state = SocketState.Terminated; Close(); Global.Log(ex); } } private void DataReceived(Task task) { try { // SocketError err; if (state == SocketState.Closed || state == SocketState.Terminated) return; if (task.Result <= 0) { Close(); return; } receiveNetworkBuffer.Write(receiveBuffer, 0, (uint)task.Result); OnReceive?.Invoke(receiveNetworkBuffer); if (state == SocketState.Established) ssl.ReadAsync(receiveBuffer, 0, receiveBuffer.Length).ContinueWith(DataReceived); } catch (Exception ex) { if (state != SocketState.Closed && !sock.Connected) { state = SocketState.Terminated; Close(); } Global.Log("SSLSocket", LogType.Error, ex.ToString()); } } public bool Begin() { if (state == SocketState.Established) { ssl.ReadAsync(receiveBuffer, 0, receiveBuffer.Length).ContinueWith(DataReceived); return true; } else return false; } public bool Trigger(ResourceTrigger trigger) { return true; } public void Destroy() { Close(); OnDestroy?.Invoke(this); } public AsyncReply Accept() { var reply = new AsyncReply(); try { sock.AcceptAsync().ContinueWith((x) => { try { reply.Trigger(new SSLSocket(x.Result, cert, true)); } catch { reply.Trigger(null); } }, null); } catch { state = SocketState.Terminated; return null; } return reply; } } }