File: net\System\Net\WebSockets\ClientWebSocket.cs
Project: ndp\fx\src\System.csproj (System)
//------------------------------------------------------------------------------
// <copyright file="ClientWebSocket.cs" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
//------------------------------------------------------------------------------
 
namespace System.Net.WebSockets
{
    using System;
    using System.Collections.Generic;
    using System.Diagnostics.CodeAnalysis;
    using System.Diagnostics.Contracts;
    using System.Net;
    using System.Security.Cryptography.X509Certificates;
    using System.Threading;
    using System.Threading.Tasks;
 
    public sealed class ClientWebSocket : WebSocket
    {
        private readonly ClientWebSocketOptions options;
        private WebSocket innerWebSocket;
        private readonly CancellationTokenSource cts;
 
        // Stages of this class. Interlocked doesn't support enums.
        private int state;
        private const int created = 0;
        private const int connecting = 1;
        private const int connected = 2;
        private const int disposed = 3;
 
        static ClientWebSocket() 
        {
            // Register ws: and wss: with WebRequest.Register so that WebRequest.Create returns a 
            // WebSocket capable HttpWebRequest instance.
            WebSocket.RegisterPrefixes();
        }
        
        public ClientWebSocket() 
        {
            if (Logging.On) Logging.Enter(Logging.WebSockets, this, ".ctor", null);
 
            if (!WebSocketProtocolComponent.IsSupported)
            {
                WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC();
            }
 
            state = created;
            options = new ClientWebSocketOptions();
            cts = new CancellationTokenSource();
 
            if (Logging.On) Logging.Exit(Logging.WebSockets, this, ".ctor", null);
        }
 
        #region Properties
 
        public ClientWebSocketOptions Options { get { return options; } }
 
        public override WebSocketCloseStatus? CloseStatus
        {
            get 
            {
                if (innerWebSocket != null)
                {
                    return innerWebSocket.CloseStatus;
                }
                return null; 
            }
        }
 
        public override string CloseStatusDescription
        {
            get
            {
                if (innerWebSocket != null)
                {
                    return innerWebSocket.CloseStatusDescription;
                }
                return null;
            }
        }
 
        public override string SubProtocol
        {
            get
            {
                if (innerWebSocket != null)
                {
                    return innerWebSocket.SubProtocol;
                }
                return null;
            }
        }
 
        public override WebSocketState State
        {
            get 
            {
                // state == Connected or Disposed
                if (innerWebSocket != null)
                {
                    return innerWebSocket.State;
                }
                switch (state)
                {
                    case created:
                        return WebSocketState.None;
                    case connecting:
                        return WebSocketState.Connecting;
                    case disposed: // We only get here if disposed before connecting
                        return WebSocketState.Closed;
                    default:
                        Contract.Assert(false, "NotImplemented: " + state);
                        return WebSocketState.Closed;
                }
            }
        }
 
        #endregion Properties
 
        public Task ConnectAsync(Uri uri, CancellationToken cancellationToken)
        {
            if (uri == null)
            {
                throw new ArgumentNullException("uri");
            }
            if (!uri.IsAbsoluteUri)
            {
                throw new ArgumentException(SR.GetString(SR.net_uri_NotAbsolute), "uri");
            }
            if (uri.Scheme != Uri.UriSchemeWs && uri.Scheme != Uri.UriSchemeWss)
            {
                throw new ArgumentException(SR.GetString(SR.net_WebSockets_Scheme), "uri");
            }
 
            // Check that we have not started already
            int priorState = Interlocked.CompareExchange(ref state, connecting, created);
            if (priorState == disposed)
            {
                throw new ObjectDisposedException(GetType().FullName);
            }
            else if (priorState != created)
            {
                throw new InvalidOperationException(SR.GetString(SR.net_WebSockets_AlreadyStarted));
            }
            options.SetToReadOnly();
 
            return ConnectAsyncCore(uri, cancellationToken);
        }
 
        private async Task ConnectAsyncCore(Uri uri, CancellationToken cancellationToken)
        {
            HttpWebResponse response = null;
            CancellationTokenRegistration connectCancellation = new CancellationTokenRegistration();
            // Any errors from here on out are fatal and this instance will be disposed.
            try
            {
                HttpWebRequest request = CreateAndConfigureRequest(uri);
                if (Logging.On) Logging.Associate(Logging.WebSockets, this, request);
 
                connectCancellation = cancellationToken.Register(AbortRequest, request, false);
 
                response = await request.GetResponseAsync().SuppressContextFlow() as HttpWebResponse;
                Contract.Assert(response != null, "Not an HttpWebResponse");
 
                if (Logging.On) Logging.Associate(Logging.WebSockets, this, response);
 
                string subprotocol = ValidateResponse(request, response);
 
                innerWebSocket = WebSocket.CreateClientWebSocket(response.GetResponseStream(), subprotocol,
                    options.ReceiveBufferSize, options.SendBufferSize, options.KeepAliveInterval, false,
                    options.GetOrCreateBuffer());
 
                if (Logging.On) Logging.Associate(Logging.WebSockets, this, innerWebSocket);
 
                // Change internal state to 'connected' to enable the other methods
                if (Interlocked.CompareExchange(ref state, connected, connecting) != connecting)
                {
                    // Aborted/Disposed during connect.
                    throw new ObjectDisposedException(GetType().FullName);
                }
            }
            catch (WebException ex)
            {
                ConnectExceptionCleanup(response);
                WebSocketException wex = new WebSocketException(SR.GetString(SR.net_webstatus_ConnectFailure), ex);
                if (Logging.On) Logging.Exception(Logging.WebSockets, this, "ConnectAsync", wex);
                throw wex;
            }
            catch (Exception ex)
            {
                ConnectExceptionCleanup(response);
                if (Logging.On) Logging.Exception(Logging.WebSockets, this, "ConnectAsync", ex);
                throw;
            }
            finally
            {
                // We successfully connected (or failed trying), disengage from this token.  
                // Otherwise any timeout/cancellation would apply to the full session.
                // In the failure case we need to release the reference to HWR.
                connectCancellation.Dispose();
            }
        }
 
        private void ConnectExceptionCleanup(HttpWebResponse response)
        {
            Dispose();
            if (response != null)
            {
                response.Dispose();
            }
        }
 
        private HttpWebRequest CreateAndConfigureRequest(Uri uri)
        {
            HttpWebRequest request = WebRequest.Create(uri) as HttpWebRequest;
            if (request == null)
            {
                throw new InvalidOperationException(SR.GetString(SR.net_WebSockets_InvalidRegistration));
            }
 
            // Request Headers
            foreach (string key in options.RequestHeaders.Keys)
            {
                request.Headers.Add(key, options.RequestHeaders[key]);
            }
 
            // SubProtocols
            if (options.RequestedSubProtocols.Count > 0)
            {
                request.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol,
                    string.Join(", ", options.RequestedSubProtocols));
            }
 
            // Creds
            if (options.UseDefaultCredentials)
            {
                request.UseDefaultCredentials = true;
            }
            else if (options.Credentials != null)
            {
                request.Credentials = options.Credentials;
            }
 
            // Certs
            if (options.InternalClientCertificates != null)
            {
                request.ClientCertificates = options.InternalClientCertificates;
            }
 
            request.Proxy = options.Proxy;
            request.CookieContainer = options.Cookies;
 
            // For Abort/Dispose.  Calling Abort on the request at any point will close the connection.
            cts.Token.Register(AbortRequest, request, false);
 
            return request;
        }
        
        // Validate the response headers and return the sub-protocol.
        private string ValidateResponse(HttpWebRequest request, HttpWebResponse response)
        {
            // 101
            if (response.StatusCode != HttpStatusCode.SwitchingProtocols)
            {
                throw new WebSocketException(SR.GetString(SR.net_WebSockets_Connect101Expected, 
                    (int)response.StatusCode));
            }
 
            // Upgrade: websocket
            string upgradeHeader = response.Headers[HttpKnownHeaderNames.Upgrade];
            if (!string.Equals(upgradeHeader, WebSocketHelpers.WebSocketUpgradeToken, 
                StringComparison.OrdinalIgnoreCase))
            {
                throw new WebSocketException(SR.GetString(SR.net_WebSockets_InvalidResponseHeader, 
                    HttpKnownHeaderNames.Upgrade, upgradeHeader));
            }
 
            // Connection: Upgrade
            string connectionHeader = response.Headers[HttpKnownHeaderNames.Connection];
            if (!string.Equals(connectionHeader, HttpKnownHeaderNames.Upgrade,
                StringComparison.OrdinalIgnoreCase))
            {
                throw new WebSocketException(SR.GetString(SR.net_WebSockets_InvalidResponseHeader,
                    HttpKnownHeaderNames.Connection, connectionHeader));
            }
 
            // Sec-WebSocket-Accept derived from request Sec-WebSocket-Key
            string websocketAcceptHeader = response.Headers[HttpKnownHeaderNames.SecWebSocketAccept];
            string expectedAcceptHeader = WebSocketHelpers.GetSecWebSocketAcceptString(
                request.Headers[HttpKnownHeaderNames.SecWebSocketKey]);
            if (!string.Equals(websocketAcceptHeader, expectedAcceptHeader, StringComparison.OrdinalIgnoreCase))
            {
                throw new WebSocketException(SR.GetString(SR.net_WebSockets_InvalidResponseHeader,
                    HttpKnownHeaderNames.SecWebSocketAccept, websocketAcceptHeader));
            }
 
            // Sec-WebSocket-Protocol matches one from request
            // A missing header is ok.  It's also ok if the client didn't specify any.
            string subProtocol = response.Headers[HttpKnownHeaderNames.SecWebSocketProtocol];
            if (!string.IsNullOrWhiteSpace(subProtocol) && options.RequestedSubProtocols.Count > 0)
            {
                bool foundMatch = false;
                foreach (string requestedSubProtocol in options.RequestedSubProtocols)
                {
                    if (string.Equals(requestedSubProtocol, subProtocol, StringComparison.OrdinalIgnoreCase))
                    {
                        foundMatch = true;
                        break;
                    }
                }
                if (!foundMatch)
                {
                    throw new WebSocketException(SR.GetString(SR.net_WebSockets_AcceptUnsupportedProtocol,
                        string.Join(", ", options.RequestedSubProtocols), subProtocol));
                }
            }
 
            return string.IsNullOrWhiteSpace(subProtocol) ? null : subProtocol; // May be null or valid.
        }
 
        public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, 
            CancellationToken cancellationToken)
        {
            ThrowIfNotConnected();
            return innerWebSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
        }
 
        public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, 
            CancellationToken cancellationToken)
        {
            ThrowIfNotConnected();
            return innerWebSocket.ReceiveAsync(buffer, cancellationToken);
        }
 
        public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, 
            CancellationToken cancellationToken)
        {
            ThrowIfNotConnected();
            return innerWebSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
        }
 
        public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, 
            CancellationToken cancellationToken)
        {
            ThrowIfNotConnected();
            return innerWebSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
        }
 
        public override void Abort()
        {
            if (state == disposed)
            {
                return;
            }
            if (innerWebSocket != null)
            {
                innerWebSocket.Abort();
            }
            Dispose();
        }
 
        private void AbortRequest(object obj)
        {
            HttpWebRequest request = (HttpWebRequest)obj;
            request.Abort();
        }
 
        public override void Dispose()
        {
            int priorState = Interlocked.Exchange(ref state, disposed);
            if (priorState == disposed)
            {
                // No cleanup required.
                return;
            }
            cts.Cancel(false);
            cts.Dispose();
            if (innerWebSocket != null)
            {
                innerWebSocket.Dispose();
            }
        }
 
        private void ThrowIfNotConnected()
        {
            if (state == disposed)
            {
                throw new ObjectDisposedException(GetType().FullName);
            }
            else if (state != connected)
            {
                throw new InvalidOperationException(SR.GetString(SR.net_WebSockets_NotConnected));
            }
        }
    }
 
    public sealed class ClientWebSocketOptions
    {
        private bool isReadOnly; // After ConnectAsync is called the options cannot be modified.
        private readonly IList<string> requestedSubProtocols;
        private readonly WebHeaderCollection requestHeaders;
        private TimeSpan keepAliveInterval;
        private int receiveBufferSize;
        private int sendBufferSize;
        private ArraySegment<byte>? buffer;
        private bool useDefaultCredentials;
        private ICredentials credentials;
        private IWebProxy proxy;
        private X509CertificateCollection clientCertificates;
        private CookieContainer cookies;
 
        internal ClientWebSocketOptions()
        {
            requestedSubProtocols = new List<string>();
            requestHeaders = new WebHeaderCollection(WebHeaderCollectionType.HttpWebRequest);
            Proxy = WebRequest.DefaultWebProxy;
            receiveBufferSize = WebSocketHelpers.DefaultReceiveBufferSize;
            sendBufferSize = WebSocketHelpers.DefaultClientSendBufferSize;
            keepAliveInterval = WebSocket.DefaultKeepAliveInterval;
        }
 
        #region HTTP Settings
 
        // Note that some headers are restricted like Host.
        public void SetRequestHeader(string headerName, string headerValue)
        {
            ThrowIfReadOnly();
            // WebHeadersColection performs the validation
            requestHeaders.Set(headerName, headerValue);
        }
 
        internal WebHeaderCollection RequestHeaders { get { return requestHeaders; } }
 
        public bool UseDefaultCredentials
        {
            get
            {
                return useDefaultCredentials;
            }
            set
            {
                ThrowIfReadOnly();
                useDefaultCredentials = value;
            }
        }
 
        public ICredentials Credentials
        {
            get
            {
                return credentials;
            }
            set
            {
                ThrowIfReadOnly();
                credentials = value;
            }
        }
 
        public IWebProxy Proxy
        {
            get
            {
                return proxy;
            }
            set
            {
                ThrowIfReadOnly();
                proxy = value;
            }
        }
 
        [SuppressMessage("Microsoft.Usage", "CA2227:CollectionPropertiesShouldBeReadOnly", 
            Justification = "This collectin will be handed off directly to HttpWebRequest.")]
        public X509CertificateCollection ClientCertificates
        {
            get
            {
                if (clientCertificates == null)
                {
                    clientCertificates = new X509CertificateCollection();
                }
                return clientCertificates;
            }
            set
            {
                ThrowIfReadOnly();
                if (value == null)
                {
                    throw new ArgumentNullException("value");
                }
                clientCertificates = value;
            }
        }
 
        internal X509CertificateCollection InternalClientCertificates { get { return clientCertificates; } }
 
        public CookieContainer Cookies
        {
            get
            {
                return cookies;
            }
            set
            {
                ThrowIfReadOnly();
                cookies = value;
            }
        }
 
        #endregion HTTP Settings
 
        #region WebSocket Settings
 
        public void SetBuffer(int receiveBufferSize, int sendBufferSize)
        {
            ThrowIfReadOnly();
            WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, sendBufferSize);
 
            this.buffer = null;
            this.receiveBufferSize = receiveBufferSize;
            this.sendBufferSize = sendBufferSize;
        }
 
        public void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer)
        {
            ThrowIfReadOnly();
            WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, sendBufferSize);
            WebSocketHelpers.ValidateArraySegment(buffer, "buffer");
            WebSocketBuffer.Validate(buffer.Count, receiveBufferSize, sendBufferSize, false);
 
            this.receiveBufferSize = receiveBufferSize;
            this.sendBufferSize = sendBufferSize;
            
            // Only full-trust applications can specify their own buffer to be used as the
            // internal buffer for the WebSocket object.  This is because the contents of the
            // buffer are used internally by the WebSocket as it marshals data with embedded
            // pointers to native code.  A malicious application could use this to corrupt
            // native memory.
            if (AppDomain.CurrentDomain.IsFullyTrusted)
            {
                this.buffer = buffer;
            }
            else
            {
                // We silently ignore the passed in buffer and will create an internal
                // buffer later.
                this.buffer = null;
            }
        }
 
        internal int ReceiveBufferSize { get { return receiveBufferSize; } }
 
        internal int SendBufferSize { get { return sendBufferSize; } }
 
        internal ArraySegment<byte> GetOrCreateBuffer()
        {
            if (!buffer.HasValue)
            {
                buffer = WebSocket.CreateClientBuffer(receiveBufferSize, sendBufferSize);
            }
            return buffer.Value;
        }
 
        public void AddSubProtocol(string subProtocol) 
        {
            ThrowIfReadOnly();
            WebSocketHelpers.ValidateSubprotocol(subProtocol);
            // Duplicates not allowed.
            foreach (string item in requestedSubProtocols)
            {
                if (string.Equals(item, subProtocol, StringComparison.OrdinalIgnoreCase))
                {
                    throw new ArgumentException(SR.GetString(SR.net_WebSockets_NoDuplicateProtocol, subProtocol), 
                        "subProtocol");
                }
            }
            requestedSubProtocols.Add(subProtocol);
        }
 
        internal IList<string> RequestedSubProtocols { get { return requestedSubProtocols; } }
 
        public TimeSpan KeepAliveInterval 
        {
            get
            {
                return keepAliveInterval;
            }
            set
            {
                ThrowIfReadOnly();
                if (value < Timeout.InfiniteTimeSpan)
                {
                    throw new ArgumentOutOfRangeException("value", value,
                        SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall,
                        Timeout.InfiniteTimeSpan.ToString()));
                }
                keepAliveInterval = value;
            }
        }
 
        #endregion WebSocket settings
 
        #region Helpers
 
        internal void SetToReadOnly()
        {
            Contract.Assert(!isReadOnly, "Already set");
            isReadOnly = true;
        }
 
        private void ThrowIfReadOnly()
        {
            if (isReadOnly)
            {
                throw new InvalidOperationException(SR.GetString(SR.net_WebSockets_AlreadyStarted));
            }
        }
 
        #endregion Helpers
    }
}