File: System\Runtime\ThreadNeutralSemaphore.cs
Project: ndp\cdf\src\System.ServiceModel.Internals\System.ServiceModel.Internals.csproj (System.ServiceModel.Internals)
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//------------------------------------------------------------
 
namespace System.Runtime
{
    using System.Collections.Generic;
    using System.Threading;
    using System.Globalization;
    using System.Diagnostics.CodeAnalysis;
    using System.Diagnostics;
 
    [Fx.Tag.SynchronizationPrimitive(Fx.Tag.BlocksUsing.PrivatePrimitive,
        SupportsAsync = true, ReleaseMethod = "Exit")]
    class ThreadNeutralSemaphore
    {
#if DEBUG
        StackTrace exitStack;
#endif
 
        static Action<object, TimeoutException> enteredAsyncCallback;
 
        bool aborted;
        Func<Exception> abortedExceptionGenerator;
        int count;
        int maxCount;
 
        [Fx.Tag.SynchronizationObject(Blocking = false)]
        object ThisLock = new object();
 
        [Fx.Tag.SynchronizationObject]
        Queue<AsyncWaitHandle> waiters;
 
        public ThreadNeutralSemaphore(int maxCount)
            : this(maxCount, null)
        {
        }
 
        public ThreadNeutralSemaphore(int maxCount, Func<Exception> abortedExceptionGenerator)
        {
            Fx.Assert(maxCount > 0, "maxCount must be positive");
            this.maxCount = maxCount;
            this.abortedExceptionGenerator = abortedExceptionGenerator;
        }
 
        static Action<object, TimeoutException> EnteredAsyncCallback
        {
            get
            {
                if (enteredAsyncCallback == null)
                {
                    enteredAsyncCallback = new Action<object, TimeoutException>(OnEnteredAsync);
                }
 
                return enteredAsyncCallback;
            }
        }
 
        Queue<AsyncWaitHandle> Waiters
        {
            get
            {
                if (this.waiters == null)
                {
                    this.waiters = new Queue<AsyncWaitHandle>();
                }
 
                return this.waiters;
            }
        }
 
        public bool EnterAsync(TimeSpan timeout, FastAsyncCallback callback, object state)
        {
            Fx.Assert(callback != null, "must have a non-null call back for async purposes");
 
            AsyncWaitHandle waiter = null;
 
            lock (this.ThisLock)
            {
                if (this.aborted)
                {
                    throw Fx.Exception.AsError(CreateObjectAbortedException());
                }
 
                if (this.count < this.maxCount)
                {
                    this.count++;
                    return true;
                }
 
                waiter = new AsyncWaitHandle();
                this.Waiters.Enqueue(waiter);
            }
 
            return waiter.WaitAsync(EnteredAsyncCallback, new EnterAsyncData(this, waiter, callback, state), timeout);
        }
 
        static void OnEnteredAsync(object state, TimeoutException exception)
        {
            EnterAsyncData data = (EnterAsyncData)state;
            ThreadNeutralSemaphore thisPtr = data.Semaphore;
            Exception exceptionToPropagate = exception;
 
            if (exception != null)
            {
                if (!thisPtr.RemoveWaiter(data.Waiter))
                {
                    // The timeout raced with Exit and exit won.
                    // We've successfully entered.
                    exceptionToPropagate = null;
                }
            }
 
            Fx.Assert(!thisPtr.waiters.Contains(data.Waiter), "The waiter should have been removed already.");
 
            if (thisPtr.aborted)
            {
                exceptionToPropagate = thisPtr.CreateObjectAbortedException();
            }
 
            data.Callback(data.State, exceptionToPropagate);
        }
 
        public bool TryEnter()
        {
            lock (this.ThisLock)
            {
                if (this.count < this.maxCount)
                {
                    this.count++;
                    return true;
                }
 
                return false;
            }
        }
 
        [Fx.Tag.Blocking(CancelMethod = "Abort")]
        public void Enter(TimeSpan timeout)
        {
            if (!TryEnter(timeout))
            {
                throw Fx.Exception.AsError(CreateEnterTimedOutException(timeout));
            }
        }
 
        [Fx.Tag.Blocking(CancelMethod = "Abort")]
        public bool TryEnter(TimeSpan timeout)
        {
            AsyncWaitHandle waiter = EnterCore();
 
            if (waiter != null)
            {
                bool timedOut = !waiter.Wait(timeout);
 
                if (this.aborted)
                {
                    throw Fx.Exception.AsError(CreateObjectAbortedException());
                }
 
                if (timedOut && !RemoveWaiter(waiter))
                {
                    // The timeout raced with Exit and exit won.
                    // We've successfully entered.
 
                    timedOut = false;
                }
 
 
                return !timedOut;
            }
 
            return true;
        }
 
        internal static TimeoutException CreateEnterTimedOutException(TimeSpan timeout)
        {
            return new TimeoutException(InternalSR.LockTimeoutExceptionMessage(timeout));
        }
 
        Exception CreateObjectAbortedException()
        {
            if (this.abortedExceptionGenerator != null)
            {
                return this.abortedExceptionGenerator();
            }
            else
            {
                return new OperationCanceledException(InternalSR.ThreadNeutralSemaphoreAborted);
            }
        }
 
        // remove a waiter from our queue. Returns true if successful. Used to implement timeouts.
        bool RemoveWaiter(AsyncWaitHandle waiter)
        {
            bool removed = false;
 
            lock (this.ThisLock)
            {
                for (int i = this.Waiters.Count; i > 0; i--)
                {
                    AsyncWaitHandle temp = this.Waiters.Dequeue();
                    if (object.ReferenceEquals(temp, waiter))
                    {
                        removed = true;
                    }
                    else
                    {
                        this.Waiters.Enqueue(temp);
                    }
                }
            }
 
            return removed;
        }
 
        AsyncWaitHandle EnterCore()
        {
            AsyncWaitHandle waiter;
 
            lock (this.ThisLock)
            {
                if (this.aborted)
                {
                    throw Fx.Exception.AsError(CreateObjectAbortedException());
                }
 
                if (this.count < this.maxCount)
                {
                    this.count++;
                    return null;
                }
 
                waiter = new AsyncWaitHandle();
                this.Waiters.Enqueue(waiter);
            }
 
            return waiter;
        }
 
        public int Exit()
        {
            AsyncWaitHandle waiter;
 
            int remainingCount = -1;
            lock (this.ThisLock)
            {
                if (this.aborted)
                {
                    return remainingCount;
                }
 
                if (this.count == 0)
                {
                    string message = InternalSR.InvalidSemaphoreExit;
 
#if DEBUG
                    if (!Fx.FastDebug && exitStack != null)
                    {
                        string originalStack = exitStack.ToString().Replace("\r\n", "\r\n    ");
                        message = string.Format(CultureInfo.InvariantCulture,
                            "Object synchronization method was called from an unsynchronized block of code. Previous Exit(): {0}", originalStack);
                    }
#endif
 
                    throw Fx.Exception.AsError(new SynchronizationLockException(message));
                }
 
                if (this.waiters == null || this.waiters.Count == 0)
                {
                    this.count--;
 
#if DEBUG
                    if (!Fx.FastDebug && this.count == 0)
                    {
                        exitStack = new StackTrace();
                    }
#endif
 
                    return this.count;
                }
 
                waiter = this.waiters.Dequeue();
                remainingCount = this.count;
            }
 
            waiter.Set();
            return remainingCount;
        }
 
        // Abort the ThreadNeutralSemaphore object.
        public void Abort()
        {
            lock (this.ThisLock)
            {
                if (this.aborted)
                {
                    return;
                }
 
                this.aborted = true;
 
                if (this.waiters != null)
                {
                    while (this.waiters.Count > 0)
                    {
                        AsyncWaitHandle waiter = this.waiters.Dequeue();
                        waiter.Set();
                    }
                }
            }
        }
 
        class EnterAsyncData
        {
            public EnterAsyncData(ThreadNeutralSemaphore semaphore, AsyncWaitHandle waiter, FastAsyncCallback callback, object state)
            {
                this.Waiter = waiter;
                this.Semaphore = semaphore;
                this.Callback = callback;
                this.State = state;
            }
 
            public ThreadNeutralSemaphore Semaphore
            {
                get;
                set;
            }
 
            public AsyncWaitHandle Waiter
            {
                get;
                set;
            }
 
            public FastAsyncCallback Callback
            {
                get;
                set;
            }
 
            public object State
            {
                get;
                set;
            }
        }
    }
}