|
using System.Collections.Generic;
using System.Data.Linq.Mapping;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// Binds MemberAccess
/// Prefetches deferrable expressions (SqlLink) if necessary
/// Translates structured object comparision (EQ, NE) into memberwise comparison
/// Translates shared expressions (SqlSharedExpression, SqlSharedExpressionRef)
/// Optimizes out simple redundant operations :
/// XXX OR TRUE ==> TRUE
/// XXX AND FALSE ==> FALSE
/// NON-NULL EQ NULL ==> FALSE
/// NON-NULL NEQ NULL ==> TRUE
/// </summary>
internal class SqlBinder {
SqlColumnizer columnizer;
Visitor visitor;
SqlFactory sql;
Func<SqlNode, SqlNode> prebinder;
bool optimizeLinkExpansions = true;
bool simplifyCaseStatements = true;
internal SqlBinder(Translator translator, SqlFactory sqlFactory, MetaModel model, DataLoadOptions shape, SqlColumnizer columnizer, bool canUseOuterApply) {
this.sql = sqlFactory;
this.columnizer = columnizer;
this.visitor = new Visitor(this, translator, this.columnizer, this.sql, model, shape, canUseOuterApply);
}
internal Func<SqlNode, SqlNode> PreBinder {
get { return this.prebinder; }
set { this.prebinder = value; }
}
private SqlNode Prebind(SqlNode node) {
if (this.prebinder != null) {
node = this.prebinder(node);
}
return node;
}
class LinkOptimizationScope {
Dictionary<object, SqlExpression> map;
LinkOptimizationScope previous;
internal LinkOptimizationScope(LinkOptimizationScope previous) {
this.previous = previous;
}
internal void Add(object linkId, SqlExpression expr) {
if (this.map == null) {
this.map = new Dictionary<object,SqlExpression>();
}
this.map.Add(linkId, expr);
}
internal bool TryGetValue(object linkId, out SqlExpression expr) {
expr = null;
return (this.map != null && this.map.TryGetValue(linkId, out expr)) ||
(this.previous != null && this.previous.TryGetValue(linkId, out expr));
}
}
internal SqlNode Bind(SqlNode node) {
node = Prebind(node);
node = this.visitor.Visit(node);
return node;
}
internal bool OptimizeLinkExpansions {
get { return this.optimizeLinkExpansions; }
set { this.optimizeLinkExpansions = value; }
}
internal bool SimplifyCaseStatements {
get { return this.simplifyCaseStatements; }
set { this.simplifyCaseStatements = value; }
}
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
class Visitor : SqlVisitor {
SqlBinder binder;
Translator translator;
SqlFactory sql;
TypeSystemProvider typeProvider;
SqlExpander expander;
SqlColumnizer columnizer;
SqlAggregateChecker aggregateChecker;
SqlSelect currentSelect;
SqlAlias currentAlias;
Dictionary<SqlAlias, SqlAlias> outerAliasMap;
LinkOptimizationScope linkMap;
MetaModel model;
HashSet<MetaType> alreadyIncluded;
DataLoadOptions shape;
bool disableInclude;
bool inGroupBy;
bool canUseOuterApply;
internal Visitor(SqlBinder binder, Translator translator, SqlColumnizer columnizer, SqlFactory sqlFactory, MetaModel model, DataLoadOptions shape, bool canUseOuterApply) {
this.binder = binder;
this.translator = translator;
this.columnizer = columnizer;
this.sql = sqlFactory;
this.typeProvider = sqlFactory.TypeProvider;
this.expander = new SqlExpander(this.sql);
this.aggregateChecker = new SqlAggregateChecker();
this.linkMap = new LinkOptimizationScope(null);
this.outerAliasMap = new Dictionary<SqlAlias, SqlAlias>();
this.model = model;
this.shape = shape;
this.canUseOuterApply = canUseOuterApply;
}
internal override SqlExpression VisitExpression(SqlExpression expr) {
return this.ConvertToExpression(this.Visit(expr));
}
internal override SqlNode VisitIncludeScope(SqlIncludeScope scope) {
this.alreadyIncluded = new HashSet<MetaType>();
try {
return this.Visit(scope.Child); // Strip the include scope so SqlBinder will be idempotent.
}
finally {
this.alreadyIncluded = null;
}
}
internal override SqlUserQuery VisitUserQuery(SqlUserQuery suq) {
this.disableInclude = true;
return base.VisitUserQuery(suq);
}
internal SqlExpression FetchExpression(SqlExpression expr) {
return this.ConvertToExpression(this.ConvertToFetchedExpression(this.ConvertLinks(this.VisitExpression(expr))));
}
internal override SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
fc.Arguments[i] = this.FetchExpression(fc.Arguments[i]);
}
return fc;
}
internal override SqlExpression VisitLike(SqlLike like) {
like.Expression = this.FetchExpression(like.Expression);
like.Pattern = this.FetchExpression(like.Pattern);
return base.VisitLike(like);
}
internal override SqlExpression VisitGrouping(SqlGrouping g) {
g.Key = this.FetchExpression(g.Key);
g.Group = this.FetchExpression(g.Group);
return g;
}
internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
mc.Object = this.FetchExpression(mc.Object);
for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
mc.Arguments[i] = this.FetchExpression(mc.Arguments[i]);
}
return mc;
}
[SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
// Below we translate comparisons with constant NULL to either IS NULL or IS NOT NULL.
// We only want to do this if the type of the binary expression is not nullable.
switch (bo.NodeType) {
case SqlNodeType.EQ:
case SqlNodeType.EQ2V:
if (this.IsConstNull(bo.Left) && !TypeSystem.IsNullableType(bo.ClrType)) {
return this.VisitUnaryOperator(this.sql.Unary(SqlNodeType.IsNull, bo.Right, bo.SourceExpression));
}
else if (this.IsConstNull(bo.Right) && !TypeSystem.IsNullableType(bo.ClrType)) {
return this.VisitUnaryOperator(this.sql.Unary(SqlNodeType.IsNull, bo.Left, bo.SourceExpression));
}
break;
case SqlNodeType.NE:
case SqlNodeType.NE2V:
if (this.IsConstNull(bo.Left) && !TypeSystem.IsNullableType(bo.ClrType)) {
return this.VisitUnaryOperator(this.sql.Unary(SqlNodeType.IsNotNull, bo.Right, bo.SourceExpression));
}
else if (this.IsConstNull(bo.Right) && !TypeSystem.IsNullableType(bo.ClrType)) {
return this.VisitUnaryOperator(this.sql.Unary(SqlNodeType.IsNotNull, bo.Left, bo.SourceExpression));
}
break;
}
bo.Left = this.VisitExpression(bo.Left);
bo.Right = this.VisitExpression(bo.Right);
switch (bo.NodeType) {
case SqlNodeType.EQ:
case SqlNodeType.EQ2V:
case SqlNodeType.NE:
case SqlNodeType.NE2V: {
SqlValue vLeft = bo.Left as SqlValue;
SqlValue vRight = bo.Right as SqlValue;
bool leftIsBool = vLeft!=null && vLeft.Value is bool;
bool rightIsBool = vRight!=null && vRight.Value is bool;
if (leftIsBool || rightIsBool) {
bool equal = bo.NodeType != SqlNodeType.NE && bo.NodeType != SqlNodeType.NE2V;
bool isTwoValue = bo.NodeType == SqlNodeType.EQ2V || bo.NodeType == SqlNodeType.NE2V;
SqlNodeType negator = isTwoValue ? SqlNodeType.Not2V : SqlNodeType.Not;
if (leftIsBool && !rightIsBool) {
bool value = (bool)vLeft.Value;
if (value^equal) {
return VisitUnaryOperator(new SqlUnary(negator, bo.ClrType, bo.SqlType, sql.DoNotVisitExpression(bo.Right), bo.SourceExpression));
}
if (bo.Right.ClrType==typeof(bool)) { // If the other side is nullable bool then this expression is already a reasonable way to handle three-values
return bo.Right;
}
}
else if (!leftIsBool && rightIsBool) {
bool value = (bool)vRight.Value;
if (value^equal) {
return VisitUnaryOperator(new SqlUnary(negator, bo.ClrType, bo.SqlType, sql.DoNotVisitExpression(bo.Left), bo.SourceExpression));
}
if (bo.Left.ClrType==typeof(bool)) { // If the other side is nullable bool then this expression is already a reasonable way to handle three-values
return bo.Left;
}
} else if (leftIsBool && rightIsBool) {
// Here, both left and right are bools.
bool leftValue = (bool)vLeft.Value;
bool rightValue = (bool)vRight.Value;
if (equal) {
return sql.ValueFromObject(leftValue==rightValue, false, bo.SourceExpression);
} else {
return sql.ValueFromObject(leftValue!=rightValue, false, bo.SourceExpression);
}
}
}
break;
}
}
switch (bo.NodeType) {
case SqlNodeType.And: {
SqlValue vLeft = bo.Left as SqlValue;
SqlValue vRight = bo.Right as SqlValue;
if (vLeft != null && vRight == null) {
if (vLeft.Value != null && (bool)vLeft.Value) {
return bo.Right;
}
return sql.ValueFromObject(false, false, bo.SourceExpression);
}
else if (vLeft == null && vRight != null) {
if (vRight.Value != null && (bool)vRight.Value) {
return bo.Left;
}
return sql.ValueFromObject(false, false, bo.SourceExpression);
}
else if (vLeft != null && vRight != null) {
return sql.ValueFromObject((bool)(vLeft.Value ?? false) && (bool)(vRight.Value ?? false), false, bo.SourceExpression);
}
break;
}
case SqlNodeType.Or: {
SqlValue vLeft = bo.Left as SqlValue;
SqlValue vRight = bo.Right as SqlValue;
if (vLeft != null && vRight == null) {
if (vLeft.Value != null && !(bool)vLeft.Value) {
return bo.Right;
}
return sql.ValueFromObject(true, false, bo.SourceExpression);
}
else if (vLeft == null && vRight != null) {
if (vRight.Value != null && !(bool)vRight.Value) {
return bo.Left;
}
return sql.ValueFromObject(true, false, bo.SourceExpression);
}
else if (vLeft != null && vRight != null) {
return sql.ValueFromObject((bool)(vLeft.Value ?? false) || (bool)(vRight.Value ?? false), false, bo.SourceExpression);
}
break;
}
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V: {
SqlExpression translated = this.translator.TranslateLinkEquals(bo);
if (translated != bo) {
return this.VisitExpression(translated);
}
break;
}
}
bo.Left = this.ConvertToFetchedExpression(bo.Left);
bo.Right = this.ConvertToFetchedExpression(bo.Right);
switch (bo.NodeType) {
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
SqlExpression translated = this.translator.TranslateEquals(bo);
if (translated != bo) {
return this.VisitExpression(translated);
}
// Special handling for typeof(Type) nodes. Reduce to a static check if possible;
// strip SqlDiscriminatedType if possible;
if (typeof(Type).IsAssignableFrom(bo.Left.ClrType)) {
SqlExpression left = TypeSource.GetTypeSource(bo.Left);
SqlExpression right = TypeSource.GetTypeSource(bo.Right);
MetaType[] leftPossibleTypes = GetPossibleTypes(left);
MetaType[] rightPossibleTypes = GetPossibleTypes(right);
bool someMatch = false;
for (int i = 0; i < leftPossibleTypes.Length; ++i) {
for (int j = 0; j < rightPossibleTypes.Length; ++j) {
if (leftPossibleTypes[i] == rightPossibleTypes[j]) {
someMatch = true;
break;
}
}
}
// Is a match possible?
if (!someMatch) {
// No match is possible
return this.VisitExpression(sql.ValueFromObject(bo.NodeType == SqlNodeType.NE, false, bo.SourceExpression));
}
// Is the match known statically?
if (leftPossibleTypes.Length == 1 && rightPossibleTypes.Length == 1) {
// Yes, the match is statically known.
return this.VisitExpression(sql.ValueFromObject(
(bo.NodeType == SqlNodeType.EQ) == (leftPossibleTypes[0] == rightPossibleTypes[0]),
false,
bo.SourceExpression));
}
// If both sides are discriminated types, then create a comparison of discriminators instead;
SqlDiscriminatedType leftDt = bo.Left as SqlDiscriminatedType;
SqlDiscriminatedType rightDt = bo.Right as SqlDiscriminatedType;
if (leftDt != null && rightDt != null) {
return this.VisitExpression(sql.Binary(bo.NodeType, leftDt.Discriminator, rightDt.Discriminator));
}
}
// can only compare sql scalars
if (TypeSystem.IsSequenceType(bo.Left.ClrType)) {
throw Error.ComparisonNotSupportedForType(bo.Left.ClrType);
}
if (TypeSystem.IsSequenceType(bo.Right.ClrType)) {
throw Error.ComparisonNotSupportedForType(bo.Right.ClrType);
}
break;
}
return bo;
}
/// <summary>
/// Given an expression, return the set of dynamic types that could be returned.
/// </summary>
private MetaType[] GetPossibleTypes(SqlExpression typeExpression) {
if (!typeof(Type).IsAssignableFrom(typeExpression.ClrType)) {
return new MetaType[0];
}
if (typeExpression.NodeType == SqlNodeType.DiscriminatedType) {
SqlDiscriminatedType dt = (SqlDiscriminatedType)typeExpression;
List<MetaType> concreteTypes = new List<MetaType>();
foreach (MetaType mt in dt.TargetType.InheritanceTypes) {
if (!mt.Type.IsAbstract) {
concreteTypes.Add(mt);
}
}
return concreteTypes.ToArray();
}
else if (typeExpression.NodeType == SqlNodeType.Value) {
SqlValue val = (SqlValue)typeExpression;
MetaType mt = this.model.GetMetaType((Type)val.Value);
return new MetaType[] { mt };
} else if (typeExpression.NodeType == SqlNodeType.SearchedCase) {
SqlSearchedCase sc = (SqlSearchedCase)typeExpression;
HashSet<MetaType> types = new HashSet<MetaType>();
foreach (var when in sc.Whens) {
types.UnionWith(GetPossibleTypes(when.Value));
}
return types.ToArray();
}
throw Error.UnexpectedNode(typeExpression.NodeType);
}
/// <summary>
/// Evaluate the object and extract its discriminator.
/// </summary>
internal override SqlExpression VisitDiscriminatorOf(SqlDiscriminatorOf dof) {
SqlExpression obj = this.FetchExpression(dof.Object); // FetchExpression removes Link.
// It's valid to unwrap optional and outer-join values here because type case already handles
// NULL values correctly.
while (obj.NodeType == SqlNodeType.OptionalValue
|| obj.NodeType == SqlNodeType.OuterJoinedValue) {
if (obj.NodeType == SqlNodeType.OptionalValue) {
obj = ((SqlOptionalValue)obj).Value;
}
else {
obj = ((SqlUnary)obj).Operand;
}
}
if (obj.NodeType == SqlNodeType.TypeCase) {
SqlTypeCase tc = (SqlTypeCase)obj;
// Rewrite a case of discriminators. We can't just reduce to
// discriminator (yet) because the ELSE clause needs to be considered.
// Later in the conversion there is an optimization that will turn the CASE
// into a simple combination of ANDs and ORs.
// Also, cannot reduce to IsNull(Discriminator,DefaultDiscriminator) because
// other unexpected values besides NULL need to be handled.
List<SqlExpression> matches = new List<SqlExpression>();
List<SqlExpression> values = new List<SqlExpression>();
MetaType defaultType = tc.RowType.InheritanceDefault;
object discriminator = defaultType.InheritanceCode;
foreach (SqlTypeCaseWhen when in tc.Whens) {
matches.Add(when.Match);
if (when.Match == null) {
SqlExpression @default = sql.Value(discriminator.GetType(), tc.Whens[0].Match.SqlType, defaultType.InheritanceCode, true, tc.SourceExpression);
values.Add(@default);
}
else {
// Must duplicate so that columnizer doesn't nominate the match as a value.
values.Add(sql.Value(discriminator.GetType(), when.Match.SqlType, ((SqlValue)when.Match).Value, true, tc.SourceExpression));
}
}
return sql.Case(tc.Discriminator.ClrType, tc.Discriminator, matches, values, tc.SourceExpression);
} else {
var mt = this.model.GetMetaType(obj.ClrType).InheritanceRoot;
if (mt.HasInheritance) {
return this.VisitExpression(sql.Member(dof.Object, mt.Discriminator.Member));
}
}
return sql.TypedLiteralNull(dof.ClrType, dof.SourceExpression);
}
internal override SqlExpression VisitSearchedCase(SqlSearchedCase c) {
if ((c.ClrType == typeof(bool) || c.ClrType == typeof(bool?)) &&
c.Whens.Count == 1 && c.Else != null) {
SqlValue litElse = c.Else as SqlValue;
SqlValue litWhen = c.Whens[0].Value as SqlValue;
if (litElse != null && litElse.Value != null && !(bool)litElse.Value) {
return this.VisitExpression(sql.Binary(SqlNodeType.And, c.Whens[0].Match, c.Whens[0].Value));
}
else if (litWhen != null && litWhen.Value != null && (bool)litWhen.Value) {
return this.VisitExpression(sql.Binary(SqlNodeType.Or, c.Whens[0].Match, c.Else));
}
}
return base.VisitSearchedCase(c);
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private bool IsConstNull(SqlExpression sqlExpr) {
SqlValue sqlValue = sqlExpr as SqlValue;
if (sqlValue == null) {
return false;
}
// literal nulls are encoded as IsClientSpecified=false
return sqlValue.Value == null && !sqlValue.IsClientSpecified;
}
/// <summary>
/// Apply the 'TREAT' operator into the given target. The goal is for instances of non-assignable types
/// to be nulled out.
/// </summary>
private SqlExpression ApplyTreat(SqlExpression target, Type type) {
switch (target.NodeType) {
case SqlNodeType.OptionalValue:
SqlOptionalValue optValue = (SqlOptionalValue)target;
return ApplyTreat(optValue.Value, type);
case SqlNodeType.OuterJoinedValue:
SqlUnary unary = (SqlUnary)target;
return ApplyTreat(unary.Operand, type);
case SqlNodeType.New:
var n = (SqlNew)target;
// Are we constructing a concrete instance of a type we know can't be assigned
// to 'type'? If so, make it null.
if (!type.IsAssignableFrom(n.ClrType)) {
return sql.TypedLiteralNull(type, target.SourceExpression);
}
return target;
case SqlNodeType.TypeCase:
SqlTypeCase tc = (SqlTypeCase)target;
// Null out type case options that are impossible now.
int reducedToNull = 0;
foreach (SqlTypeCaseWhen when in tc.Whens) {
when.TypeBinding = (SqlExpression)ApplyTreat(when.TypeBinding, type);
if (this.IsConstNull(when.TypeBinding)) {
++reducedToNull;
}
}
// If every case reduced to NULL then reduce the whole clause entirely to NULL.
if (reducedToNull == tc.Whens.Count) {
// This is not an optimization. We need to do this because the type-case may be the l-value of an assign.
tc.Whens[0].TypeBinding.SetClrType(type);
return tc.Whens[0].TypeBinding; // <-- Points to a SqlValue null.
}
tc.SetClrType(type);
return target;
default:
SqlExpression expr = target as SqlExpression;
if (expr != null) {
if (!type.IsAssignableFrom(expr.ClrType) && !expr.ClrType.IsAssignableFrom(type)) {
return sql.TypedLiteralNull(type, target.SourceExpression);
}
}
else {
System.Diagnostics.Debug.Assert(false, "Don't know how to apply 'as' to " + target.NodeType);
}
return target;
}
}
internal override SqlExpression VisitTreat(SqlUnary a) {
return VisitUnaryOperator(a);
}
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal override SqlExpression VisitUnaryOperator(SqlUnary uo) {
uo.Operand = this.VisitExpression(uo.Operand);
// ------------------------------------------------------------
// PHASE 1: If possible, evaluate without fetching the operand.
// This is preferred because fetching LINKs causes them to not
// be deferred.
// ------------------------------------------------------------
if (uo.NodeType == SqlNodeType.IsNull || uo.NodeType == SqlNodeType.IsNotNull) {
SqlExpression translated = this.translator.TranslateLinkIsNull(uo);
if (translated != uo) {
return this.VisitExpression(translated);
}
if (uo.Operand.NodeType==SqlNodeType.OuterJoinedValue) {
SqlUnary ojv = uo.Operand as SqlUnary;
if (ojv.Operand.NodeType == SqlNodeType.OptionalValue) {
SqlOptionalValue ov = (SqlOptionalValue)ojv.Operand;
return this.VisitUnaryOperator(
new SqlUnary(uo.NodeType, uo.ClrType, uo.SqlType,
new SqlUnary(SqlNodeType.OuterJoinedValue, ov.ClrType, ov.SqlType, ov.HasValue, ov.SourceExpression)
, uo.SourceExpression)
);
}
else if (ojv.Operand.NodeType == SqlNodeType.TypeCase) {
SqlTypeCase tc = (SqlTypeCase)ojv.Operand;
return new SqlUnary(uo.NodeType, uo.ClrType, uo.SqlType,
new SqlUnary(SqlNodeType.OuterJoinedValue, tc.Discriminator.ClrType, tc.Discriminator.SqlType, tc.Discriminator, tc.SourceExpression),
uo.SourceExpression
);
}
}
}
// Fetch the expression.
uo.Operand = this.ConvertToFetchedExpression(uo.Operand);
// ------------------------------------------------------------
// PHASE 2: Evaluate operator on fetched expression.
// ------------------------------------------------------------
if ((uo.NodeType == SqlNodeType.Not || uo.NodeType == SqlNodeType.Not2V) && uo.Operand.NodeType == SqlNodeType.Value) {
SqlValue val = (SqlValue)uo.Operand;
return sql.Value(typeof(bool), val.SqlType, !(bool)val.Value, val.IsClientSpecified, val.SourceExpression);
}
else if (uo.NodeType == SqlNodeType.Not2V) {
if (SqlExpressionNullability.CanBeNull(uo.Operand) != false) {
SqlSearchedCase c = new SqlSearchedCase(
typeof(int),
new [] { new SqlWhen(uo.Operand, sql.ValueFromObject(1, false, uo.SourceExpression)) },
sql.ValueFromObject(0, false, uo.SourceExpression),
uo.SourceExpression
);
return sql.Binary(SqlNodeType.EQ, c, sql.ValueFromObject(0, false, uo.SourceExpression));
}
else {
return sql.Unary(SqlNodeType.Not, uo.Operand);
}
}
// push converts of client-expressions inside the client-expression (to be evaluated client side)
else if (uo.NodeType == SqlNodeType.Convert && uo.Operand.NodeType == SqlNodeType.Value) {
SqlValue val = (SqlValue)uo.Operand;
return sql.Value(uo.ClrType, uo.SqlType, DBConvert.ChangeType(val.Value, uo.ClrType), val.IsClientSpecified, val.SourceExpression);
}
else if (uo.NodeType == SqlNodeType.IsNull || uo.NodeType == SqlNodeType.IsNotNull) {
bool? canBeNull = SqlExpressionNullability.CanBeNull(uo.Operand);
if (canBeNull == false) {
return sql.ValueFromObject(uo.NodeType == SqlNodeType.IsNotNull, false, uo.SourceExpression);
}
SqlExpression exp = uo.Operand;
switch (exp.NodeType) {
case SqlNodeType.Element:
exp = sql.SubSelect(SqlNodeType.Exists, ((SqlSubSelect)exp).Select);
if (uo.NodeType == SqlNodeType.IsNull) {
exp = sql.Unary(SqlNodeType.Not, exp, exp.SourceExpression);
}
return exp;
case SqlNodeType.ClientQuery: {
SqlClientQuery cq = (SqlClientQuery)exp;
if (cq.Query.NodeType == SqlNodeType.Element) {
exp = sql.SubSelect(SqlNodeType.Exists, cq.Query.Select);
if (uo.NodeType == SqlNodeType.IsNull) {
exp = sql.Unary(SqlNodeType.Not, exp, exp.SourceExpression);
}
return exp;
}
return sql.ValueFromObject(uo.NodeType == SqlNodeType.IsNotNull, false, uo.SourceExpression);
}
case SqlNodeType.OptionalValue:
uo.Operand = ((SqlOptionalValue)exp).HasValue;
return uo;
case SqlNodeType.ClientCase: {
// Distribute unary into simple case.
SqlClientCase sc = (SqlClientCase)uo.Operand;
List<SqlExpression> matches = new List<SqlExpression>();
List<SqlExpression> values = new List<SqlExpression>();
foreach (SqlClientWhen when in sc.Whens) {
matches.Add(when.Match);
values.Add(VisitUnaryOperator(sql.Unary(uo.NodeType, when.Value, when.Value.SourceExpression)));
}
return sql.Case(sc.ClrType, sc.Expression, matches, values, sc.SourceExpression);
}
case SqlNodeType.TypeCase: {
// Distribute unary into type case. In the process, convert to simple case.
SqlTypeCase tc = (SqlTypeCase)uo.Operand;
List<SqlExpression> newMatches = new List<SqlExpression>();
List<SqlExpression> newValues = new List<SqlExpression>();
foreach (SqlTypeCaseWhen when in tc.Whens) {
SqlUnary un = new SqlUnary(uo.NodeType, uo.ClrType, uo.SqlType, when.TypeBinding, when.TypeBinding.SourceExpression);
SqlExpression expr = VisitUnaryOperator(un);
if (expr is SqlNew) {
throw Error.DidNotExpectTypeBinding();
}
newMatches.Add(when.Match);
newValues.Add(expr);
}
return sql.Case(uo.ClrType, tc.Discriminator, newMatches, newValues, tc.SourceExpression);
}
case SqlNodeType.Value: {
SqlValue val = (SqlValue)uo.Operand;
return sql.Value(typeof(bool), this.typeProvider.From(typeof(int)), (val.Value == null) == (uo.NodeType == SqlNodeType.IsNull), val.IsClientSpecified, uo.SourceExpression);
}
}
}
else if (uo.NodeType == SqlNodeType.Treat) {
return ApplyTreat(VisitExpression(uo.Operand), uo.ClrType);
}
return uo;
}
internal override SqlExpression VisitNew(SqlNew sox) {
for (int i = 0, n = sox.Args.Count; i < n; i++) {
if (inGroupBy) {
// we don't want to fetch expressions for group by,
// since we want links to remain links so SqlFlattener
// can deal with them properly
sox.Args[i] = this.VisitExpression(sox.Args[i]);
}
else {
sox.Args[i] = this.FetchExpression(sox.Args[i]);
}
}
for (int i = 0, n = sox.Members.Count; i < n; i++) {
SqlMemberAssign ma = sox.Members[i];
MetaDataMember mm = sox.MetaType.GetDataMember(ma.Member);
MetaType otherType = mm.DeclaringType.InheritanceRoot;
if (mm.IsAssociation && ma.Expression != null && ma.Expression.NodeType != SqlNodeType.Link
&& this.shape != null && this.shape.IsPreloaded(mm.Member) && mm.LoadMethod == null
&& this.alreadyIncluded != null && !this.alreadyIncluded.Contains(otherType)) {
// The expression is already fetched, add it to the alreadyIncluded set.
this.alreadyIncluded.Add(otherType);
ma.Expression = this.VisitExpression(ma.Expression);
this.alreadyIncluded.Remove(otherType);
}
else if (mm.IsAssociation || mm.IsDeferred) {
ma.Expression = this.VisitExpression(ma.Expression);
}
else {
ma.Expression = this.FetchExpression(ma.Expression);
}
}
return sox;
}
internal override SqlNode VisitMember(SqlMember m) {
return this.AccessMember(m, this.FetchExpression(m.Expression));
}
[SuppressMessage("Microsoft.Performance", "CA1809:AvoidExcessiveLocals", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
private SqlNode AccessMember(SqlMember m, SqlExpression expo) {
SqlExpression exp = expo;
switch (exp.NodeType) {
case SqlNodeType.ClientCase: {
// Distribute into each case.
SqlClientCase sc = (SqlClientCase)exp;
Type newClrType = null;
List<SqlExpression> matches = new List<SqlExpression>();
List<SqlExpression> values = new List<SqlExpression>();
foreach (SqlClientWhen when in sc.Whens) {
SqlExpression newValue = (SqlExpression)AccessMember(m, when.Value);
if (newClrType == null) {
newClrType = newValue.ClrType;
}
else if (newClrType != newValue.ClrType) {
throw Error.ExpectedClrTypesToAgree(newClrType, newValue.ClrType);
}
matches.Add(when.Match);
values.Add(newValue);
}
SqlExpression result = sql.Case(newClrType, sc.Expression, matches, values, sc.SourceExpression);
return result;
}
case SqlNodeType.SimpleCase: {
// Distribute into each case.
SqlSimpleCase sc = (SqlSimpleCase)exp;
Type newClrType = null;
List<SqlExpression> newMatches = new List<SqlExpression>();
List<SqlExpression> newValues = new List<SqlExpression>();
foreach (SqlWhen when in sc.Whens) {
SqlExpression newValue = (SqlExpression)AccessMember(m, when.Value);
if (newClrType == null) {
newClrType = newValue.ClrType;
}
else if (newClrType != newValue.ClrType) {
throw Error.ExpectedClrTypesToAgree(newClrType, newValue.ClrType);
}
newMatches.Add(when.Match);
newValues.Add(newValue);
}
SqlExpression result = sql.Case(newClrType, sc.Expression, newMatches, newValues, sc.SourceExpression);
return result;
}
case SqlNodeType.SearchedCase: {
// Distribute into each case.
SqlSearchedCase sc = (SqlSearchedCase)exp;
List<SqlWhen> whens = new List<SqlWhen>(sc.Whens.Count);
foreach (SqlWhen when in sc.Whens) {
SqlExpression value = (SqlExpression)AccessMember(m, when.Value);
whens.Add(new SqlWhen(when.Match, value));
}
SqlExpression @else = (SqlExpression)AccessMember(m, sc.Else);
return sql.SearchedCase(whens.ToArray(), @else, sc.SourceExpression);
}
case SqlNodeType.TypeCase: {
// We don't allow derived types to map members to different database fields.
// Therefore, just pick the best SqlNew to call AccessMember on.
SqlTypeCase tc = (SqlTypeCase)exp;
// Find the best type binding for this member.
SqlNew tb = tc.Whens[0].TypeBinding as SqlNew;
foreach (SqlTypeCaseWhen when in tc.Whens) {
if (when.TypeBinding.NodeType == SqlNodeType.New) {
SqlNew sn = (SqlNew)when.TypeBinding;
if (m.Member.DeclaringType.IsAssignableFrom(sn.ClrType)) {
tb = sn;
break;
}
}
}
return AccessMember(m, tb);
}
case SqlNodeType.AliasRef: {
// convert alias.Member => column
SqlAliasRef aref = (SqlAliasRef)exp;
// if its a table, find the matching column
SqlTable tab = aref.Alias.Node as SqlTable;
if (tab != null) {
MetaDataMember mm = GetRequiredInheritanceDataMember(tab.RowType, m.Member);
System.Diagnostics.Debug.Assert(mm != null);
string name = mm.MappedName;
SqlColumn c = tab.Find(name);
if (c == null) {
ProviderType sqlType = sql.Default(mm);
c = new SqlColumn(m.ClrType, sqlType, name, mm, null, m.SourceExpression);
c.Alias = aref.Alias;
tab.Columns.Add(c);
}
return new SqlColumnRef(c);
}
// if it is a table valued function, find the matching result column
SqlTableValuedFunctionCall fc = aref.Alias.Node as SqlTableValuedFunctionCall;
if (fc != null) {
MetaDataMember mm = GetRequiredInheritanceDataMember(fc.RowType, m.Member);
System.Diagnostics.Debug.Assert(mm != null);
string name = mm.MappedName;
SqlColumn c = fc.Find(name);
if (c == null) {
ProviderType sqlType = sql.Default(mm);
c = new SqlColumn(m.ClrType, sqlType, name, mm, null, m.SourceExpression);
c.Alias = aref.Alias;
fc.Columns.Add(c);
}
return new SqlColumnRef(c);
}
break;
}
case SqlNodeType.OptionalValue:
// convert option(exp).Member => exp.Member
return this.AccessMember(m, ((SqlOptionalValue)exp).Value);
case SqlNodeType.OuterJoinedValue: {
SqlNode n = this.AccessMember(m, ((SqlUnary)exp).Operand);
SqlExpression e = n as SqlExpression;
if (e != null) return sql.Unary(SqlNodeType.OuterJoinedValue, e);
return n;
}
case SqlNodeType.Lift:
return this.AccessMember(m, ((SqlLift)exp).Expression);
case SqlNodeType.UserRow: {
// convert UserRow.Member => UserColumn
SqlUserRow row = (SqlUserRow)exp;
SqlUserQuery suq = row.Query;
MetaDataMember mm = GetRequiredInheritanceDataMember(row.RowType, m.Member);
System.Diagnostics.Debug.Assert(mm != null);
string name = mm.MappedName;
SqlUserColumn c = suq.Find(name);
if (c == null) {
ProviderType sqlType = sql.Default(mm);
c = new SqlUserColumn(m.ClrType, sqlType, suq, name, mm.IsPrimaryKey, m.SourceExpression);
suq.Columns.Add(c);
}
return c;
}
case SqlNodeType.New: {
// convert (new {Member = expr}).Member => expr
SqlNew sn = (SqlNew)exp;
SqlExpression e = sn.Find(m.Member);
if (e != null) {
return e;
}
MetaDataMember mm = sn.MetaType.PersistentDataMembers.FirstOrDefault(p => p.Member == m.Member);
if (!sn.SqlType.CanBeColumn && mm != null) {
throw Error.MemberNotPartOfProjection(m.Member.DeclaringType, m.Member.Name);
}
break;
}
case SqlNodeType.Element:
case SqlNodeType.ScalarSubSelect: {
// convert Scalar/Element(select exp).Member => Scalar/Element(select exp.Member) / select exp.Member
SqlSubSelect sub = (SqlSubSelect)exp;
SqlAlias alias = new SqlAlias(sub.Select);
SqlAliasRef aref = new SqlAliasRef(alias);
SqlSelect saveSelect = this.currentSelect;
try {
SqlSelect newSelect = new SqlSelect(aref, alias, sub.SourceExpression);
this.currentSelect = newSelect;
SqlNode result = this.Visit(sql.Member(aref, m.Member));
SqlExpression rexp = result as SqlExpression;
if (rexp != null) {
// If the expression is still a Member after being visited, but it cannot be a column, then it cannot be collapsed
// into the SubSelect because we need to keep track of the fact that this member has to be accessed on the client.
// This must be done after the expression has been Visited above, because otherwise we don't have
// enough context to know if the member can be a column or not.
if (rexp.NodeType == SqlNodeType.Member && !SqlColumnizer.CanBeColumn(rexp)) {
// If the original member expression is an Element, optimize it by converting to an OuterApply if possible.
// We have to do this here because we are creating a new member expression based on it, and there are no
// subsequent visitors that will do this optimization.
if (this.canUseOuterApply && exp.NodeType == SqlNodeType.Element && this.currentSelect != null) {
// Reset the currentSelect since we are not going to use the previous SqlSelect that was created
this.currentSelect = saveSelect;
this.currentSelect.From = sql.MakeJoin(SqlJoinType.OuterApply, this.currentSelect.From, alias, null, sub.SourceExpression);
exp = this.VisitExpression(aref);
}
return sql.Member(exp, m.Member);
}
// Since we are going to make a SubSelect out of this member expression, we need to make
// sure it gets columnized before it gets to the PostBindDotNetConverter, otherwise only the
// entire SubSelect will be columnized as a whole. Subsequent columnization does not know how to handle
// any function calls that may be produced by the PostBindDotNetConverter, but we know how to handle it here.
newSelect.Selection = rexp;
newSelect.Selection = this.columnizer.ColumnizeSelection(newSelect.Selection);
newSelect.Selection = this.ConvertLinks(newSelect.Selection);
SqlNodeType subType = (rexp is SqlTypeCase || !rexp.SqlType.CanBeColumn) ? SqlNodeType.Element : SqlNodeType.ScalarSubSelect;
SqlSubSelect subSel = sql.SubSelect(subType, newSelect);
return this.FoldSubquery(subSel);
}
SqlSelect rselect = result as SqlSelect;
if (rselect != null) {
SqlAlias ralias = new SqlAlias(rselect);
SqlAliasRef rref = new SqlAliasRef(ralias);
newSelect.Selection = this.ConvertLinks(this.VisitExpression(rref));
newSelect.From = new SqlJoin(SqlJoinType.CrossApply, alias, ralias, null, m.SourceExpression);
return newSelect;
}
throw Error.UnexpectedNode(result.NodeType);
}
finally {
this.currentSelect = saveSelect;
}
}
case SqlNodeType.Value: {
SqlValue val = (SqlValue)exp;
if (val.Value == null) {
return sql.Value(m.ClrType, m.SqlType, null, val.IsClientSpecified, m.SourceExpression);
}
else if (m.Member is PropertyInfo) {
PropertyInfo p = (PropertyInfo)m.Member;
return sql.Value(m.ClrType, m.SqlType, p.GetValue(val.Value, null), val.IsClientSpecified, m.SourceExpression);
}
else {
FieldInfo f = (FieldInfo)m.Member;
return sql.Value(m.ClrType, m.SqlType, f.GetValue(val.Value), val.IsClientSpecified, m.SourceExpression);
}
}
case SqlNodeType.Grouping: {
SqlGrouping g = ((SqlGrouping)exp);
if (m.Member.Name == "Key") {
return g.Key;
}
break;
}
case SqlNodeType.ClientParameter: {
SqlClientParameter cp = (SqlClientParameter)exp;
// create new accessor including this member access
LambdaExpression accessor =
Expression.Lambda(
typeof(Func<,>).MakeGenericType(typeof(object[]), m.ClrType),
Expression.MakeMemberAccess(cp.Accessor.Body, m.Member),
cp.Accessor.Parameters
);
return new SqlClientParameter(m.ClrType, m.SqlType, accessor, cp.SourceExpression);
}
default:
break;
}
if (m.Expression == exp) {
return m;
}
else {
return sql.Member(exp, m.Member);
}
}
private SqlExpression FoldSubquery(SqlSubSelect ss) {
// convert ELEMENT(SELECT MULTISET(SELECT xxx FROM t1 WHERE p1) FROM t2 WHERE p2)
// into MULTISET(SELECT xxx FROM t2 CA (SELECT xxx FROM t1 WHERE p1) WHERE p2))
while (true) {
if (ss.NodeType == SqlNodeType.Element && ss.Select.Selection.NodeType == SqlNodeType.Multiset) {
SqlSubSelect msub = (SqlSubSelect)ss.Select.Selection;
SqlAlias alias = new SqlAlias(msub.Select);
SqlAliasRef aref = new SqlAliasRef(alias);
SqlSelect sel = ss.Select;
sel.Selection = this.ConvertLinks(this.VisitExpression(aref));
sel.From = new SqlJoin(SqlJoinType.CrossApply, sel.From, alias, null, ss.SourceExpression);
SqlSubSelect newss = sql.SubSelect(SqlNodeType.Multiset, sel, ss.ClrType);
ss = newss;
}
else if (ss.NodeType == SqlNodeType.Element && ss.Select.Selection.NodeType == SqlNodeType.Element) {
SqlSubSelect msub = (SqlSubSelect)ss.Select.Selection;
SqlAlias alias = new SqlAlias(msub.Select);
SqlAliasRef aref = new SqlAliasRef(alias);
SqlSelect sel = ss.Select;
sel.Selection = this.ConvertLinks(this.VisitExpression(aref));
sel.From = new SqlJoin(SqlJoinType.CrossApply, sel.From, alias, null, ss.SourceExpression);
SqlSubSelect newss = sql.SubSelect(SqlNodeType.Element, sel);
ss = newss;
}
else {
break;
}
}
return ss;
}
/// <summary>
/// Get the MetaDataMember from the given table. Look in the inheritance hierarchy.
/// The member is expected to be there and an exception will be thrown if it isn't.
/// </summary>
/// <param name="type">The hierarchy type that should have the member.</param>
/// <param name="mi">The member to retrieve.</param>
/// <returns>The MetaDataMember for the type.</returns>
private static MetaDataMember GetRequiredInheritanceDataMember(MetaType type, MemberInfo mi) {
System.Diagnostics.Debug.Assert(type != null);
System.Diagnostics.Debug.Assert(mi != null);
MetaType root = type.GetInheritanceType(mi.DeclaringType);
if (root == null) {
throw Error.UnmappedDataMember(mi, mi.DeclaringType, type);
}
return root.GetDataMember(mi);
}
internal override SqlStatement VisitAssign(SqlAssign sa) {
sa.LValue = this.FetchExpression(sa.LValue);
sa.RValue = this.FetchExpression(sa.RValue);
return sa;
}
internal SqlExpression ExpandExpression(SqlExpression expression) {
SqlExpression expanded = this.expander.Expand(expression);
if (expanded != expression) {
expanded = this.VisitExpression(expanded);
}
return expanded;
}
internal override SqlExpression VisitAliasRef(SqlAliasRef aref) {
return this.ExpandExpression(aref);
}
internal override SqlAlias VisitAlias(SqlAlias a) {
SqlAlias saveAlias = this.currentAlias;
if (a.Node.NodeType == SqlNodeType.Table) {
this.outerAliasMap[a] = this.currentAlias;
}
this.currentAlias = a;
try {
a.Node = this.ConvertToFetchedSequence(this.Visit(a.Node));
return a;
}
finally {
this.currentAlias = saveAlias;
}
}
internal override SqlNode VisitLink(SqlLink link) {
link = (SqlLink)base.VisitLink(link);
// prefetch all 'LoadWith' links
if (!this.disableInclude && this.shape != null && this.alreadyIncluded != null) {
MetaDataMember mdm = link.Member;
MemberInfo mi = mdm.Member;
if (this.shape.IsPreloaded(mi) && mdm.LoadMethod == null) {
// Is the other side of the relation in the list already?
MetaType otherType = mdm.DeclaringType.InheritanceRoot;
if (!this.alreadyIncluded.Contains(otherType)) {
this.alreadyIncluded.Add(otherType);
SqlNode fetched = this.ConvertToFetchedExpression(link);
this.alreadyIncluded.Remove(otherType);
return fetched;
}
}
}
if (this.inGroupBy && link.Expansion != null) {
return this.VisitLinkExpansion(link);
}
return link;
}
internal override SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
// always make a copy
return (SqlExpression) SqlDuplicator.Copy(sref.SharedExpression.Expression);
}
internal override SqlExpression VisitSharedExpression(SqlSharedExpression shared) {
shared.Expression = this.VisitExpression(shared.Expression);
// shared expressions in group-by/select must be only column refs
if (shared.Expression.NodeType == SqlNodeType.ColumnRef) {
return shared.Expression;
}
else {
// not simple? better push it down (make a sub-select that projects the relevant bits
shared.Expression = this.PushDownExpression(shared.Expression);
return shared.Expression;
}
}
internal override SqlExpression VisitSimpleExpression(SqlSimpleExpression simple) {
simple.Expression = this.VisitExpression(simple.Expression);
if (SimpleExpression.IsSimple(simple.Expression)) {
return simple.Expression;
}
SqlExpression result = this.PushDownExpression(simple.Expression);
// simple expressions must be scalar (such that they can be formed into a single column declaration)
System.Diagnostics.Debug.Assert(result is SqlColumnRef);
return result;
}
// add a new sub query that projects the given expression
private SqlExpression PushDownExpression(SqlExpression expr) {
// make sure this expression was columnized like a selection
if (expr.NodeType == SqlNodeType.Value && expr.SqlType.CanBeColumn) {
expr = new SqlColumn(expr.ClrType, expr.SqlType, null, null, expr, expr.SourceExpression);
}
else {
expr = this.columnizer.ColumnizeSelection(expr);
}
SqlSelect simple = new SqlSelect(expr, this.currentSelect.From, expr.SourceExpression);
this.currentSelect.From = new SqlAlias(simple);
// make a copy of the expression for the current scope
return this.ExpandExpression(expr);
}
internal override SqlSource VisitJoin(SqlJoin join) {
if (join.JoinType == SqlJoinType.CrossApply ||
join.JoinType == SqlJoinType.OuterApply) {
join.Left = this.VisitSource(join.Left);
SqlSelect saveSelect = this.currentSelect;
try {
this.currentSelect = this.GetSourceSelect(join.Left);
join.Right = this.VisitSource(join.Right);
this.currentSelect = null;
join.Condition = this.VisitExpression(join.Condition);
return join;
}
finally {
this.currentSelect = saveSelect;
}
}
else {
return base.VisitJoin(join);
}
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlSelect GetSourceSelect(SqlSource source) {
SqlAlias alias = source as SqlAlias;
if (alias == null) {
return null;
}
return alias.Node as SqlSelect;
}
internal override SqlSelect VisitSelect(SqlSelect select) {
LinkOptimizationScope saveScope = this.linkMap;
SqlSelect saveSelect = this.currentSelect;
bool saveInGroupBy = inGroupBy;
inGroupBy = false;
try {
// don't preserve any link optimizations across a group or distinct boundary
bool linkOptimize = true;
if (this.binder.optimizeLinkExpansions &&
(select.GroupBy.Count > 0 || this.aggregateChecker.HasAggregates(select) || select.IsDistinct)) {
linkOptimize = false;
this.linkMap = new LinkOptimizationScope(this.linkMap);
}
select.From = this.VisitSource(select.From);
this.currentSelect = select;
select.Where = this.VisitExpression(select.Where);
this.inGroupBy = true;
for (int i = 0, n = select.GroupBy.Count; i < n; i++) {
select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
}
this.inGroupBy = false;
select.Having = this.VisitExpression(select.Having);
for (int i = 0, n = select.OrderBy.Count; i < n; i++) {
select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
}
select.Top = this.VisitExpression(select.Top);
select.Row = (SqlRow)this.Visit(select.Row);
select.Selection = this.VisitExpression(select.Selection);
select.Selection = this.columnizer.ColumnizeSelection(select.Selection);
if (linkOptimize) {
select.Selection = ConvertLinks(select.Selection);
}
// optimize out where clause for WHERE TRUE
if (select.Where != null && select.Where.NodeType == SqlNodeType.Value && (bool)((SqlValue)select.Where).Value) {
select.Where = null;
}
}
finally {
this.currentSelect = saveSelect;
this.linkMap = saveScope;
this.inGroupBy = saveInGroupBy;
}
return select;
}
internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
// don't preserve any link optimizations across sub-queries
LinkOptimizationScope saveScope = this.linkMap;
SqlSelect saveSelect = this.currentSelect;
try {
this.linkMap = new LinkOptimizationScope(this.linkMap);
this.currentSelect = null;
return base.VisitSubSelect(ss);
}
finally {
this.linkMap = saveScope;
this.currentSelect = saveSelect;
}
}
/// <summary>
/// Convert links. Need to recurse because there may be a client case with cases that are links.
/// </summary>
private SqlExpression ConvertLinks(SqlExpression node) {
if (node == null) {
return null;
}
switch (node.NodeType) {
case SqlNodeType.Column: {
SqlColumn col = (SqlColumn)node;
if (col.Expression != null) {
col.Expression = this.ConvertLinks(col.Expression);
}
return node;
}
case SqlNodeType.OuterJoinedValue: {
SqlExpression o = ((SqlUnary)node).Operand;
SqlExpression e = this.ConvertLinks(o);
if (e == o) {
return node;
}
if (e.NodeType != SqlNodeType.OuterJoinedValue) {
return sql.Unary(SqlNodeType.OuterJoinedValue, e);
}
return e;
}
case SqlNodeType.Link:
return this.ConvertToFetchedExpression((SqlLink)node);
case SqlNodeType.ClientCase: {
SqlClientCase sc = (SqlClientCase)node;
foreach (SqlClientWhen when in sc.Whens) {
SqlExpression converted = ConvertLinks(when.Value);
when.Value = converted;
if (!sc.ClrType.IsAssignableFrom(when.Value.ClrType)) {
throw Error.DidNotExpectTypeChange(when.Value.ClrType, sc.ClrType);
}
}
return node;
}
}
return node;
}
internal SqlExpression ConvertToExpression(SqlNode node) {
if (node == null) {
return null;
}
SqlExpression x = node as SqlExpression;
if (x != null) {
return x;
}
SqlSelect select = node as SqlSelect;
if (select != null) {
SqlSubSelect ms = sql.SubSelect(SqlNodeType.Multiset, select);
return ms;
}
throw Error.UnexpectedNode(node.NodeType);
}
[SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal SqlExpression ConvertToFetchedExpression(SqlNode node) {
if (node == null) {
return null;
}
switch (node.NodeType) {
case SqlNodeType.OuterJoinedValue: {
SqlExpression o = ((SqlUnary)node).Operand;
SqlExpression e = this.ConvertLinks(o);
if (e == o) {
return (SqlExpression)node;
}
return e;
}
case SqlNodeType.ClientCase: {
// Need to recurse in case the object case has links.
SqlClientCase cc = (SqlClientCase)node;
List<SqlNode> fetchedValues = new List<SqlNode>();
bool allExprs = true;
foreach (SqlClientWhen when in cc.Whens) {
SqlNode fetchedValue = ConvertToFetchedExpression(when.Value);
allExprs = allExprs && (fetchedValue is SqlExpression);
fetchedValues.Add(fetchedValue);
}
if (allExprs) {
// All WHEN values are simple expressions (no sequences).
List<SqlExpression> matches = new List<SqlExpression>();
List<SqlExpression> values = new List<SqlExpression>();
for (int i = 0, c = fetchedValues.Count; i < c; ++i) {
SqlExpression fetchedValue = (SqlExpression)fetchedValues[i];
if (!cc.ClrType.IsAssignableFrom(fetchedValue.ClrType)) {
throw Error.DidNotExpectTypeChange(cc.ClrType, fetchedValue.ClrType);
}
matches.Add(cc.Whens[i].Match);
values.Add(fetchedValue);
}
node = sql.Case(cc.ClrType, cc.Expression, matches, values, cc.SourceExpression);
}
else {
node = SimulateCaseOfSequences(cc, fetchedValues);
}
break;
}
case SqlNodeType.TypeCase: {
SqlTypeCase tc = (SqlTypeCase)node;
List<SqlNode> fetchedValues = new List<SqlNode>();
foreach (SqlTypeCaseWhen when in tc.Whens) {
SqlNode fetchedValue = ConvertToFetchedExpression(when.TypeBinding);
fetchedValues.Add(fetchedValue);
}
for (int i = 0, c = fetchedValues.Count; i < c; ++i) {
SqlExpression fetchedValue = (SqlExpression)fetchedValues[i];
tc.Whens[i].TypeBinding = fetchedValue;
}
break;
}
case SqlNodeType.SearchedCase: {
SqlSearchedCase sc = (SqlSearchedCase)node;
foreach (SqlWhen when in sc.Whens) {
when.Match = this.ConvertToFetchedExpression(when.Match);
when.Value = this.ConvertToFetchedExpression(when.Value);
}
sc.Else = this.ConvertToFetchedExpression(sc.Else);
break;
}
case SqlNodeType.Link: {
SqlLink link = (SqlLink)node;
if (link.Expansion != null) {
return this.VisitLinkExpansion(link);
}
SqlExpression cached;
if (this.linkMap.TryGetValue(link.Id, out cached)) {
return this.VisitExpression(cached);
}
// translate link into expanded form
node = this.translator.TranslateLink(link, true);
// New nodes may have been produced because of Subquery.
// Prebind again for method-call and static treat handling.
node = binder.Prebind(node);
// Make it an expression.
node = this.ConvertToExpression(node);
// bind the translation
node = this.Visit(node);
// Check for element node, rewrite as sql apply.
if (this.currentSelect != null
&& node != null
&& node.NodeType == SqlNodeType.Element
&& link.Member.IsAssociation
&& this.binder.OptimizeLinkExpansions
) {
// if link in a non-nullable foreign key association then inner-join is okay to use (since it must always exist)
// otherwise use left-outer-join
SqlJoinType joinType = (link.Member.Association.IsForeignKey && !link.Member.Association.IsNullable)
? SqlJoinType.Inner : SqlJoinType.LeftOuter;
SqlSubSelect ss = (SqlSubSelect)node;
SqlExpression where = ss.Select.Where;
ss.Select.Where = null;
// form cross apply
SqlAlias sa = new SqlAlias(ss.Select);
if (joinType == SqlJoinType.Inner && this.IsOuterDependent(this.currentSelect.From, sa, where))
{
joinType = SqlJoinType.LeftOuter;
}
this.currentSelect.From = sql.MakeJoin(joinType, this.currentSelect.From, sa, where, ss.SourceExpression);
SqlExpression result = new SqlAliasRef(sa);
this.linkMap.Add(link.Id, result);
return this.VisitExpression(result);
}
}
break;
}
return (SqlExpression)node;
}
// insert new join in an appropriate location within an existing join tree
private bool IsOuterDependent(SqlSource location, SqlAlias alias, SqlExpression where)
{
HashSet<SqlAlias> consumed = SqlGatherConsumedAliases.Gather(where);
consumed.ExceptWith(SqlGatherProducedAliases.Gather(alias));
HashSet<SqlAlias> produced;
if (this.IsOuterDependent(false, location, consumed, out produced))
return true;
return false;
}
// insert new join closest to the aliases it depends on
private bool IsOuterDependent(bool isOuterDependent, SqlSource location, HashSet<SqlAlias> consumed, out HashSet<SqlAlias> produced)
{
if (location.NodeType == SqlNodeType.Join)
{
// walk down join tree looking for best location for join
SqlJoin join = (SqlJoin)location;
if (this.IsOuterDependent(isOuterDependent, join.Left, consumed, out produced))
return true;
HashSet<SqlAlias> rightProduced;
bool rightIsOuterDependent = join.JoinType == SqlJoinType.LeftOuter || join.JoinType == SqlJoinType.OuterApply;
if (this.IsOuterDependent(rightIsOuterDependent, join.Right, consumed, out rightProduced))
return true;
produced.UnionWith(rightProduced);
}
else
{
SqlAlias a = location as SqlAlias;
if (a != null)
{
SqlSelect s = a.Node as SqlSelect;
if (s != null && !isOuterDependent && s.From != null)
{
if (this.IsOuterDependent(false, s.From, consumed, out produced))
return true;
}
}
produced = SqlGatherProducedAliases.Gather(location);
}
// look to see if this subtree fully satisfies join condition
if (consumed.IsSubsetOf(produced))
{
return isOuterDependent;
}
return false;
}
/// <summary>
/// The purpose of this function is to look in 'node' for delay-fetched structures (eg Links)
/// and to make them into fetched structures that will be evaluated directly in the query.
/// </summary>
internal SqlNode ConvertToFetchedSequence(SqlNode node) {
if (node == null) {
return node;
}
while (node.NodeType == SqlNodeType.OuterJoinedValue) {
node = ((SqlUnary)node).Operand;
}
SqlExpression expr = node as SqlExpression;
if (expr == null) {
return node;
}
if (!TypeSystem.IsSequenceType(expr.ClrType)) {
throw Error.SequenceOperatorsNotSupportedForType(expr.ClrType);
}
if (expr.NodeType == SqlNodeType.Value) {
throw Error.QueryOnLocalCollectionNotSupported();
}
if (expr.NodeType == SqlNodeType.Link) {
SqlLink link = (SqlLink)expr;
if (link.Expansion != null) {
return this.VisitLinkExpansion(link);
}
// translate link into expanded form
node = this.translator.TranslateLink(link, false);
// New nodes may have been produced because of Subquery.
// Prebind again for method-call and static treat handling.
node = binder.Prebind(node);
// bind the translation
node = this.Visit(node);
}
else if (expr.NodeType == SqlNodeType.Grouping) {
node = ((SqlGrouping)expr).Group;
}
else if (expr.NodeType == SqlNodeType.ClientCase) {
/*
* Client case needs to be handled here because it may be a client-case
* of delay-fetch structures such as links (or other client cases of links):
*
* CASE [Disc]
* WHEN 'X' THEN A
* WHEN 'Y' THEN B
* END
*
* Abstractly, this would be rewritten as
*
* CASE [Disc]
* WHEN 'X' THEN ConvertToFetchedSequence(A)
* WHEN 'Y' THEN ConvertToFetchedSequence(B)
* END
*
* The hitch is that the result of ConvertToFetchedSequence() is likely
* to be a SELECT which is not legal in a CASE. Instead, we need to rewrite as
*
* SELECT [ProjectionX] WHERE [Disc]='X'
* UNION ALL
* SELECT [ProjectionY] WHERE [Disc]='Y'
*
* In other words, a Union where only one SELECT will have a WHERE clase
* that can produce a non-empty set for each instance of [Disc].
*/
SqlClientCase sc = (SqlClientCase)expr;
List<SqlNode> newValues = new List<SqlNode>();
bool rewrite = false;
bool allSame = true;
foreach (SqlClientWhen when in sc.Whens) {
SqlNode newValue = ConvertToFetchedSequence(when.Value);
rewrite = rewrite || (newValue != when.Value);
newValues.Add(newValue);
allSame = allSame && SqlComparer.AreEqual(when.Value, sc.Whens[0].Value);
}
if (rewrite) {
if (allSame) {
// If all branches are the same then just take one.
node = newValues[0];
}
else {
node = this.SimulateCaseOfSequences(sc, newValues);
}
}
}
SqlSubSelect ss = node as SqlSubSelect;
if (ss != null) {
node = ss.Select;
}
return node;
}
private SqlExpression VisitLinkExpansion(SqlLink link) {
SqlAliasRef aref = link.Expansion as SqlAliasRef;
if (aref != null && aref.Alias.Node.NodeType == SqlNodeType.Table) {
SqlAlias outerAlias;
if (this.outerAliasMap.TryGetValue(aref.Alias, out outerAlias)) {
return this.VisitAliasRef(new SqlAliasRef(outerAlias));
}
// should not happen
System.Diagnostics.Debug.Assert(false);
}
return this.VisitExpression(link.Expansion);
}
/// <summary>
/// Given a ClientCase and a list of sequence (one for each case), construct a structure
/// that is equivalent to a CASE of SELECTs. To accomplish this we use UNION ALL and attach
/// a WHERE clause which will pick the SELECT that matches the discriminator in the Client Case.
/// </summary>
private SqlSelect SimulateCaseOfSequences(SqlClientCase clientCase, List<SqlNode> sequences) {
/*
* There are two situations we may be in:
* (1) There is exactly one case alternative.
* Here, no where clause is needed.
* (2) There is more than case alternative.
* Here, each WHERE clause needs to be ANDed with [Disc]=D where D
* is the literal discriminanator value.
*/
if (sequences.Count == 1) {
return (SqlSelect)sequences[0];
}
else {
SqlNode union = null;
SqlSelect sel = null;
int elseIndex = clientCase.Whens.Count - 1;
int elseCount = clientCase.Whens[elseIndex].Match == null ? 1 : 0;
SqlExpression elseFilter = null;
for (int i = 0; i < sequences.Count - elseCount; ++i) {
sel = (SqlSelect)sequences[i];
SqlExpression discriminatorPredicate = sql.Binary(SqlNodeType.EQ, clientCase.Expression, clientCase.Whens[i].Match);
sel.Where = sql.AndAccumulate(sel.Where, discriminatorPredicate);
elseFilter = sql.AndAccumulate(elseFilter, sql.Binary(SqlNodeType.NE, clientCase.Expression, clientCase.Whens[i].Match));
if (union == null) {
union = sel;
}
else {
union = new SqlUnion(sel, union, true /* Union All */);
}
}
// Handle 'else' if present.
if (elseCount == 1) {
sel = (SqlSelect)sequences[elseIndex];
sel.Where = sql.AndAccumulate(sel.Where, elseFilter);
if (union == null) {
union = sel;
}
else {
union = new SqlUnion(sel, union, true /* Union All */);
}
}
SqlAlias alias = new SqlAlias(union);
SqlAliasRef aref = new SqlAliasRef(alias);
return new SqlSelect(aref, alias, union.SourceExpression);
}
}
}
}
}
|