File: SqlClient\Query\SqlOuterApplyReducer.cs
Project: ndp\fx\src\DLinq\Dlinq\System.Data.Linq.csproj (System.Data.Linq)
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Data.Linq;
using System.Diagnostics.CodeAnalysis;
 
namespace System.Data.Linq.SqlClient {
 
	/// <summary>
	/// </summary>
	internal class SqlOuterApplyReducer { 
		internal static SqlNode Reduce(SqlNode node, SqlFactory factory, SqlNodeAnnotations annotations) {
			Visitor r = new Visitor(factory, annotations);
			return r.Visit(node);			
		}
 
		class Visitor : SqlVisitor {
            SqlFactory factory;
			SqlNodeAnnotations annotations;
 
            internal Visitor(SqlFactory factory, SqlNodeAnnotations annotations) {
                this.factory = factory;
                this.annotations = annotations;
            }
 
            internal override SqlSource VisitSource(SqlSource source) {
                source = base.VisitSource(source);
 
                SqlJoin join = source as SqlJoin;
                if (join != null) {
                    if (join.JoinType == SqlJoinType.OuterApply) {
                        // Reduce outer-apply into left-outer-join
                        HashSet<SqlAlias> leftProducedAliases = SqlGatherProducedAliases.Gather(join.Left);
                        HashSet<SqlExpression> liftedExpressions = new HashSet<SqlExpression>();
 
                        if (SqlPredicateLifter.CanLift(join.Right, leftProducedAliases, liftedExpressions) &&
                            SqlSelectionLifter.CanLift(join.Right, leftProducedAliases, liftedExpressions) &&
                            !SqlAliasDependencyChecker.IsDependent(join.Right, leftProducedAliases, liftedExpressions) ) {
 
                            SqlExpression liftedPredicate = SqlPredicateLifter.Lift(join.Right, leftProducedAliases);
                            List<List<SqlColumn>> liftedSelections = SqlSelectionLifter.Lift(join.Right, leftProducedAliases, liftedExpressions);
 
                            join.JoinType = SqlJoinType.LeftOuter;
                            join.Condition = liftedPredicate;
 
                            if (liftedSelections != null) {
                                foreach(List<SqlColumn> selection in liftedSelections) {
                                    source = this.PushSourceDown(source, selection);
                                }
                            }
                        }
                        else {
                            this.AnnotateSqlIncompatibility(join, SqlProvider.ProviderMode.Sql2000);
                        }
                    }
                    else if (join.JoinType == SqlJoinType.CrossApply) {
                        // reduce cross apply with special nested left-outer-join's into a single left-outer-join
                        //
                        // SELECT x.*, y.*
                        // FROM X
                        // CROSS APPLY (
                        //      SELECT y.*
                        //       FROM (
                        //          SELECT ?
                        //       ) 
                        //       LEFT OUTER JOIN (
                        //          SELECT y.* FROM Y
                        //       ) AS y
                        //
                        // ==>
                        // 
                        // SELECT x.*, y.*
                        // FROM X
                        // LEFT OUTER JOIN (
                        //     SELECT y.* FROM Y
                        // )
 
                        SqlJoin leftOuter = this.GetLeftOuterWithUnreferencedSingletonOnLeft(join.Right);
                        if (leftOuter != null) {
                            HashSet<SqlAlias> leftProducedAliases = SqlGatherProducedAliases.Gather(join.Left);
                            HashSet<SqlExpression> liftedExpressions = new HashSet<SqlExpression>();
 
                            if (SqlPredicateLifter.CanLift(leftOuter.Right, leftProducedAliases, liftedExpressions) &&
                                SqlSelectionLifter.CanLift(leftOuter.Right, leftProducedAliases, liftedExpressions) &&
                                !SqlAliasDependencyChecker.IsDependent(leftOuter.Right, leftProducedAliases, liftedExpressions)
                                ) {
 
                                SqlExpression liftedPredicate = SqlPredicateLifter.Lift(leftOuter.Right, leftProducedAliases);
                                List<List<SqlColumn>> liftedSelections = SqlSelectionLifter.Lift(leftOuter.Right, leftProducedAliases, liftedExpressions);
 
                                // add intermediate selections 
                                this.GetSelectionsBeforeJoin(join.Right, liftedSelections);
 
                                // push down all selections
                                foreach(List<SqlColumn> selection in liftedSelections.Where(s => s.Count > 0)) {
                                    source = this.PushSourceDown(source, selection);
                                }
 
                                join.JoinType = SqlJoinType.LeftOuter;
                                join.Condition = this.factory.AndAccumulate(leftOuter.Condition, liftedPredicate);
                                join.Right = leftOuter.Right;
                            }
                            else {
                                this.AnnotateSqlIncompatibility(join, SqlProvider.ProviderMode.Sql2000);
                            }
                        }
                    }
 
                    // re-balance join tree of left-outer-joins to expose LOJ w/ leftside unreferenced
                    while (join.JoinType == SqlJoinType.LeftOuter) {
                        // look for buried left-outer-joined-with-unreferenced singleton
                        SqlJoin leftLeftOuter = this.GetLeftOuterWithUnreferencedSingletonOnLeft(join.Left);
                        if (leftLeftOuter == null)
                            break;
 
                        List<List<SqlColumn>> liftedSelections = new List<List<SqlColumn>>();
 
                        // add intermediate selections 
                        this.GetSelectionsBeforeJoin(join.Left, liftedSelections);
 
                        // push down all selections
                        foreach(List<SqlColumn> selection in liftedSelections) {
                            source = this.PushSourceDown(source, selection);
                        }
 
                        // bubble this one up on-top of this 'join'.
                        SqlSource jRight = join.Right;
                        SqlExpression jCondition = join.Condition;
 
                        join.Left = leftLeftOuter.Left;
                        join.Right = leftLeftOuter;
                        join.Condition = leftLeftOuter.Condition;
 
                        leftLeftOuter.Left = leftLeftOuter.Right;
                        leftLeftOuter.Right = jRight;
                        leftLeftOuter.Condition = jCondition;
                    }
                }
 
                return source;
            }
 
            private void AnnotateSqlIncompatibility(SqlNode node, params SqlProvider.ProviderMode[] providers) {
                this.annotations.Add(node, new SqlServerCompatibilityAnnotation(Strings.SourceExpressionAnnotation(node.SourceExpression), providers));
            }
 
            [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
            private SqlSource PushSourceDown(SqlSource sqlSource, List<SqlColumn> cols) {
                SqlSelect ns = new SqlSelect(new SqlNop(cols[0].ClrType, cols[0].SqlType, sqlSource.SourceExpression), sqlSource, sqlSource.SourceExpression);
                ns.Row.Columns.AddRange(cols);
                return new SqlAlias(ns);
            }
 
            private SqlJoin GetLeftOuterWithUnreferencedSingletonOnLeft(SqlSource source) {
                SqlAlias alias = source as SqlAlias;
                if (alias != null) {
                    SqlSelect select = alias.Node as SqlSelect;
                    if (select != null &&
                        select.Where == null &&
                        select.Top == null &&
                        select.GroupBy.Count == 0 &&
                        select.OrderBy.Count == 0) {
                        return this.GetLeftOuterWithUnreferencedSingletonOnLeft(select.From);
                    }
                }
                SqlJoin join = source as SqlJoin;
                if (join == null || join.JoinType != SqlJoinType.LeftOuter)
                    return null;
                if (!this.IsSingletonSelect(join.Left))
                    return null;
                HashSet<SqlAlias> p = SqlGatherProducedAliases.Gather(join.Left);
				HashSet<SqlAlias> c = SqlGatherConsumedAliases.Gather(join.Right);
                if (p.Overlaps(c)) {
                    return null;
                }
                return join;
            }
 
            private void GetSelectionsBeforeJoin(SqlSource source, List<List<SqlColumn>> selections) {
                SqlJoin join = source as SqlJoin;
                if (join != null)
                    return;
                SqlAlias alias = source as SqlAlias;
                if (alias != null) {
                    SqlSelect select = alias.Node as SqlSelect;
                    if (select != null) {
                        this.GetSelectionsBeforeJoin(select.From, selections);
                        selections.Add(select.Row.Columns);
                    }
                }
            }
 
            [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
            private bool IsSingletonSelect(SqlSource source) {
                SqlAlias alias = source as SqlAlias;
                if (alias == null)
                    return false;
                SqlSelect select = alias.Node as SqlSelect;
                if (select == null)
                    return false;
                if (select.From != null)
                    return false;
                return true;
            }
		}
 
        class SqlGatherReferencedColumns {
            private SqlGatherReferencedColumns() { }
            internal static HashSet<SqlColumn> Gather(SqlNode node, HashSet<SqlColumn> columns) {
                Visitor v = new Visitor(columns);
                v.Visit(node);
                return columns;
            }
            class Visitor : SqlVisitor {
                HashSet<SqlColumn> columns;
                internal Visitor(HashSet<SqlColumn> columns) {
                    this.columns = columns;
                }
                internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
                    if (!this.columns.Contains(cref.Column)) {
                        this.columns.Add(cref.Column);
                        if (cref.Column.Expression != null) {
                            this.Visit(cref.Column.Expression);
                        }
                    }
                    return cref;
                }
            }
        }
 
        class SqlAliasesReferenced {
            HashSet<SqlAlias> aliases;
            bool referencesAny;
            Visitor visitor;
 
            internal SqlAliasesReferenced(HashSet<SqlAlias> aliases) {
                this.aliases = aliases;
                this.visitor = new Visitor(this);
            }
 
            internal bool ReferencesAny(SqlExpression expression) {
                this.referencesAny = false;
                this.visitor.Visit(expression);
                return this.referencesAny;
            }
 
            class Visitor: SqlVisitor {
                SqlAliasesReferenced parent;
 
                internal Visitor(SqlAliasesReferenced parent) {
                    this.parent = parent;
                }
                internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
                    if (this.parent.aliases.Contains(cref.Column.Alias)) {
                        this.parent.referencesAny = true;
                    }
                    else if (cref.Column.Expression != null) {
                        this.Visit(cref.Column.Expression);
                    }
                    return cref;
                }
 
                internal override SqlExpression VisitColumn(SqlColumn col) {
                    if (col.Expression != null) {
                        this.Visit(col.Expression);
                    }
                    return col;
                }
            }
        }
 
        static class SqlAliasDependencyChecker {
            internal static bool IsDependent(SqlNode node, HashSet<SqlAlias> aliasesToCheck, HashSet<SqlExpression> ignoreExpressions) {
                Visitor v = new Visitor(aliasesToCheck, ignoreExpressions);
                v.Visit(node);
                return v.hasDependency;
            }
            class Visitor : SqlVisitor {
                HashSet<SqlAlias> aliasesToCheck;
                HashSet<SqlExpression> ignoreExpressions;
                internal bool hasDependency;
 
                internal Visitor(HashSet<SqlAlias> aliasesToCheck, HashSet<SqlExpression> ignoreExpressions) {
                    this.aliasesToCheck = aliasesToCheck;
                    this.ignoreExpressions = ignoreExpressions;
                }
                internal override SqlNode Visit(SqlNode node) {
                    SqlExpression e = node as SqlExpression;
                    if (this.hasDependency)
                        return node;
                    if (e != null && this.ignoreExpressions.Contains(e)) {
                        return node;
                    }
                    return base.Visit(node);
                }
                internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
                    if (this.aliasesToCheck.Contains(cref.Column.Alias)) {
                        this.hasDependency = true;
                    }
                    else if (cref.Column.Expression != null) {
                        this.Visit(cref.Column.Expression);
                    }
                    return cref;
                }
                internal override SqlExpression VisitColumn(SqlColumn col) {
                    if (col.Expression != null) {
                        this.Visit(col.Expression);
                    }
                    return col;
                }
            }
        }
 
        static class SqlPredicateLifter {
            internal static bool CanLift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
                System.Diagnostics.Debug.Assert(source != null);
                System.Diagnostics.Debug.Assert(aliasesForLifting != null);
                Visitor v = new Visitor(false, aliasesForLifting, liftedExpressions);
                v.VisitSource(source);
                return v.canLiftAll;
            }
 
            internal static SqlExpression Lift(SqlSource source, HashSet<SqlAlias> aliasesForLifting) {
                System.Diagnostics.Debug.Assert(source != null);
                System.Diagnostics.Debug.Assert(aliasesForLifting != null);
                Visitor v = new Visitor(true, aliasesForLifting, null);
                v.VisitSource(source);
                return v.lifted;
            }
 
            class Visitor : SqlVisitor {
                SqlAliasesReferenced aliases;
                HashSet<SqlExpression> liftedExpressions;
                bool doLifting;
                internal bool canLiftAll;
                internal SqlExpression lifted;
                SqlAggregateChecker aggregateChecker;
 
 
                internal Visitor(bool doLifting, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
                    this.doLifting = doLifting;
                    this.aliases = new SqlAliasesReferenced(aliasesForLifting);
                    this.liftedExpressions = liftedExpressions;
                    this.canLiftAll = true;
                    this.aggregateChecker = new SqlAggregateChecker();
                }
 
                internal override SqlSelect VisitSelect(SqlSelect select) {
                    // check subqueries first
                    this.VisitSource(select.From);
 
                    // don't allow lifting through these operations
                    if (select.Top != null ||
                        select.GroupBy.Count > 0 ||
                        this.aggregateChecker.HasAggregates(select) ||
                        select.IsDistinct) {
                        this.canLiftAll = false;
                    }
 
                    // only lift predicates that actually reference the aliases
                    if (this.canLiftAll && select.Where != null) {
                        bool referencesAliases = this.aliases.ReferencesAny(select.Where);
                        if (referencesAliases) {
                            if (this.liftedExpressions != null) {
                                this.liftedExpressions.Add(select.Where);
                            }
                            if (this.doLifting) {
                                if (this.lifted != null)
                                    this.lifted = new SqlBinary(SqlNodeType.And, this.lifted.ClrType, this.lifted.SqlType, this.lifted, select.Where);
                                else
                                    this.lifted = select.Where;
                                select.Where = null;
                            }
                        }
                    }
 
                    return select;
                }
            }
        }
 
        static class SqlSelectionLifter {
            internal static bool CanLift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
                Visitor v = new Visitor(false, aliasesForLifting, liftedExpressions);
                v.VisitSource(source);
                return v.canLiftAll;
            }
 
            internal static List<List<SqlColumn>> Lift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
                Visitor v = new Visitor(true, aliasesForLifting, liftedExpressions);
                v.VisitSource(source);
                return v.lifted;
            }
 
            class Visitor : SqlVisitor {
                SqlAliasesReferenced aliases;
                HashSet<SqlColumn> referencedColumns;
                HashSet<SqlExpression> liftedExpressions;
                internal List<List<SqlColumn>> lifted;
                internal bool canLiftAll;
                bool hasLifted;
                bool doLifting;
                SqlAggregateChecker aggregateChecker;
 
                internal Visitor(bool doLifting, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
                    this.doLifting = doLifting;
                    this.aliases = new SqlAliasesReferenced(aliasesForLifting);
                    this.referencedColumns = new HashSet<SqlColumn>();
                    this.liftedExpressions = liftedExpressions;
                    this.canLiftAll = true;
                    if (doLifting)
                        this.lifted = new List<List<SqlColumn>>();
                    this.aggregateChecker = new SqlAggregateChecker();
                }
 
                internal override SqlSource VisitJoin(SqlJoin join) {
                    this.ReferenceColumns(join.Condition);
                    return base.VisitJoin(join);
                }
 
                internal override SqlSelect VisitSelect(SqlSelect select) {
                    // reference all columns
                    this.ReferenceColumns(select.Where);
                    foreach(SqlOrderExpression oe in select.OrderBy) {
                        // 
                        this.ReferenceColumns(oe.Expression);
                    }
                    foreach(SqlExpression e in select.GroupBy) {
                        // 
                        this.ReferenceColumns(e);
                    }
                    this.ReferenceColumns(select.Having);
 
                    // determine what if anything should be lifted from this select
                    List<SqlColumn> lift = null;
                    List<SqlColumn> keep = null;
                    foreach (SqlColumn sc in select.Row.Columns) {
                        bool referencesAliasesForLifting = this.aliases.ReferencesAny(sc.Expression);
                        bool isLockedExpression = this.referencedColumns.Contains(sc);
                        if (referencesAliasesForLifting) {
                            // 
                            if (isLockedExpression) {
                                this.canLiftAll = false;
                                this.ReferenceColumns(sc);
                            }
                            else {
                                this.hasLifted = true;
                                if (this.doLifting) {
                                    if (lift == null)
                                        lift = new List<SqlColumn>();
                                    lift.Add(sc);
                                }
                            }
                        }
                        else {
                            if (this.doLifting) {
                                if (keep == null)
                                    keep = new List<SqlColumn>();
                                keep.Add(sc);
                            }
                            this.ReferenceColumns(sc);
                        }
                    }
 
                    // check subqueries too
                    if (this.canLiftAll) {
                        this.VisitSource(select.From);
                    }
 
                    // don't allow lifting through these operations
                    if (select.Top != null ||
                        select.GroupBy.Count > 0 ||
                        this.aggregateChecker.HasAggregates(select) ||
                        select.IsDistinct) {
                        if (this.hasLifted) {
                            // 
                            this.canLiftAll = false;
                        }
                    }
 
                    // do the actual lifting for this select
                    if (this.doLifting && this.canLiftAll) {
                        select.Row.Columns.Clear();
                        if (keep != null)
                            select.Row.Columns.AddRange(keep);
                        if (lift != null) {
                            // 
                            this.lifted.Add(lift);
                        }
                    }
 
                    return select;
                }
 
                private void ReferenceColumns(SqlExpression expression) {
                    if (expression != null) {
                        if (this.liftedExpressions == null || !this.liftedExpressions.Contains(expression)) {
                            SqlGatherReferencedColumns.Gather(expression, this.referencedColumns);
                        }
                    }
                }
            }
        }
	}
}