File: System\ServiceModel\Dispatcher\EndpointAddressMessageFilter.cs
Project: ndp\cdf\src\WCF\ServiceModel\System.ServiceModel.csproj (System.ServiceModel)
//-----------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//-----------------------------------------------------------------------------
namespace System.ServiceModel.Dispatcher
{
    using System;
    using System.ServiceModel.Channels;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.IO;
    using System.ServiceModel.Security;
    using System.Text;
    using System.Xml;
 
    using QName = System.ServiceModel.Dispatcher.EndpointAddressProcessor.QName;
    using HeaderBit = System.ServiceModel.Dispatcher.EndpointAddressProcessor.HeaderBit;
 
    public class EndpointAddressMessageFilter : MessageFilter
    {
        EndpointAddress address;
        bool includeHostNameInComparison;
        EndpointAddressMessageFilterHelper helper;
        UriComparer comparer;
 
        public EndpointAddressMessageFilter(EndpointAddress address)
            : this(address, false)
        {
        }
 
        public EndpointAddressMessageFilter(EndpointAddress address, bool includeHostNameInComparison)
        {
            if (address == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("address");
            }
 
            this.address = address;
            this.includeHostNameInComparison = includeHostNameInComparison;
            this.helper = new EndpointAddressMessageFilterHelper(this.address);
 
            if (includeHostNameInComparison)
            {
                this.comparer = HostUriComparer.Value;
            }
            else
            {
                this.comparer = NoHostUriComparer.Value;
            }
        }
 
        public EndpointAddress Address
        {
            get
            {
                return this.address;
            }
        }
 
        public bool IncludeHostNameInComparison
        {
            get { return this.includeHostNameInComparison; }
        }
 
        protected internal override IMessageFilterTable<FilterData> CreateFilterTable<FilterData>()
        {
            return new EndpointAddressMessageFilterTable<FilterData>();
        }
 
        public override bool Match(MessageBuffer messageBuffer)
        {
            if (messageBuffer == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("messageBuffer");
            }
 
            Message msg = messageBuffer.CreateMessage();
            try
            {
                return Match(msg);
            }
            finally
            {
                msg.Close();
            }
        }
 
        public override bool Match(Message message)
        {
            if (message == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
            }
 
            // To
#pragma warning suppress 56506 // Microsoft, Message.Headers can never be null
            Uri to = message.Headers.To;
            Uri actingAs = this.address.Uri;
 
            if (to == null || !this.comparer.Equals(actingAs, to))
            {
                return false;
            }
 
            return this.helper.Match(message);
        }
 
        internal Dictionary<string, HeaderBit[]> HeaderLookup
        {
            get { return this.helper.HeaderLookup; }
        }
 
        internal bool ComparePort
        {
            set
            {
                comparer.ComparePort = value;
            }
        }
 
        internal abstract class UriComparer : EqualityComparer<Uri>
        {
            protected UriComparer()
            {
                ComparePort = true;
            }
 
            protected abstract bool CompareHost { get; }
 
            internal bool ComparePort
            {
                get;
                set;
            }
 
            public override bool Equals(Uri u1, Uri u2)
            {
                return EndpointAddress.UriEquals(u1, u2, true /* ignoreCase */, CompareHost, ComparePort);
            }
 
            public override int GetHashCode(Uri uri)
            {
                return EndpointAddress.UriGetHashCode(uri, CompareHost, ComparePort);
            }
        }
 
        internal sealed class HostUriComparer : UriComparer
        {
            internal static readonly UriComparer Value = new HostUriComparer();
 
            HostUriComparer()
            {
            }
 
            protected override bool CompareHost
            {
                get
                {
                    return true;
                }
            }
        }
 
        internal sealed class NoHostUriComparer : UriComparer
        {
            internal static readonly UriComparer Value = new NoHostUriComparer();
 
            NoHostUriComparer()
            {
            }
 
            protected override bool CompareHost
            {
                get
                {
                    return false;
                }
            }
        }
    }
 
    internal class EndpointAddressMessageFilterHelper
    {
        EndpointAddress address;
        WeakReference processorPool;
 
        int size;
        byte[] mask;
        Dictionary<QName, int> qnameLookup;
        Dictionary<string, HeaderBit[]> headerLookup;
 
        internal EndpointAddressMessageFilterHelper(EndpointAddress address)
        {
            this.address = address;
 
            if (this.address.Headers.Count > 0)
            {
                CreateMask();
                this.processorPool = new WeakReference(null);
            }
            else
            {
                this.qnameLookup = null;
                this.headerLookup = null;
                this.size = 0;
                this.mask = null;
            }
        }
 
        void CreateMask()
        {
            int nextBit = 0;
            string key;
            HeaderBit[] bits;
            QName qname;
            this.qnameLookup = new Dictionary<QName, int>(EndpointAddressProcessor.QNameComparer);
            this.headerLookup = new Dictionary<string, HeaderBit[]>();
            StringBuilder builder = null;
 
            for (int i = 0; i < this.address.Headers.Count; ++i)
            {
                if (builder == null)
                    builder = new StringBuilder();
                else
                    builder.Remove(0, builder.Length);
 
 
                key = this.address.Headers[i].GetComparableForm(builder);
                if (this.headerLookup.TryGetValue(key, out bits))
                {
                    Array.Resize(ref bits, bits.Length + 1);
                    bits[bits.Length - 1] = new HeaderBit(nextBit++);
                    this.headerLookup[key] = bits;
                }
                else
                {
 
                    this.headerLookup.Add(key, new HeaderBit[] { new HeaderBit(nextBit++) });
                    AddressHeader parameter = this.address.Headers[i];
 
                    qname.name = parameter.Name;
                    qname.ns = parameter.Namespace;
                    this.qnameLookup[qname] = 1;
                }
            }
 
            if (nextBit == 0)
            {
                this.size = 0;
            }
            else
            {
                this.size = (nextBit - 1) / 8 + 1;
            }
 
            if (this.size > 0)
            {
                this.mask = new byte[size];
                for (int i = 0; i < this.size - 1; ++i)
                {
                    this.mask[i] = 0xff;
                }
 
                if ((nextBit % 8) == 0)
                {
                    this.mask[this.size - 1] = 0xff;
                }
                else
                {
                    this.mask[this.size - 1] = (byte)((1 << (nextBit % 8)) - 1);
                }
            }
        }
 
        internal Dictionary<string, HeaderBit[]> HeaderLookup
        {
            get
            {
                if (this.headerLookup == null)
                    this.headerLookup = new Dictionary<string, HeaderBit[]>();
                return this.headerLookup;
            }
        }
 
        EndpointAddressProcessor CreateProcessor(int length)
        {
            if (this.processorPool.Target != null)
            {
                lock (this.processorPool)
                {
                    object o = this.processorPool.Target;
                    if (o != null)
                    {
                        EndpointAddressProcessor p = (EndpointAddressProcessor)o;
                        this.processorPool.Target = p.Next;
                        p.Next = null;
                        p.Clear(length);
                        return p;
                    }
                }
            }
 
            return new EndpointAddressProcessor(length);
        }
 
        internal bool Match(Message message)
        {
            if (this.size == 0)
            {
                return true;
            }
 
            EndpointAddressProcessor context = CreateProcessor(this.size);
            context.ProcessHeaders(message, this.qnameLookup, this.headerLookup);
            bool result = context.TestExact(this.mask);
            ReleaseProcessor(context);
            return result;
        }
 
        void ReleaseProcessor(EndpointAddressProcessor context)
        {
            lock (this.processorPool)
            {
                context.Next = this.processorPool.Target as EndpointAddressProcessor;
                this.processorPool.Target = context;
            }
        }
 
    }
}