File: System\Linq\Parallel\Merging\DefaultMergeHelper.cs
Project: ndp\fx\src\Core\System.Core.csproj (System.Core)
// ==++==
//
//   Copyright (c) Microsoft Corporation.  All rights reserved.
// 
// ==--==
// =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
//
// DefaultMergeHelper.cs
//
// <OWNER>Microsoft</OWNER>
//
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Diagnostics.Contracts;
 
namespace System.Linq.Parallel
{
    /// <summary>
    /// The default merge helper uses a set of straightforward algorithms for output
    /// merging. Namely, for synchronous merges, the input data is yielded from the
    /// input data streams in "depth first" left-to-right order. For asynchronous merges,
    /// on the other hand, we use a biased choice algorithm to favor input channels in
    /// a "fair" way. No order preservation is carried out by this helper. 
    /// </summary>
    /// <typeparam name="TInputOutput"></typeparam>
    /// <typeparam name="TIgnoreKey"></typeparam>
    internal class DefaultMergeHelper<TInputOutput, TIgnoreKey> : IMergeHelper<TInputOutput>
    {
        private QueryTaskGroupState m_taskGroupState; // State shared among tasks.
        private PartitionedStream<TInputOutput, TIgnoreKey> m_partitions; // Source partitions.
        private AsynchronousChannel<TInputOutput>[] m_asyncChannels; // Destination channels (async).
        private SynchronousChannel<TInputOutput>[] m_syncChannels; // Destination channels (sync).
        private IEnumerator<TInputOutput> m_channelEnumerator; // Output enumerator.
        private TaskScheduler m_taskScheduler; // The task manager to execute the query.
        private bool m_ignoreOutput; // Whether we're enumerating "for effect".
 
        //-----------------------------------------------------------------------------------
        // Instantiates a new merge helper.
        //
        // Arguments:
        //     partitions   - the source partitions from which to consume data.
        //     ignoreOutput - whether we're enumerating "for effect" or for output.
        //     pipeline     - whether to use a pipelined merge.
        //
 
        internal DefaultMergeHelper(PartitionedStream<TInputOutput, TIgnoreKey> partitions, bool ignoreOutput, ParallelMergeOptions options, 
            TaskScheduler taskScheduler, CancellationState cancellationState, int queryId)
        {
            Contract.Assert(partitions != null);
 
            m_taskGroupState = new QueryTaskGroupState(cancellationState, queryId);
            m_partitions = partitions;
            m_taskScheduler = taskScheduler;
            m_ignoreOutput = ignoreOutput;
            IntValueEvent consumerEvent = new IntValueEvent();
 
            TraceHelpers.TraceInfo("DefaultMergeHelper::.ctor(..): creating a default merge helper");
 
            // If output won't be ignored, we need to manufacture a set of channels for the consumer.
            // Otherwise, when the merge is executed, we'll just invoke the activities themselves.
            if (!ignoreOutput)
            {
                // Create the asynchronous or synchronous channels, based on whether we're pipelining.
                if (options != ParallelMergeOptions.FullyBuffered)
                {
                    if (partitions.PartitionCount > 1)
                    {
                        m_asyncChannels =
                            MergeExecutor<TInputOutput>.MakeAsynchronousChannels(partitions.PartitionCount, options, consumerEvent, cancellationState.MergedCancellationToken);
                        m_channelEnumerator = new AsynchronousChannelMergeEnumerator<TInputOutput>(m_taskGroupState, m_asyncChannels, consumerEvent);
                    }
                    else
                    {
                        // If there is only one partition, we don't need to create channels. The only producer enumerator
                        // will be used as the result enumerator.
                        m_channelEnumerator = ExceptionAggregator.WrapQueryEnumerator(partitions[0], m_taskGroupState.CancellationState).GetEnumerator();
                    }
                }
                else
                {
                    m_syncChannels =
                        MergeExecutor<TInputOutput>.MakeSynchronousChannels(partitions.PartitionCount);
                    m_channelEnumerator = new SynchronousChannelMergeEnumerator<TInputOutput>(m_taskGroupState, m_syncChannels);
                }
 
                Contract.Assert(m_asyncChannels == null || m_asyncChannels.Length == partitions.PartitionCount);
                Contract.Assert(m_syncChannels == null || m_syncChannels.Length == partitions.PartitionCount);
                Contract.Assert(m_channelEnumerator != null, "enumerator can't be null if we're not ignoring output");
            }
        }
 
        //-----------------------------------------------------------------------------------
        // Schedules execution of the merge itself.
        //
        // Arguments:
        //    ordinalIndexState - the state of the ordinal index of the merged partitions
        //
 
        void IMergeHelper<TInputOutput>.Execute()
        {
            if (m_asyncChannels != null)
            {
                SpoolingTask.SpoolPipeline<TInputOutput, TIgnoreKey>(m_taskGroupState, m_partitions, m_asyncChannels, m_taskScheduler);
            }
            else if (m_syncChannels != null)
            {
                SpoolingTask.SpoolStopAndGo<TInputOutput, TIgnoreKey>(m_taskGroupState, m_partitions, m_syncChannels, m_taskScheduler);
            }
            else if (m_ignoreOutput)
            {
                SpoolingTask.SpoolForAll<TInputOutput, TIgnoreKey>(m_taskGroupState, m_partitions, m_taskScheduler);
            }
            else
            {
                // The last case is a pipelining merge when DOP = 1. In this case, the consumer thread itself will compute the results,
                // so we don't need any tasks to compute the results asynchronously.
                Contract.Assert(m_partitions.PartitionCount == 1);
            }
        }
 
        //-----------------------------------------------------------------------------------
        // Gets the enumerator from which to enumerate output results.
        //
 
        IEnumerator<TInputOutput> IMergeHelper<TInputOutput>.GetEnumerator()
        {
            Contract.Assert(m_ignoreOutput || m_channelEnumerator != null);
            return m_channelEnumerator;
        }
 
        //-----------------------------------------------------------------------------------
        // Returns the results as an array.
        //
        // @
 
 
 
 
 
        public TInputOutput[] GetResultsAsArray()
        {
            if (m_syncChannels != null)
            {
                // Right size an array.
                int totalSize = 0;
                for (int i = 0; i < m_syncChannels.Length; i++)
                {
                    totalSize += m_syncChannels[i].Count;
                }
                TInputOutput[] array = new TInputOutput[totalSize];
 
                // And then blit the elements in.
                int current = 0;
                for (int i = 0; i < m_syncChannels.Length; i++)
                {
                    m_syncChannels[i].CopyTo(array, current);
                    current += m_syncChannels[i].Count;
                }
                return array;
            }
            else
            {
                List<TInputOutput> output = new List<TInputOutput>();
                using (IEnumerator<TInputOutput> enumerator = ((IMergeHelper<TInputOutput>)this).GetEnumerator())
                {
                    while (enumerator.MoveNext())
                    {
                        output.Add(enumerator.Current);
                    }
                }
 
                return output.ToArray();
            }            
        }
    }
}