File: System\ServiceModel\Channels\RequestChannel.cs
Project: ndp\cdf\src\WCF\ServiceModel\System.ServiceModel.csproj (System.ServiceModel)
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//------------------------------------------------------------
 
namespace System.ServiceModel.Channels
{
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Runtime;
    using System.ServiceModel;
    using System.ServiceModel.Diagnostics;
    using System.Threading;
 
    abstract class RequestChannel : ChannelBase, IRequestChannel
    {
        bool manualAddressing;
        List<IRequestBase> outstandingRequests = new List<IRequestBase>();
        EndpointAddress to;
        Uri via;
        ManualResetEvent closedEvent;
        bool closed;
 
        protected RequestChannel(ChannelManagerBase channelFactory, EndpointAddress to, Uri via, bool manualAddressing)
            : base(channelFactory)
        {
            if (!manualAddressing)
            {
                if (to == null)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("to");
                }
            }
 
            this.manualAddressing = manualAddressing;
            this.to = to;
            this.via = via;
        }
 
        protected bool ManualAddressing
        {
            get
            {
                return this.manualAddressing;
            }
        }
 
        public EndpointAddress RemoteAddress
        {
            get
            {
                return this.to;
            }
        }
 
        public Uri Via
        {
            get
            {
                return this.via;
            }
        }
 
        protected void AbortPendingRequests()
        {
            IRequestBase[] requestsToAbort = CopyPendingRequests(false);
 
            if (requestsToAbort != null)
            {
                foreach (IRequestBase request in requestsToAbort)
                {
                    request.Abort(this);
                }
            }
        }
 
        protected IAsyncResult BeginWaitForPendingRequests(TimeSpan timeout, AsyncCallback callback, object state)
        {
            IRequestBase[] pendingRequests = SetupWaitForPendingRequests();
            return new WaitForPendingRequestsAsyncResult(timeout, this, pendingRequests, callback, state);
        }
 
        protected void EndWaitForPendingRequests(IAsyncResult result)
        {
            WaitForPendingRequestsAsyncResult.End(result);
        }
 
        void FinishClose()
        {
            lock (outstandingRequests)
            {
                if (!closed)
                {
                    closed = true;
                    if (closedEvent != null)
                    {
                        this.closedEvent.Close();
                    }
                }
            }
        }
 
        IRequestBase[] SetupWaitForPendingRequests()
        {
            return this.CopyPendingRequests(true);
        }
 
        protected void WaitForPendingRequests(TimeSpan timeout)
        {
            IRequestBase[] pendingRequests = SetupWaitForPendingRequests();
            if (pendingRequests != null)
            {
                if (!closedEvent.WaitOne(timeout, false))
                {
                    foreach (IRequestBase request in pendingRequests)
                    {
                        request.Abort(this);
                    }
                }
            }
            FinishClose();
        }
 
        IRequestBase[] CopyPendingRequests(bool createEventIfNecessary)
        {
            IRequestBase[] requests = null;
 
            lock (outstandingRequests)
            {
                if (outstandingRequests.Count > 0)
                {
                    requests = new IRequestBase[outstandingRequests.Count];
                    outstandingRequests.CopyTo(requests);
                    outstandingRequests.Clear();
 
                    if (createEventIfNecessary && closedEvent == null)
                    {
                        closedEvent = new ManualResetEvent(false);
                    }
                }
            }
 
            return requests;
        }
 
        protected void FaultPendingRequests()
        {
            IRequestBase[] requestsToFault = CopyPendingRequests(false);
 
            if (requestsToFault != null)
            {
                foreach (IRequestBase request in requestsToFault)
                {
                    request.Fault(this);
                }
            }
        }
 
        public override T GetProperty<T>()
        {
            if (typeof(T) == typeof(IRequestChannel))
            {
                return (T)(object)this;
            }
 
            T baseProperty = base.GetProperty<T>();
            if (baseProperty != null)
            {
                return baseProperty;
            }
 
            return default(T);
        }
 
        protected override void OnAbort()
        {
            AbortPendingRequests();
        }
 
        void ReleaseRequest(IRequestBase request)
        {
            if (request != null)
            {
                // Synchronization of OnReleaseRequest is the 
                // responsibility of the concrete implementation of request.
                request.OnReleaseRequest();
            }
 
            lock (outstandingRequests)
            {
                // Remove supports the connection having been removed, so don't need extra Contains() check,
                // even though this may have been removed by Abort()
                outstandingRequests.Remove(request);
                if (outstandingRequests.Count == 0)
                {
                    if (!closed && closedEvent != null)
                    {
                        closedEvent.Set();
                    }
                }
            }
        }
 
        void TrackRequest(IRequestBase request)
        {
            lock (outstandingRequests)
            {
                ThrowIfDisposedOrNotOpen(); // make sure that we haven't already snapshot our collection
                outstandingRequests.Add(request);
            }
        }
 
        public IAsyncResult BeginRequest(Message message, AsyncCallback callback, object state)
        {
            return this.BeginRequest(message, this.DefaultSendTimeout, callback, state);
        }
 
        public IAsyncResult BeginRequest(Message message, TimeSpan timeout, AsyncCallback callback, object state)
        {
            if (message == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
 
            if (timeout < TimeSpan.Zero)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ArgumentOutOfRangeException("timeout", timeout, SR.GetString(SR.SFxTimeoutOutOfRange0)));
 
            ThrowIfDisposedOrNotOpen();
 
            AddHeadersTo(message);
            IAsyncRequest asyncRequest = CreateAsyncRequest(message, callback, state);
            TrackRequest(asyncRequest);
 
            bool throwing = true;
            try
            {
                asyncRequest.BeginSendRequest(message, timeout);
                throwing = false;
            }
            finally
            {
                if (throwing)
                {
                    ReleaseRequest(asyncRequest);
                }
            }
 
            return asyncRequest;
        }
 
        protected abstract IRequest CreateRequest(Message message);
        protected abstract IAsyncRequest CreateAsyncRequest(Message message, AsyncCallback callback, object state);
 
        public Message EndRequest(IAsyncResult result)
        {
            if (result == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("result");
            }
 
            IAsyncRequest asyncRequest = result as IAsyncRequest;
 
            if (asyncRequest == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument("result", SR.GetString(SR.InvalidAsyncResult));
            }
 
            try
            {
                Message reply = asyncRequest.End();
 
                if (DiagnosticUtility.ShouldTraceInformation)
                {
                    TraceUtility.TraceEvent(TraceEventType.Information, TraceCode.RequestChannelReplyReceived,
                        SR.GetString(SR.TraceCodeRequestChannelReplyReceived), reply);
                }
 
                return reply;
            }
            finally
            {
                ReleaseRequest(asyncRequest);
            }
        }
 
        public Message Request(Message message)
        {
            return this.Request(message, this.DefaultSendTimeout);
        }
 
        public Message Request(Message message, TimeSpan timeout)
        {
            if (message == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
            }
 
            if (timeout < TimeSpan.Zero)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ArgumentOutOfRangeException("timeout", timeout, SR.GetString(SR.SFxTimeoutOutOfRange0)));
 
            ThrowIfDisposedOrNotOpen();
 
            AddHeadersTo(message);
            IRequest request = CreateRequest(message);
            TrackRequest(request);
            try
            {
                Message reply;
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
                TimeSpan savedTimeout = timeoutHelper.RemainingTime();
                try
                {
                    request.SendRequest(message, savedTimeout);
                }
                catch (TimeoutException timeoutException)
                {
                    throw TraceUtility.ThrowHelperError(new TimeoutException(SR.GetString(SR.RequestChannelSendTimedOut, savedTimeout),
                        timeoutException), message);
                }
 
                savedTimeout = timeoutHelper.RemainingTime();
 
                try
                {
                    reply = request.WaitForReply(savedTimeout);
                }
                catch (TimeoutException timeoutException)
                {
                    throw TraceUtility.ThrowHelperError(new TimeoutException(SR.GetString(SR.RequestChannelWaitForReplyTimedOut, savedTimeout),
                        timeoutException), message);
                }
 
 
                if (DiagnosticUtility.ShouldTraceInformation)
                {
                    TraceUtility.TraceEvent(TraceEventType.Information, TraceCode.RequestChannelReplyReceived,
                        SR.GetString(SR.TraceCodeRequestChannelReplyReceived), reply);
                }
 
                return reply;
            }
            finally
            {
                ReleaseRequest(request);
            }
        }
 
        protected virtual void AddHeadersTo(Message message)
        {
            if (!manualAddressing && to != null)
            {
                to.ApplyTo(message);
            }
        }
 
        class WaitForPendingRequestsAsyncResult : AsyncResult
        {
            static WaitOrTimerCallback completeWaitCallBack = new WaitOrTimerCallback(OnCompleteWaitCallBack);
            IRequestBase[] pendingRequests;
            RequestChannel requestChannel;
            TimeSpan timeout;
            RegisteredWaitHandle waitHandle;
 
            public WaitForPendingRequestsAsyncResult(TimeSpan timeout, RequestChannel requestChannel, IRequestBase[] pendingRequests, AsyncCallback callback, object state)
                : base(callback, state)
            {
                this.requestChannel = requestChannel;
                this.pendingRequests = pendingRequests;
                this.timeout = timeout;
 
                if (this.timeout == TimeSpan.Zero || this.pendingRequests == null)
                {
                    AbortRequests();
                    CleanupEvents();
                    Complete(true);
                }
                else
                {
                    this.waitHandle = ThreadPool.RegisterWaitForSingleObject(this.requestChannel.closedEvent, completeWaitCallBack, this, TimeoutHelper.ToMilliseconds(timeout), true);
                }
            }
 
            void AbortRequests()
            {
                if (pendingRequests != null)
                {
                    foreach (IRequestBase request in pendingRequests)
                    {
                        request.Abort(this.requestChannel);
                    }
                }
            }
 
            void CleanupEvents()
            {
                if (requestChannel.closedEvent != null)
                {
                    if (waitHandle != null)
                    {
                        waitHandle.Unregister(requestChannel.closedEvent);
                    }
                    requestChannel.FinishClose();
                }
            }
 
            static void OnCompleteWaitCallBack(object state, bool timedOut)
            {
                WaitForPendingRequestsAsyncResult thisPtr = (WaitForPendingRequestsAsyncResult)state;
                Exception completionException = null;
                try
                {
                    if (timedOut)
                    {
                        thisPtr.AbortRequests();
                    }
                    thisPtr.CleanupEvents();
                }
#pragma warning suppress 56500 // Microsoft, transferring exception to another thread
                catch (Exception e)
                {
                    if (Fx.IsFatal(e))
                    {
                        throw;
                    }
                    completionException = e;
                }
 
                thisPtr.Complete(false, completionException);
            }
 
            public static void End(IAsyncResult result)
            {
                AsyncResult.End<WaitForPendingRequestsAsyncResult>(result);
            }
        }
    }
 
    interface IRequestBase
    {
        void Abort(RequestChannel requestChannel);
        void Fault(RequestChannel requestChannel);
        void OnReleaseRequest();
    }
 
    interface IRequest : IRequestBase
    {
        void SendRequest(Message message, TimeSpan timeout);
        Message WaitForReply(TimeSpan timeout);
    }
 
    interface IAsyncRequest : IAsyncResult, IRequestBase
    {
        void BeginSendRequest(Message message, TimeSpan timeout);
        Message End();
    }
}