File: System\ServiceModel\Channels\TransactionChannelListener.cs
Project: ndp\cdf\src\WCF\ServiceModel\System.ServiceModel.csproj (System.ServiceModel)
//----------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//----------------------------------------------------------------------------
namespace System.ServiceModel.Channels
{
    using System.Collections.Generic;
    using System.Runtime;
    using System.ServiceModel;
    using System.ServiceModel.Description;
    using System.ServiceModel.Diagnostics;
    using System.ServiceModel.Security;
 
    sealed class TransactionChannelListener<TChannel> : DelegatingChannelListener<TChannel>, ITransactionChannelManager
        where TChannel : class, IChannel
    {
        TransactionFlowOption flowIssuedTokens;
        Dictionary<DirectionalAction, TransactionFlowOption> dictionary;
        SecurityStandardsManager standardsManager;
        TransactionProtocol transactionProtocol;
 
        public TransactionChannelListener(TransactionProtocol transactionProtocol, IDefaultCommunicationTimeouts timeouts, Dictionary<DirectionalAction, TransactionFlowOption> dictionary, IChannelListener<TChannel> innerListener)
            : base(timeouts, innerListener)
        {
            this.dictionary = dictionary;
            this.TransactionProtocol = transactionProtocol;
            this.Acceptor = new TransactionChannelAcceptor(this, innerListener);
 
            this.standardsManager = SecurityStandardsHelper.CreateStandardsManager(this.TransactionProtocol);
        }
 
        public TransactionProtocol TransactionProtocol
        {
            get
            {
                return this.transactionProtocol;
            }
            set
            {
                if (!TransactionProtocol.IsDefined(value))
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        new ArgumentException(SR.GetString(SR.SFxBadTransactionProtocols)));
                this.transactionProtocol = value;
            }
        }
 
        public TransactionFlowOption FlowIssuedTokens
        {
            get { return this.flowIssuedTokens; }
            set { this.flowIssuedTokens = value; }
        }
 
        public SecurityStandardsManager StandardsManager
        {
            get { return this.standardsManager; }
            set { this.standardsManager = (value != null ? value : SecurityStandardsHelper.CreateStandardsManager(this.transactionProtocol)); }
        }
 
        public IDictionary<DirectionalAction, TransactionFlowOption> Dictionary
        {
            get { return this.dictionary; }
        }
 
        public TransactionFlowOption GetTransaction(MessageDirection direction, string action)
        {
            TransactionFlowOption txFlow;
            if (dictionary.TryGetValue(new DirectionalAction(direction, action), out txFlow))
                return txFlow;
 
            // Look for the wildcard action
            if (dictionary.TryGetValue(new DirectionalAction(direction, MessageHeaders.WildcardAction), out txFlow))
                return txFlow;
 
            return TransactionFlowOption.NotAllowed;
        }
 
        class TransactionChannelAcceptor : LayeredChannelAcceptor<TChannel, TChannel>
        {
            TransactionChannelListener<TChannel> listener;
 
            public TransactionChannelAcceptor(TransactionChannelListener<TChannel> listener, IChannelListener<TChannel> innerListener)
                : base(listener, innerListener)
            {
                this.listener = listener;
            }
 
            override protected TChannel OnAcceptChannel(TChannel innerChannel)
            {
                if (typeof(TChannel) == typeof(IInputSessionChannel))
                {
                    return (TChannel)(object)new TransactionInputSessionChannel(this.listener, (IInputSessionChannel)innerChannel);
                }
                if (typeof(TChannel) == typeof(IDuplexSessionChannel))
                {
                    return (TChannel)(object)new TransactionDuplexSessionChannel(this.listener, (IDuplexSessionChannel)innerChannel);
                }
                else if (typeof(TChannel) == typeof(IInputChannel))
                {
                    return (TChannel)(object)new TransactionInputChannel(this.listener, (IInputChannel)innerChannel);
                }
                else if (typeof(TChannel) == typeof(IReplyChannel))
                {
                    return (TChannel)(object)new TransactionReplyChannel(this.listener, (IReplyChannel)innerChannel);
                }
                else if (typeof(TChannel) == typeof(IReplySessionChannel))
                {
                    return (TChannel)(object)new TransactionReplySessionChannel(this.listener, (IReplySessionChannel)innerChannel);
                }
                else if (typeof(TChannel) == typeof(IDuplexChannel))
                {
                    return (TChannel)(object)new TransactionDuplexChannel(this.listener, (IDuplexChannel)innerChannel);
                }
                else
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(listener.CreateChannelTypeNotSupportedException(typeof(TChannel)));
                }
            }
        }
 
 
        //==============================================================
        //                Transaction channel classes
        //==============================================================
 
        sealed class TransactionInputChannel : TransactionReceiveChannelGeneric<IInputChannel>
        {
            public TransactionInputChannel(ChannelManagerBase channelManager, IInputChannel innerChannel)
                : base(channelManager, innerChannel, MessageDirection.Input)
            {
            }
        }
 
        sealed class TransactionReplyChannel : TransactionReplyChannelGeneric<IReplyChannel>
        {
            public TransactionReplyChannel(ChannelManagerBase channelManager, IReplyChannel innerChannel)
                : base(channelManager, innerChannel)
            {
            }
        }
 
        sealed class TransactionDuplexChannel : TransactionInputDuplexChannelGeneric<IDuplexChannel>
        {
            public TransactionDuplexChannel(ChannelManagerBase channelManager, IDuplexChannel innerChannel)
                : base(channelManager, innerChannel)
            {
            }
        }
 
        sealed class TransactionInputSessionChannel : TransactionReceiveChannelGeneric<IInputSessionChannel>, IInputSessionChannel
        {
            public TransactionInputSessionChannel(ChannelManagerBase channelManager, IInputSessionChannel innerChannel)
                : base(channelManager, innerChannel, MessageDirection.Input)
            {
            }
 
            public IInputSession Session { get { return InnerChannel.Session; } }
        }
 
        sealed class TransactionReplySessionChannel : TransactionReplyChannelGeneric<IReplySessionChannel>, IReplySessionChannel
        {
            public TransactionReplySessionChannel(ChannelManagerBase channelManager, IReplySessionChannel innerChannel)
                : base(channelManager, innerChannel)
            {
            }
 
            public IInputSession Session { get { return InnerChannel.Session; } }
        }
 
        sealed class TransactionDuplexSessionChannel : TransactionInputDuplexChannelGeneric<IDuplexSessionChannel>, IDuplexSessionChannel
        {
            public TransactionDuplexSessionChannel(ChannelManagerBase channelManager, IDuplexSessionChannel innerChannel)
                : base(channelManager, innerChannel)
            {
            }
 
            public IDuplexSession Session { get { return InnerChannel.Session; } }
        }
    }
 
 
    sealed class TransactionRequestContext : RequestContextBase
    {
        ITransactionChannel transactionChannel;
        RequestContext innerContext;
 
        public TransactionRequestContext(ITransactionChannel transactionChannel, ChannelBase channel, RequestContext innerContext,
            TimeSpan defaultCloseTimeout, TimeSpan defaultSendTimeout)
            : base(innerContext.RequestMessage, defaultCloseTimeout, defaultSendTimeout)
        {
            this.transactionChannel = transactionChannel;
            this.innerContext = innerContext;
        }
 
        protected override void OnAbort()
        {
            if (this.innerContext == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ObjectDisposedException(this.GetType().FullName));
            }
 
            this.innerContext.Abort();
        }
 
        protected override IAsyncResult OnBeginReply(Message message, TimeSpan timeout, AsyncCallback callback, object state)
        {
            if (this.innerContext == null)
            {
                throw TraceUtility.ThrowHelperError(new ObjectDisposedException(this.GetType().FullName), message);
            }
 
            if (message != null)
            {
                this.transactionChannel.WriteIssuedTokens(message, MessageDirection.Output);
            }
            return this.innerContext.BeginReply(message, timeout, callback, state);
        }
 
 
        protected override void OnClose(TimeSpan timeout)
        {
            if (this.innerContext == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ObjectDisposedException(this.GetType().FullName));
            }
 
            this.innerContext.Close(timeout);
        }
 
        protected override void OnEndReply(IAsyncResult result)
        {
            if (this.innerContext == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ObjectDisposedException(this.GetType().FullName));
            }
 
            this.innerContext.EndReply(result);
        }
 
        protected override void OnReply(Message message, TimeSpan timeout)
        {
            if (this.innerContext == null)
            {
                throw TraceUtility.ThrowHelperError(new ObjectDisposedException(this.GetType().FullName), message);
            }
 
            if (message != null)
            {
                this.transactionChannel.WriteIssuedTokens(message, MessageDirection.Output);
            }
            this.innerContext.Reply(message, timeout);
        }
    }
 
 
 
    //==============================================================
    //                Transaction channel base generic classes
    //==============================================================
 
    class TransactionReceiveChannelGeneric<TChannel> : TransactionChannel<TChannel>, IInputChannel
        where TChannel : class, IInputChannel
    {
        MessageDirection receiveMessageDirection;
 
        public TransactionReceiveChannelGeneric(ChannelManagerBase channelManager, TChannel innerChannel, MessageDirection direction)
            : base(channelManager, innerChannel)
        {
            this.receiveMessageDirection = direction;
        }
 
        public EndpointAddress LocalAddress
        {
            get
            {
                return InnerChannel.LocalAddress;
            }
        }
 
        public Message Receive()
        {
            return this.Receive(this.DefaultReceiveTimeout);
        }
 
        public Message Receive(TimeSpan timeout)
        {
            return InputChannel.HelpReceive(this, timeout);
        }
 
        public IAsyncResult BeginReceive(AsyncCallback callback, object state)
        {
            return this.BeginReceive(this.DefaultReceiveTimeout, callback, state);
        }
 
        public IAsyncResult BeginReceive(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return InputChannel.HelpBeginReceive(this, timeout, callback, state);
        }
 
        public Message EndReceive(IAsyncResult result)
        {
            return InputChannel.HelpEndReceive(result);
        }
 
        public IAsyncResult BeginTryReceive(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return InnerChannel.BeginTryReceive(timeout, callback, state);
        }
 
        public virtual bool EndTryReceive(IAsyncResult asyncResult, out Message message)
        {
            if (!InnerChannel.EndTryReceive(asyncResult, out message))
            {
                return false;
            }
 
            if (message != null)
            {
                ReadTransactionDataFromMessage(message, this.receiveMessageDirection);
            }
 
            return true;
        }
 
        public virtual bool TryReceive(TimeSpan timeout, out Message message)
        {
            if (!InnerChannel.TryReceive(timeout, out message))
            {
                return false;
            }
 
            if (message != null)
            {
                ReadTransactionDataFromMessage(message, this.receiveMessageDirection);
            }
 
            return true;
        }
 
        public bool WaitForMessage(TimeSpan timeout)
        {
            return InnerChannel.WaitForMessage(timeout);
        }
 
        public IAsyncResult BeginWaitForMessage(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return InnerChannel.BeginWaitForMessage(timeout, callback, state);
        }
 
        public bool EndWaitForMessage(IAsyncResult result)
        {
            return InnerChannel.EndWaitForMessage(result);
        }
    }
 
 
    //-------------------------------------------------------------
    class TransactionReplyChannelGeneric<TChannel> : TransactionChannel<TChannel>, IReplyChannel
        where TChannel : class, IReplyChannel
    {
 
        public TransactionReplyChannelGeneric(ChannelManagerBase channelManager, TChannel innerChannel)
            : base(channelManager, innerChannel)
        {
        }
 
        public EndpointAddress LocalAddress
        {
            get
            {
                return InnerChannel.LocalAddress;
            }
        }
 
        public RequestContext ReceiveRequest()
        {
            return this.ReceiveRequest(this.DefaultReceiveTimeout);
        }
 
        public RequestContext ReceiveRequest(TimeSpan timeout)
        {
            return ReplyChannel.HelpReceiveRequest(this, timeout);
        }
 
        public IAsyncResult BeginReceiveRequest(AsyncCallback callback, object state)
        {
            return this.BeginReceiveRequest(this.DefaultReceiveTimeout, callback, state);
        }
 
        public IAsyncResult BeginReceiveRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return ReplyChannel.HelpBeginReceiveRequest(this, timeout, callback, state);
        }
 
        public RequestContext EndReceiveRequest(IAsyncResult result)
        {
            return ReplyChannel.HelpEndReceiveRequest(result);
        }
 
        public IAsyncResult BeginTryReceiveRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            ReceiveTimeoutAsyncResult result = new ReceiveTimeoutAsyncResult(timeout, callback, state);
            result.InnerResult = this.InnerChannel.BeginTryReceiveRequest(timeout, result.InnerCallback, result.InnerState);
            return result;
        }
 
        RequestContext FinishReceiveRequest(RequestContext innerContext, TimeSpan timeout)
        {
            if (innerContext == null)
                return null;
 
            try
            {
                this.ReadTransactionDataFromMessage(innerContext.RequestMessage, MessageDirection.Input);
            }
            catch (FaultException fault)
            {
                string faultAction = fault.Action ?? innerContext.RequestMessage.Version.Addressing.DefaultFaultAction;
                Message reply = Message.CreateMessage(innerContext.RequestMessage.Version, fault.CreateMessageFault(), faultAction);
                try
                {
                    innerContext.Reply(reply, timeout);
                }
                finally
                {
                    reply.Close();
                }
                throw;
            }
 
            return new TransactionRequestContext(this, this, innerContext, this.DefaultCloseTimeout, this.DefaultSendTimeout);
        }
 
 
        public bool EndTryReceiveRequest(IAsyncResult asyncResult, out RequestContext requestContext)
        {
            if (asyncResult == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("asyncResult");
 
            ReceiveTimeoutAsyncResult result = asyncResult as ReceiveTimeoutAsyncResult;
            if (result == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentException(SR.GetString(SR.AsyncEndCalledWithAnIAsyncResult)));
 
            RequestContext innerContext;
            if (InnerChannel.EndTryReceiveRequest(result.InnerResult, out innerContext))
            {
                requestContext = FinishReceiveRequest(innerContext, result.TimeoutHelper.RemainingTime());
                return true;
            }
            else
            {
                requestContext = null;
                return false;
            }
        }
 
        public bool TryReceiveRequest(TimeSpan timeout, out RequestContext requestContext)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
 
            RequestContext innerContext;
            if (InnerChannel.TryReceiveRequest(timeoutHelper.RemainingTime(), out innerContext))
            {
                requestContext = FinishReceiveRequest(innerContext, timeoutHelper.RemainingTime());
                return true;
            }
            else
            {
                requestContext = null;
                return false;
            }
        }
 
        public bool WaitForRequest(TimeSpan timeout)
        {
            return InnerChannel.WaitForRequest(timeout);
        }
 
        public IAsyncResult BeginWaitForRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return InnerChannel.BeginWaitForRequest(timeout, callback, state);
        }
 
        public bool EndWaitForRequest(IAsyncResult result)
        {
            return InnerChannel.EndWaitForRequest(result);
        }
    }
 
 
    //-------------------------------------------------------------
    class TransactionInputDuplexChannelGeneric<TChannel> : TransactionDuplexChannelGeneric<TChannel>
        where TChannel : class, IDuplexChannel
    {
        public TransactionInputDuplexChannelGeneric(ChannelManagerBase channelManager, TChannel innerChannel)
            : base(channelManager, innerChannel, MessageDirection.Input)
        {
        }
    }
 
 
    //-------------------------------------------------------------
    class TransactionDuplexChannelGeneric<TChannel> : TransactionReceiveChannelGeneric<TChannel>, IDuplexChannel
        where TChannel : class, IDuplexChannel
    {
        MessageDirection sendMessageDirection;
 
        public TransactionDuplexChannelGeneric(ChannelManagerBase channelManager, TChannel innerChannel, MessageDirection direction)
            : base(channelManager, innerChannel, direction)
        {
            if (direction == MessageDirection.Input)
            {
                this.sendMessageDirection = MessageDirection.Output;
            }
            else
            {
                this.sendMessageDirection = MessageDirection.Input;
            }
        }
 
        public EndpointAddress RemoteAddress
        {
            get
            {
                return InnerChannel.RemoteAddress;
            }
        }
 
        public Uri Via
        {
            get
            {
                return InnerChannel.Via;
            }
        }
 
        public override void ReadTransactionDataFromMessage(Message message, MessageDirection direction)
        {
            try
            {
                base.ReadTransactionDataFromMessage(message, direction);
            }
            catch (FaultException fault)
            {
                Message reply = Message.CreateMessage(message.Version, fault.CreateMessageFault(), fault.Action);
 
                System.ServiceModel.Channels.RequestReplyCorrelator.AddressReply(reply, message);
                System.ServiceModel.Channels.RequestReplyCorrelator.PrepareReply(reply, message.Headers.MessageId);
 
                try
                {
                    this.Send(reply);
                }
                finally
                {
                    reply.Close();
                }
 
                throw;
            }
        }
 
        public IAsyncResult BeginSend(Message message, AsyncCallback callback, object state)
        {
            return this.BeginSend(message, this.DefaultSendTimeout, callback, state);
        }
 
        public virtual IAsyncResult BeginSend(Message message, TimeSpan timeout, AsyncCallback asyncCallback, object state)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            WriteTransactionDataToMessage(message, sendMessageDirection);
            return InnerChannel.BeginSend(message, timeoutHelper.RemainingTime(), asyncCallback, state);
        }
 
        public void EndSend(IAsyncResult result)
        {
            InnerChannel.EndSend(result);
        }
 
        public void Send(Message message)
        {
            this.Send(message, this.DefaultSendTimeout);
        }
 
        public virtual void Send(Message message, TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            WriteTransactionDataToMessage(message, sendMessageDirection);
            InnerChannel.Send(message, timeoutHelper.RemainingTime());
        }
    }
 
 
    //==============================================================
    //                async helper classes
    //==============================================================
 
    class ReceiveTimeoutAsyncResult : AsyncResult
    {
        TimeoutHelper timeoutHelper;
        IAsyncResult innerResult;
        static AsyncCallback innerCallback = Fx.ThunkCallback(new AsyncCallback(Callback));
 
        internal ReceiveTimeoutAsyncResult(TimeSpan timeout, AsyncCallback callback, object state)
            : base(callback, state)
        {
            this.timeoutHelper = new TimeoutHelper(timeout);
        }
 
        internal TimeoutHelper TimeoutHelper
        {
            get { return this.timeoutHelper; }
        }
 
        internal AsyncCallback InnerCallback
        {
            get
            {
                if (ReceiveTimeoutAsyncResult.innerCallback == null)
                    ReceiveTimeoutAsyncResult.innerCallback = Fx.ThunkCallback(new AsyncCallback(Callback));
                return ReceiveTimeoutAsyncResult.innerCallback;
            }
        }
 
        internal IAsyncResult InnerResult
        {
            get
            {
                if (!(this.innerResult != null))
                {
                    // tx processing requires failfast when state is inconsistent
                    DiagnosticUtility.FailFast("ReceiveTimeoutAsyncResult.InnerResult: (this.innerResult != null)");
                }
                return this.innerResult;
            }
            set
            {
                if (value == null)
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("value");
 
                if (this.innerResult == null)
                {
                    this.innerResult = value;
                }
                else if (this.innerResult != value)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.SFxAsyncResultsDontMatch0)));
                }
            }
        }
 
        internal object InnerState
        {
            get { return this; }
        }
 
        static void Callback(IAsyncResult result)
        {
            if (result == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("result");
 
            ReceiveTimeoutAsyncResult outerResult = (ReceiveTimeoutAsyncResult)result.AsyncState;
 
            outerResult.InnerResult = result;
            outerResult.Complete(result.CompletedSynchronously);
        }
    }
 
 
}