File: System\ServiceModel\Channels\UdpSocketReceiveManager.cs
Project: ndp\cdf\src\NetFx40\System.ServiceModel.Channels\System.ServiceModel.Channels.csproj (System.ServiceModel.Channels)
//----------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//----------------------------------------------------------------
 
namespace System.ServiceModel.Channels
{
    using System;
    using System.Collections.Generic;
    using System.Net;
    using System.Net.Sockets;
    using System.Runtime;
    using System.Threading;
 
    sealed class UdpSocketReceiveManager
    {
        BufferManager bufferManager;
        Action<object> continueReceivingCallback;
        int maxPendingReceivesPerSocket;
        AsyncCallback onReceiveFrom;
        Action<object> onStartReceiving;
        int openCount;
        IUdpReceiveHandler receiveHandler;
        UdpSocket[] receiveSockets;
        Action onMessageDequeued;
        object thisLock;
        int messageBufferSize;
        ConnectionBufferPool receiveBufferPool;
 
        internal UdpSocketReceiveManager(UdpSocket[] receiveSockets, int maxPendingReceivesPerSocket, BufferManager bufferManager, IUdpReceiveHandler receiveHandler)
        {
            Fx.Assert(receiveSockets != null, "receiveSockets parameter is null");
            Fx.Assert(receiveSockets.Length > 0, "receiveSockets parameter is empty");
            Fx.Assert(maxPendingReceivesPerSocket > 0, "maxPendingReceivesPerSocket can't be <= 0");
            Fx.Assert(receiveHandler.MaxReceivedMessageSize > 0, "maxReceivedMessageSize must be > 0");
            Fx.Assert(bufferManager != null, "bufferManager argument should not be null");
            Fx.Assert(receiveHandler != null, "receiveHandler should not be null");
 
            this.receiveHandler = receiveHandler;
            this.thisLock = new object();
            this.bufferManager = bufferManager;
            this.receiveSockets = receiveSockets;
            this.maxPendingReceivesPerSocket = maxPendingReceivesPerSocket;
            this.messageBufferSize = UdpUtility.ComputeMessageBufferSize(receiveHandler.MaxReceivedMessageSize);
 
            int maxPendingReceives = maxPendingReceivesPerSocket * receiveSockets.Length;
            this.receiveBufferPool = new ConnectionBufferPool(this.messageBufferSize, maxPendingReceives);
        }
 
        bool IsDisposed
        {
            get
            {
                return this.openCount < 0;
            }
        }
 
        public void SetReceiveHandler(IUdpReceiveHandler handler)
        {
            Fx.Assert(handler != null, "IUdpReceiveHandler can't be null");
            Fx.Assert(handler.MaxReceivedMessageSize == this.receiveHandler.MaxReceivedMessageSize, "new receive handler's max message size doesn't match");
            Fx.Assert(this.openCount > 0, "SetReceiveHandler called on a closed UdpSocketReceiveManager");
            this.receiveHandler = handler;
        }
 
        public void Close()
        {
            lock (this.thisLock)
            {
                if (this.IsDisposed)
                {
                    return;
                }
 
                this.openCount--;
 
                if (this.openCount == 0)
                {
                    this.openCount = -1;
                    this.receiveBufferPool.Close();
                    this.bufferManager.Clear();
 
                    for (int i = 0; i < this.receiveSockets.Length; i++)
                    {
                        this.receiveSockets[i].Close();
                    }
                }
            }
        }
 
        public void Open()
        {
            lock (this.thisLock)
            {
                ThrowIfDisposed();
 
                this.openCount++;
 
                if (this.openCount == 1)
                {
                    for (int i = 0; i < this.receiveSockets.Length; i++)
                    {
                        this.receiveSockets[i].Open();
                    }
 
                    this.onMessageDequeued = new Action(OnMessageDequeued);
                    this.onReceiveFrom = Fx.ThunkCallback(new AsyncCallback(OnReceiveFrom));
                    this.continueReceivingCallback = new Action<object>(ContinueReceiving);
                }
            }
 
 
            try
            {
                if (Thread.CurrentThread.IsThreadPoolThread)
                {
                    EnsureReceiving();
                }
                else
                {
                    if (this.onStartReceiving == null)
                    {
                        this.onStartReceiving = new Action<object>(OnStartReceiving);
                    }
 
                    ActionItem.Schedule(this.onStartReceiving, this);
                }
            }
            catch (Exception ex)
            {
                if (!TryHandleException(ex))
                {
                    throw;
                }
            }
        }
 
        static void OnStartReceiving(object state)
        {
            UdpSocketReceiveManager thisPtr = (UdpSocketReceiveManager)state;
 
            try
            {
                if (thisPtr.IsDisposed)
                {
                    return;
                }
 
                thisPtr.EnsureReceiving();
            }
            catch (Exception ex)
            {
                if (!thisPtr.TryHandleException(ex))
                {
                    throw;
                }
            }
        }
 
        void OnMessageDequeued()
        {
            try
            {
                EnsureReceiving();
            }
            catch (Exception ex)
            {
                if (!TryHandleException(ex))
                {
                    throw;
                }
            }
        }
 
        void ContinueReceiving(object socket)
        {            
            try
            {
                while (StartAsyncReceive(socket as UdpSocket))
                {
                    Fx.Assert(Thread.CurrentThread.IsThreadPoolThread, "Receive loop is running on a non-threadpool thread.  If this thread disappears while a completion port operation is outstanding, then the operation will get canceled.");
                }
            }
            catch (Exception ex)
            {
                if (!TryHandleException(ex))
                {
                    throw;
                }
            }
        }
 
        void OnReceiveFrom(IAsyncResult result)
        {
            if (result.CompletedSynchronously)
            {
                return;
            }
 
            UdpSocketReceiveState state = (UdpSocketReceiveState)result.AsyncState;
 
            ArraySegment<byte> messageBytes;
            bool continueReceiving = true;
 
            try
            {
                lock (this.thisLock)
                {
                    if (this.IsDisposed)
                    {
                        return;
                    }
 
                    messageBytes = EndReceiveFrom(result, state);
                }
                messageBytes = this.CopyMessageIntoBufferManager(messageBytes);
 
                //when receiveHandler.HandleDataReceived is called, it will return the buffer to the buffer manager.
                continueReceiving = this.receiveHandler.HandleDataReceived(messageBytes, state.RemoteEndPoint, state.Socket.InterfaceIndex, this.onMessageDequeued);
            }
            catch (Exception ex)
            {
                if (!TryHandleException(ex))
                {
                    throw;
                }
            }
            finally
            {
                if (!this.IsDisposed && continueReceiving)
                {
                    ContinueReceiving(state.Socket);
                }
            }
        }
 
        //returns true if receive completed synchronously, false otherwise
        bool StartAsyncReceive(UdpSocket socket)
        {
            Fx.Assert(socket != null, "UdpSocketReceiveManager.StartAsyncReceive: Socket should never be null");                         
            bool completedSync = false;
 
            ArraySegment<byte> messageBytes = default(ArraySegment<byte>);
            UdpSocketReceiveState state = null;
 
            lock (this.thisLock)
            {
                if (!this.IsDisposed && socket.PendingReceiveCount < this.maxPendingReceivesPerSocket)
                {
                    IAsyncResult result = null;
                    byte[] receiveBuffer = this.receiveBufferPool.Take();
                    try
                    {
                        state = new UdpSocketReceiveState(socket, receiveBuffer);
                        EndPoint remoteEndpoint = socket.CreateIPAnyEndPoint();
 
                        result = socket.BeginReceiveFrom(receiveBuffer, 0, receiveBuffer.Length, ref remoteEndpoint, onReceiveFrom, state);
                    }
                    catch (Exception e)
                    {
                        if (!Fx.IsFatal(e))
                        {
                            this.receiveBufferPool.Return(receiveBuffer);
                        }
                        throw;
                    }
 
                    if (result.CompletedSynchronously)
                    {
                        completedSync = true;
                        messageBytes = EndReceiveFrom(result, state);
                    }
                }
            }
 
            if (completedSync)
            {
                messageBytes = this.CopyMessageIntoBufferManager(messageBytes);
                //if HandleDataReceived returns false, it means that the max pending message count was hit.
                //when receiveHandler.HandleDataReceived is called (whether now or later), it will return the buffer to the buffer manager.
                return this.receiveHandler.HandleDataReceived(messageBytes, state.RemoteEndPoint, state.Socket.InterfaceIndex, this.onMessageDequeued);
            }
 
            return false;
        }
 
        ArraySegment<byte> CopyMessageIntoBufferManager(ArraySegment<byte> receiveBuffer)
        {
            int dataLength = receiveBuffer.Count;
            byte[] dataBuffer = this.bufferManager.TakeBuffer(dataLength);
            Array.Copy(receiveBuffer.Array, receiveBuffer.Offset, dataBuffer, 0, dataLength);
            this.receiveBufferPool.Return(receiveBuffer.Array);
            return new ArraySegment<byte>(dataBuffer, 0, dataLength);
        }
 
        void EnsureReceiving()
        {
            for (int i = 0; i < this.receiveSockets.Length; i++)
            {
                UdpSocket socket = this.receiveSockets[i];
                                
                while (!this.IsDisposed && socket.PendingReceiveCount < this.maxPendingReceivesPerSocket)
                {
                    bool jumpThreads = false;
                    try
                    {
                        if (StartAsyncReceive(socket) && !Thread.CurrentThread.IsThreadPoolThread)
                        {
                            jumpThreads = true;
                        }
                    }
                    catch (CommunicationException ex)
                    {
                        //message too big, ICMP errors, etc, are translated by the socket into a CommunicationException derived exception.
                        //These should not be fatal to the receive loop, so we need to continue receiving.
                        this.receiveHandler.HandleAsyncException(ex);
                        jumpThreads = !Thread.CurrentThread.IsThreadPoolThread;
                    }
 
                    if (jumpThreads)
                    {
                        ActionItem.Schedule(this.continueReceivingCallback, socket);
                        break; //while loop.
                    }
                }
            }
        }
 
        void ThrowIfDisposed()
        {
            if (this.IsDisposed)
            {
                throw FxTrace.Exception.AsError(new ObjectDisposedException("SocketReceiveManager"));
            }
        }
 
        bool TryHandleException(Exception ex)
        {
            if (Fx.IsFatal(ex))
            {
                return false;
            }
 
            this.receiveHandler.HandleAsyncException(ex);
            return true;
        }
 
        //call under a lock
        ArraySegment<byte> EndReceiveFrom(IAsyncResult result, UdpSocketReceiveState state)
        {
            try
            {
                EndPoint remoteEndpoint = null;
                ArraySegment<byte> messageBytes = state.Socket.EndReceiveFrom(result, ref remoteEndpoint);
                state.RemoteEndPoint = remoteEndpoint;
                Fx.Assert(messageBytes.Array == state.ReceiveBuffer, "Array returned by Socket.EndReceiveFrom must match the array passed in through the UdpSocketReceiveState");
                return messageBytes;
            }
            catch (Exception e)
            {
                if (!Fx.IsFatal(e))
                {
                    this.receiveBufferPool.Return(state.ReceiveBuffer);
                }
                throw;
            }
        }
 
        internal class UdpSocketReceiveState
        {
            public UdpSocketReceiveState(UdpSocket socket, byte[] receiveBuffer)
            {
                Fx.Assert(socket != null, "UdpSocketReceiveState.ctor: socket should not be null");
 
                this.Socket = socket;
                this.ReceiveBuffer = receiveBuffer;
            }
 
            public EndPoint RemoteEndPoint
            {
                get;
                set;
            }
 
            internal UdpSocket Socket
            {
                get;
                private set;
            }
 
            internal byte[] ReceiveBuffer
            {
                get;
                private set;
            }
        }
    }
}