File: System.Activities.Presentation\System\Activities\Presentation\Model\GraphManager.cs
Project: ndp\cdf\src\NetFx40\Tools\System.Activities.Presentation.csproj (System.Activities.Presentation)
//----------------------------------------------------------------
// <copyright company="Microsoft Corporation">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
//----------------------------------------------------------------
 
namespace System.Activities.Presentation.Model
{
    using System;
    using System.Activities.Presentation;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Linq;
    using System.Runtime;
    using System.Text;
    using Microsoft.Activities.Presentation.Xaml;
 
    // The graph is completely defined by a collection of vertices and a collection of edges. The back pointers are not part of the graph
    // definition, but as an auxiliary to quickly trace back to the root vertex if it is reachable from the root. A vertex should have
    // no back pointers if it's not reachable from the root.
    // This abstract base class is responsible for managing back pointers while the dervied class is responsible for managing vertices and edges.
    internal abstract class GraphManager<TVertex, TEdge, TBackPointer> where TVertex : class
    {
        protected abstract TVertex Root { get; }
 
        internal void VerifyBackPointers()
        {
            ICollection<TVertex> reachableVertices = this.CalculateReachableVertices(true);
 
            foreach (TVertex vertex in this.GetVertices())
            {
                if (reachableVertices.Contains(vertex))
                {
                    foreach (TBackPointer backPointer in this.GetBackPointers(vertex))
                    {
                        if (!reachableVertices.Contains(this.GetDestinationVertexFromBackPointer(backPointer)))
                        {
                            Fx.Assert(false, "a reachable vertex's back pointer should not point to a vertex that is not reachable");
                        }
 
                        if (!this.HasAssociatedEdge(backPointer))
                        {
                            Fx.Assert(false, "a reachable vertex doesn't have an outgoing edge to one of the vertex that have a back pointer to it");
                        }
                    }
                }
                else
                {
                    if (this.GetBackPointers(vertex).Count() != 0)
                    {
                        Fx.Assert(false, "an unreachable vertex should not have any back pointer");
                    }
                }
            }
        }
 
        protected ICollection<TVertex> CalculateReachableVertices(bool verifyBackPointers = false)
        {
            HashSet<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
 
            if (this.Root == null)
            {
                return reachableVertices;
            }
 
            Queue<TVertex> queue = new Queue<TVertex>();
            queue.Enqueue(this.Root);
            reachableVertices.Add(this.Root);
 
            while (queue.Count > 0)
            {
                TVertex vertex = queue.Dequeue();
 
                foreach (TEdge edge in this.GetOutEdges(vertex))
                {
                    if (verifyBackPointers && !this.HasBackPointer(edge))
                    {
                        Fx.Assert(false, "a reachable vertex doesn't have a back pointer to one of its incoming edges");
                    }
 
                    TVertex to = this.GetDestinationVertexFromEdge(edge);
                    if (!reachableVertices.Contains(to))
                    {
                        reachableVertices.Add(to);
                        queue.Enqueue(to);
                    }
                }
            }
 
            return reachableVertices;
        }
 
        protected void OnRootChanged(TVertex oldRoot, TVertex newRoot)
        {
            if (oldRoot != null)
            {
                this.RemoveBackPointers(oldRoot, true);
            }
 
            if (newRoot != null)
            {
                this.AddBackPointers(newRoot);
            }
        }
 
        protected abstract IEnumerable<TVertex> GetVertices();
 
        protected abstract IEnumerable<TEdge> GetOutEdges(TVertex vertex);
 
        protected abstract IEnumerable<TBackPointer> GetBackPointers(TVertex vertex);
 
        protected abstract TVertex GetDestinationVertexFromEdge(TEdge edge);
 
        protected abstract TVertex GetSourceVertexFromEdge(TEdge edge);
 
        protected abstract TVertex GetDestinationVertexFromBackPointer(TBackPointer backPointer);
 
        // call this method when an edge is removed
        protected void OnEdgeRemoved(TEdge edgeRemoved)
        {
            Fx.Assert(edgeRemoved != null, "edgeRemoved should not be null");
 
            TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeRemoved);
            if (!this.CanReachRootViaBackPointer(sourceVertex))
            {
                return;
            }
 
            this.RemoveAssociatedBackPointer(edgeRemoved);
            TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeRemoved);
 
            this.RemoveBackPointers(destinationVertex);
        }
 
        // call this method when an edge is added
        protected void OnEdgeAdded(TEdge edgeAdded)
        {
            Fx.Assert(edgeAdded != null, "edgeAdded should not be null");
 
            TVertex sourceVertex = this.GetSourceVertexFromEdge(edgeAdded);
            if (!this.CanReachRootViaBackPointer(sourceVertex))
            {
                return;
            }
 
            TVertex destinationVertex = this.GetDestinationVertexFromEdge(edgeAdded);
            bool wasReachable = this.CanReachRootViaBackPointer(destinationVertex);
            this.AddAssociatedBackPointer(edgeAdded);
 
            if (wasReachable)
            {
                return;
            }
 
            this.AddBackPointers(destinationVertex);
        }
 
        protected abstract void RemoveAssociatedBackPointer(TEdge edge);
 
        protected abstract void AddAssociatedBackPointer(TEdge edge);
 
        protected abstract bool HasBackPointer(TEdge edge);
 
        protected abstract bool HasAssociatedEdge(TBackPointer backPointer);
 
        protected abstract void OnVerticesBecameReachable(IEnumerable<TVertex> reachableVertices);
 
        protected abstract void OnVerticesBecameUnreachable(IEnumerable<TVertex> unreachableVertices);
 
        private bool CanReachRootViaBackPointer(TVertex vertex)
        {
            Fx.Assert(vertex != null, "vertex should not be null");
 
            if (vertex == this.Root)
            {
                return true;
            }
 
            HashSet<TVertex> visited = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
            Queue<TVertex> queue = new Queue<TVertex>();
 
            visited.Add(vertex);
            queue.Enqueue(vertex);
 
            while (queue.Count > 0)
            {
                TVertex current = queue.Dequeue();
                foreach (TBackPointer backPointer in this.GetBackPointers(current))
                {
                    TVertex destinationVertex = this.GetDestinationVertexFromBackPointer(backPointer);
                    if (object.ReferenceEquals(destinationVertex, this.Root))
                    {
                        return true;
                    }
 
                    if (!visited.Contains(destinationVertex))
                    {
                        visited.Add(destinationVertex);
                        queue.Enqueue(destinationVertex);
                    }
                }
            }
 
            return false;
        }
 
        // traverse the sub-graph starting from vertex and add back pointers
        private void AddBackPointers(TVertex vertex)
        {
            HashSet<TVertex> verticesBecameReachable = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
            Queue<TVertex> queue = new Queue<TVertex>();
 
            verticesBecameReachable.Add(vertex);
            queue.Enqueue(vertex);
 
            while (queue.Count > 0)
            {
                TVertex currentVertex = queue.Dequeue();
 
                foreach (TEdge edge in this.GetOutEdges(currentVertex))
                {
                    TVertex destinationVertex = this.GetDestinationVertexFromEdge(edge);
                    bool wasReachable = this.GetBackPointers(destinationVertex).Count() > 0;
                    this.AddAssociatedBackPointer(edge);
                    if (!wasReachable && !verticesBecameReachable.Contains(destinationVertex))
                    {
                        verticesBecameReachable.Add(destinationVertex);
                        queue.Enqueue(destinationVertex);
                    }
                }
            }
 
            this.OnVerticesBecameReachable(verticesBecameReachable);
        }
 
        // traverse the sub-graph starting from vertex, if a vertex is reachable then stop traversing its descendents,
        // otherwise remove back pointers that pointer to it and continue traversing its descendents
        private void RemoveBackPointers(TVertex vertex, bool isAllVerticesUnreachable = false)
        {
            ICollection<TVertex> reachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
 
            if (!isAllVerticesUnreachable)
            {
                reachableVertices = this.CalculateReachableVertices();
            }
 
            if (reachableVertices.Contains(vertex))
            {
                return;
            }
 
            Queue<TVertex> queue = new Queue<TVertex>();
            HashSet<TVertex> unreachableVertices = new HashSet<TVertex>(ObjectReferenceEqualityComparer<TVertex>.Default);
 
            unreachableVertices.Add(vertex);
            queue.Enqueue(vertex);
 
            while (queue.Count > 0)
            {
                TVertex unreachableVertex = queue.Dequeue();
                foreach (TEdge edge in this.GetOutEdges(unreachableVertex))
                {
                    this.RemoveAssociatedBackPointer(edge);
                    TVertex to = this.GetDestinationVertexFromEdge(edge);
                    if (isAllVerticesUnreachable || !reachableVertices.Contains(to))
                    {
                        if (!unreachableVertices.Contains(to))
                        {
                            unreachableVertices.Add(to);
                            queue.Enqueue(to);
                        }
                    }
                }
            }
 
            this.OnVerticesBecameUnreachable(unreachableVertices);
        }
    }
}