|
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Data.Linq;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// </summary>
internal class SqlOuterApplyReducer {
internal static SqlNode Reduce(SqlNode node, SqlFactory factory, SqlNodeAnnotations annotations) {
Visitor r = new Visitor(factory, annotations);
return r.Visit(node);
}
class Visitor : SqlVisitor {
SqlFactory factory;
SqlNodeAnnotations annotations;
internal Visitor(SqlFactory factory, SqlNodeAnnotations annotations) {
this.factory = factory;
this.annotations = annotations;
}
internal override SqlSource VisitSource(SqlSource source) {
source = base.VisitSource(source);
SqlJoin join = source as SqlJoin;
if (join != null) {
if (join.JoinType == SqlJoinType.OuterApply) {
// Reduce outer-apply into left-outer-join
HashSet<SqlAlias> leftProducedAliases = SqlGatherProducedAliases.Gather(join.Left);
HashSet<SqlExpression> liftedExpressions = new HashSet<SqlExpression>();
if (SqlPredicateLifter.CanLift(join.Right, leftProducedAliases, liftedExpressions) &&
SqlSelectionLifter.CanLift(join.Right, leftProducedAliases, liftedExpressions) &&
!SqlAliasDependencyChecker.IsDependent(join.Right, leftProducedAliases, liftedExpressions) ) {
SqlExpression liftedPredicate = SqlPredicateLifter.Lift(join.Right, leftProducedAliases);
List<List<SqlColumn>> liftedSelections = SqlSelectionLifter.Lift(join.Right, leftProducedAliases, liftedExpressions);
join.JoinType = SqlJoinType.LeftOuter;
join.Condition = liftedPredicate;
if (liftedSelections != null) {
foreach(List<SqlColumn> selection in liftedSelections) {
source = this.PushSourceDown(source, selection);
}
}
}
else {
this.AnnotateSqlIncompatibility(join, SqlProvider.ProviderMode.Sql2000);
}
}
else if (join.JoinType == SqlJoinType.CrossApply) {
// reduce cross apply with special nested left-outer-join's into a single left-outer-join
//
// SELECT x.*, y.*
// FROM X
// CROSS APPLY (
// SELECT y.*
// FROM (
// SELECT ?
// )
// LEFT OUTER JOIN (
// SELECT y.* FROM Y
// ) AS y
//
// ==>
//
// SELECT x.*, y.*
// FROM X
// LEFT OUTER JOIN (
// SELECT y.* FROM Y
// )
SqlJoin leftOuter = this.GetLeftOuterWithUnreferencedSingletonOnLeft(join.Right);
if (leftOuter != null) {
HashSet<SqlAlias> leftProducedAliases = SqlGatherProducedAliases.Gather(join.Left);
HashSet<SqlExpression> liftedExpressions = new HashSet<SqlExpression>();
if (SqlPredicateLifter.CanLift(leftOuter.Right, leftProducedAliases, liftedExpressions) &&
SqlSelectionLifter.CanLift(leftOuter.Right, leftProducedAliases, liftedExpressions) &&
!SqlAliasDependencyChecker.IsDependent(leftOuter.Right, leftProducedAliases, liftedExpressions)
) {
SqlExpression liftedPredicate = SqlPredicateLifter.Lift(leftOuter.Right, leftProducedAliases);
List<List<SqlColumn>> liftedSelections = SqlSelectionLifter.Lift(leftOuter.Right, leftProducedAliases, liftedExpressions);
// add intermediate selections
this.GetSelectionsBeforeJoin(join.Right, liftedSelections);
// push down all selections
foreach(List<SqlColumn> selection in liftedSelections.Where(s => s.Count > 0)) {
source = this.PushSourceDown(source, selection);
}
join.JoinType = SqlJoinType.LeftOuter;
join.Condition = this.factory.AndAccumulate(leftOuter.Condition, liftedPredicate);
join.Right = leftOuter.Right;
}
else {
this.AnnotateSqlIncompatibility(join, SqlProvider.ProviderMode.Sql2000);
}
}
}
// re-balance join tree of left-outer-joins to expose LOJ w/ leftside unreferenced
while (join.JoinType == SqlJoinType.LeftOuter) {
// look for buried left-outer-joined-with-unreferenced singleton
SqlJoin leftLeftOuter = this.GetLeftOuterWithUnreferencedSingletonOnLeft(join.Left);
if (leftLeftOuter == null)
break;
List<List<SqlColumn>> liftedSelections = new List<List<SqlColumn>>();
// add intermediate selections
this.GetSelectionsBeforeJoin(join.Left, liftedSelections);
// push down all selections
foreach(List<SqlColumn> selection in liftedSelections) {
source = this.PushSourceDown(source, selection);
}
// bubble this one up on-top of this 'join'.
SqlSource jRight = join.Right;
SqlExpression jCondition = join.Condition;
join.Left = leftLeftOuter.Left;
join.Right = leftLeftOuter;
join.Condition = leftLeftOuter.Condition;
leftLeftOuter.Left = leftLeftOuter.Right;
leftLeftOuter.Right = jRight;
leftLeftOuter.Condition = jCondition;
}
}
return source;
}
private void AnnotateSqlIncompatibility(SqlNode node, params SqlProvider.ProviderMode[] providers) {
this.annotations.Add(node, new SqlServerCompatibilityAnnotation(Strings.SourceExpressionAnnotation(node.SourceExpression), providers));
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlSource PushSourceDown(SqlSource sqlSource, List<SqlColumn> cols) {
SqlSelect ns = new SqlSelect(new SqlNop(cols[0].ClrType, cols[0].SqlType, sqlSource.SourceExpression), sqlSource, sqlSource.SourceExpression);
ns.Row.Columns.AddRange(cols);
return new SqlAlias(ns);
}
private SqlJoin GetLeftOuterWithUnreferencedSingletonOnLeft(SqlSource source) {
SqlAlias alias = source as SqlAlias;
if (alias != null) {
SqlSelect select = alias.Node as SqlSelect;
if (select != null &&
select.Where == null &&
select.Top == null &&
select.GroupBy.Count == 0 &&
select.OrderBy.Count == 0) {
return this.GetLeftOuterWithUnreferencedSingletonOnLeft(select.From);
}
}
SqlJoin join = source as SqlJoin;
if (join == null || join.JoinType != SqlJoinType.LeftOuter)
return null;
if (!this.IsSingletonSelect(join.Left))
return null;
HashSet<SqlAlias> p = SqlGatherProducedAliases.Gather(join.Left);
HashSet<SqlAlias> c = SqlGatherConsumedAliases.Gather(join.Right);
if (p.Overlaps(c)) {
return null;
}
return join;
}
private void GetSelectionsBeforeJoin(SqlSource source, List<List<SqlColumn>> selections) {
SqlJoin join = source as SqlJoin;
if (join != null)
return;
SqlAlias alias = source as SqlAlias;
if (alias != null) {
SqlSelect select = alias.Node as SqlSelect;
if (select != null) {
this.GetSelectionsBeforeJoin(select.From, selections);
selections.Add(select.Row.Columns);
}
}
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private bool IsSingletonSelect(SqlSource source) {
SqlAlias alias = source as SqlAlias;
if (alias == null)
return false;
SqlSelect select = alias.Node as SqlSelect;
if (select == null)
return false;
if (select.From != null)
return false;
return true;
}
}
class SqlGatherReferencedColumns {
private SqlGatherReferencedColumns() { }
internal static HashSet<SqlColumn> Gather(SqlNode node, HashSet<SqlColumn> columns) {
Visitor v = new Visitor(columns);
v.Visit(node);
return columns;
}
class Visitor : SqlVisitor {
HashSet<SqlColumn> columns;
internal Visitor(HashSet<SqlColumn> columns) {
this.columns = columns;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
if (!this.columns.Contains(cref.Column)) {
this.columns.Add(cref.Column);
if (cref.Column.Expression != null) {
this.Visit(cref.Column.Expression);
}
}
return cref;
}
}
}
class SqlAliasesReferenced {
HashSet<SqlAlias> aliases;
bool referencesAny;
Visitor visitor;
internal SqlAliasesReferenced(HashSet<SqlAlias> aliases) {
this.aliases = aliases;
this.visitor = new Visitor(this);
}
internal bool ReferencesAny(SqlExpression expression) {
this.referencesAny = false;
this.visitor.Visit(expression);
return this.referencesAny;
}
class Visitor: SqlVisitor {
SqlAliasesReferenced parent;
internal Visitor(SqlAliasesReferenced parent) {
this.parent = parent;
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
if (this.parent.aliases.Contains(cref.Column.Alias)) {
this.parent.referencesAny = true;
}
else if (cref.Column.Expression != null) {
this.Visit(cref.Column.Expression);
}
return cref;
}
internal override SqlExpression VisitColumn(SqlColumn col) {
if (col.Expression != null) {
this.Visit(col.Expression);
}
return col;
}
}
}
static class SqlAliasDependencyChecker {
internal static bool IsDependent(SqlNode node, HashSet<SqlAlias> aliasesToCheck, HashSet<SqlExpression> ignoreExpressions) {
Visitor v = new Visitor(aliasesToCheck, ignoreExpressions);
v.Visit(node);
return v.hasDependency;
}
class Visitor : SqlVisitor {
HashSet<SqlAlias> aliasesToCheck;
HashSet<SqlExpression> ignoreExpressions;
internal bool hasDependency;
internal Visitor(HashSet<SqlAlias> aliasesToCheck, HashSet<SqlExpression> ignoreExpressions) {
this.aliasesToCheck = aliasesToCheck;
this.ignoreExpressions = ignoreExpressions;
}
internal override SqlNode Visit(SqlNode node) {
SqlExpression e = node as SqlExpression;
if (this.hasDependency)
return node;
if (e != null && this.ignoreExpressions.Contains(e)) {
return node;
}
return base.Visit(node);
}
internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
if (this.aliasesToCheck.Contains(cref.Column.Alias)) {
this.hasDependency = true;
}
else if (cref.Column.Expression != null) {
this.Visit(cref.Column.Expression);
}
return cref;
}
internal override SqlExpression VisitColumn(SqlColumn col) {
if (col.Expression != null) {
this.Visit(col.Expression);
}
return col;
}
}
}
static class SqlPredicateLifter {
internal static bool CanLift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
System.Diagnostics.Debug.Assert(source != null);
System.Diagnostics.Debug.Assert(aliasesForLifting != null);
Visitor v = new Visitor(false, aliasesForLifting, liftedExpressions);
v.VisitSource(source);
return v.canLiftAll;
}
internal static SqlExpression Lift(SqlSource source, HashSet<SqlAlias> aliasesForLifting) {
System.Diagnostics.Debug.Assert(source != null);
System.Diagnostics.Debug.Assert(aliasesForLifting != null);
Visitor v = new Visitor(true, aliasesForLifting, null);
v.VisitSource(source);
return v.lifted;
}
class Visitor : SqlVisitor {
SqlAliasesReferenced aliases;
HashSet<SqlExpression> liftedExpressions;
bool doLifting;
internal bool canLiftAll;
internal SqlExpression lifted;
SqlAggregateChecker aggregateChecker;
internal Visitor(bool doLifting, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
this.doLifting = doLifting;
this.aliases = new SqlAliasesReferenced(aliasesForLifting);
this.liftedExpressions = liftedExpressions;
this.canLiftAll = true;
this.aggregateChecker = new SqlAggregateChecker();
}
internal override SqlSelect VisitSelect(SqlSelect select) {
// check subqueries first
this.VisitSource(select.From);
// don't allow lifting through these operations
if (select.Top != null ||
select.GroupBy.Count > 0 ||
this.aggregateChecker.HasAggregates(select) ||
select.IsDistinct) {
this.canLiftAll = false;
}
// only lift predicates that actually reference the aliases
if (this.canLiftAll && select.Where != null) {
bool referencesAliases = this.aliases.ReferencesAny(select.Where);
if (referencesAliases) {
if (this.liftedExpressions != null) {
this.liftedExpressions.Add(select.Where);
}
if (this.doLifting) {
if (this.lifted != null)
this.lifted = new SqlBinary(SqlNodeType.And, this.lifted.ClrType, this.lifted.SqlType, this.lifted, select.Where);
else
this.lifted = select.Where;
select.Where = null;
}
}
}
return select;
}
}
}
static class SqlSelectionLifter {
internal static bool CanLift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
Visitor v = new Visitor(false, aliasesForLifting, liftedExpressions);
v.VisitSource(source);
return v.canLiftAll;
}
internal static List<List<SqlColumn>> Lift(SqlSource source, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
Visitor v = new Visitor(true, aliasesForLifting, liftedExpressions);
v.VisitSource(source);
return v.lifted;
}
class Visitor : SqlVisitor {
SqlAliasesReferenced aliases;
HashSet<SqlColumn> referencedColumns;
HashSet<SqlExpression> liftedExpressions;
internal List<List<SqlColumn>> lifted;
internal bool canLiftAll;
bool hasLifted;
bool doLifting;
SqlAggregateChecker aggregateChecker;
internal Visitor(bool doLifting, HashSet<SqlAlias> aliasesForLifting, HashSet<SqlExpression> liftedExpressions) {
this.doLifting = doLifting;
this.aliases = new SqlAliasesReferenced(aliasesForLifting);
this.referencedColumns = new HashSet<SqlColumn>();
this.liftedExpressions = liftedExpressions;
this.canLiftAll = true;
if (doLifting)
this.lifted = new List<List<SqlColumn>>();
this.aggregateChecker = new SqlAggregateChecker();
}
internal override SqlSource VisitJoin(SqlJoin join) {
this.ReferenceColumns(join.Condition);
return base.VisitJoin(join);
}
internal override SqlSelect VisitSelect(SqlSelect select) {
// reference all columns
this.ReferenceColumns(select.Where);
foreach(SqlOrderExpression oe in select.OrderBy) {
//
this.ReferenceColumns(oe.Expression);
}
foreach(SqlExpression e in select.GroupBy) {
//
this.ReferenceColumns(e);
}
this.ReferenceColumns(select.Having);
// determine what if anything should be lifted from this select
List<SqlColumn> lift = null;
List<SqlColumn> keep = null;
foreach (SqlColumn sc in select.Row.Columns) {
bool referencesAliasesForLifting = this.aliases.ReferencesAny(sc.Expression);
bool isLockedExpression = this.referencedColumns.Contains(sc);
if (referencesAliasesForLifting) {
//
if (isLockedExpression) {
this.canLiftAll = false;
this.ReferenceColumns(sc);
}
else {
this.hasLifted = true;
if (this.doLifting) {
if (lift == null)
lift = new List<SqlColumn>();
lift.Add(sc);
}
}
}
else {
if (this.doLifting) {
if (keep == null)
keep = new List<SqlColumn>();
keep.Add(sc);
}
this.ReferenceColumns(sc);
}
}
// check subqueries too
if (this.canLiftAll) {
this.VisitSource(select.From);
}
// don't allow lifting through these operations
if (select.Top != null ||
select.GroupBy.Count > 0 ||
this.aggregateChecker.HasAggregates(select) ||
select.IsDistinct) {
if (this.hasLifted) {
//
this.canLiftAll = false;
}
}
// do the actual lifting for this select
if (this.doLifting && this.canLiftAll) {
select.Row.Columns.Clear();
if (keep != null)
select.Row.Columns.AddRange(keep);
if (lift != null) {
//
this.lifted.Add(lift);
}
}
return select;
}
private void ReferenceColumns(SqlExpression expression) {
if (expression != null) {
if (this.liftedExpressions == null || !this.liftedExpressions.Contains(expression)) {
SqlGatherReferencedColumns.Gather(expression, this.referencedColumns);
}
}
}
}
}
}
}
|