|
using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Data.Linq.Provider;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
/// <summary>
/// SQL with CASE statements is harder to read. This visitor attempts to reduce CASE
/// statements to equivalent (but easier to read) logic.
/// </summary>
internal class SqlCaseSimplifier {
internal static SqlNode Simplify(SqlNode node, SqlFactory sql) {
return new Visitor(sql).Visit(node);
}
class Visitor : SqlVisitor {
SqlFactory sql;
internal Visitor(SqlFactory sql) {
this.sql = sql;
}
/// <summary>
/// Replace equals and not equals:
///
/// | CASE XXX | CASE XXX CASE XXX
/// | WHEN AAA THEN MMMM | != RRRR ===> WHEN AAA THEN (MMMM != RRRR) ==> WHEN AAA THEN true
/// | WHEN BBB THEN NNNN | WHEN BBB THEN (NNNN != RRRR) WHEN BBB THEN false
/// | etc. | etc. etc.
/// | ELSE OOOO | ELSE (OOOO != RRRR) ELSE true
/// | END END END
///
/// Where MMMM, NNNN and RRRR are constants.
/// </summary>
internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
switch (bo.NodeType) {
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
if (bo.Left.NodeType == SqlNodeType.SimpleCase &&
bo.Right.NodeType == SqlNodeType.Value &&
AreCaseWhenValuesConstant((SqlSimpleCase)bo.Left)) {
return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Left, bo.Right);
}
else if (bo.Right.NodeType == SqlNodeType.SimpleCase &&
bo.Left.NodeType==SqlNodeType.Value &&
AreCaseWhenValuesConstant((SqlSimpleCase)bo.Right)) {
return this.DistributeOperatorIntoCase(bo.NodeType, (SqlSimpleCase)bo.Right, bo.Left);
}
break;
}
return base.VisitBinaryOperator(bo);
}
/// <summary>
/// Checks to see if all SqlSimpleCase when values are of Value type.
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
internal bool AreCaseWhenValuesConstant(SqlSimpleCase sc) {
foreach (SqlWhen when in sc.Whens) {
if (when.Value.NodeType != SqlNodeType.Value) {
return false;
}
}
return true;
}
/// <summary>
/// Helper for VisitBinaryOperator. Builds the new case with distributed valueds.
/// </summary>
private SqlExpression DistributeOperatorIntoCase(SqlNodeType nt, SqlSimpleCase sc, SqlExpression expr) {
if (nt!=SqlNodeType.EQ && nt!=SqlNodeType.NE && nt!=SqlNodeType.EQ2V && nt!=SqlNodeType.NE2V)
throw Error.ArgumentOutOfRange("nt");
object val = Eval(expr);
List<SqlExpression> values = new List<SqlExpression>();
List<SqlExpression> matches = new List<SqlExpression>();
foreach(SqlWhen when in sc.Whens) {
matches.Add(when.Match);
object whenVal = Eval(when.Value);
bool eq = when.Value.SqlType.AreValuesEqual(whenVal, val);
values.Add(sql.ValueFromObject((nt==SqlNodeType.EQ || nt==SqlNodeType.EQ2V) == eq, false, sc.SourceExpression));
}
return this.VisitExpression(sql.Case(typeof(bool), sc.Expression, matches, values, sc.SourceExpression));
}
internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
c.Expression = this.VisitExpression(c.Expression);
int compareWhen = 0;
// Find the ELSE if it exists.
for (int i = 0, n = c.Whens.Count; i < n; i++) {
if (c.Whens[i].Match == null) {
compareWhen = i;
break;
}
}
c.Whens[compareWhen].Match = VisitExpression(c.Whens[compareWhen].Match);
c.Whens[compareWhen].Value = VisitExpression(c.Whens[compareWhen].Value);
// Compare each other when value to the compare when
List<SqlWhen> newWhens = new List<SqlWhen>();
bool allValuesLiteral = true;
for (int i = 0, n = c.Whens.Count; i < n; i++) {
if (compareWhen != i) {
SqlWhen when = c.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.Value = this.VisitExpression(when.Value);
if (!SqlComparer.AreEqual(c.Whens[compareWhen].Value, when.Value)) {
newWhens.Add(when);
}
allValuesLiteral = allValuesLiteral && when.Value.NodeType == SqlNodeType.Value;
}
}
newWhens.Add(c.Whens[compareWhen]);
// Did everything reduce to a single CASE?
SqlExpression rewrite = TryToConsolidateAllValueExpressions(newWhens.Count, c.Whens[compareWhen].Value);
if (rewrite != null)
return rewrite;
// Can it be a conjuction (or disjunction) of clauses?
rewrite = TryToWriteAsSimpleBooleanExpression(c.ClrType, c.Expression, newWhens, allValuesLiteral);
if (rewrite != null)
return rewrite;
// Can any WHEN clauses be reduced to fall into the ELSE clause?
rewrite = TryToWriteAsReducedCase(c.ClrType, c.Expression, newWhens, c.Whens[compareWhen].Match, c.Whens.Count);
if (rewrite != null)
return rewrite;
return c;
}
/// <summary>
/// When there is exactly one when clause in the CASE:
///
/// CASE XXX
/// WHEN AAA THEN YYY ===> YYY
/// END
///
/// Then, just reduce it to the value.
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlExpression TryToConsolidateAllValueExpressions(int valueCount, SqlExpression value) {
if (valueCount == 1) {
return value;
}
return null;
}
/// <summary>
/// For CASE statements which represent boolean values:
///
/// CASE XXX
/// WHEN AAA THEN true ===> (XXX==AAA) || (XXX==BBB)
/// WHEN BBB THEN true
/// ELSE false
/// etc.
/// END
///
/// Also,
///
/// CASE XXX
/// WHEN AAA THEN false ===> (XXX!=AAA) && (XXX!=BBB)
/// WHEN BBB THEN false
/// ELSE true
/// etc.
/// END
///
/// The reduce to a conjunction or disjunction of equality or inequality.
/// The possibility of NULL in XXX is taken into account.
/// </summary>
private SqlExpression TryToWriteAsSimpleBooleanExpression(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, bool allValuesLiteral) {
SqlExpression rewrite = null;
if (caseType == typeof(bool) && allValuesLiteral) {
bool? holdsNull = SqlExpressionNullability.CanBeNull(discriminator);
// The discriminator can't hold a NULL.
// In this case, we don't need the special fallback that CASE-ELSE gives.
// We can just construct a boolean operation.
bool? whenValue = null;
for (int i = 0; i < newWhens.Count; ++i) {
SqlValue lit = (SqlValue)newWhens[i].Value; // Must be SqlValue because of allValuesLiteral.
bool value = (bool)lit.Value; // Must be bool because of caseType==typeof(bool).
if (newWhens[i].Match != null) { // Skip the ELSE
if (value) {
rewrite = sql.OrAccumulate(rewrite, sql.Binary(SqlNodeType.EQ, discriminator, newWhens[i].Match));
}
else {
rewrite = sql.AndAccumulate(rewrite, sql.Binary(SqlNodeType.NE, discriminator, newWhens[i].Match));
}
}
else {
whenValue = value;
}
}
// If it could possibly hold null values.
if (holdsNull != false && whenValue != null) {
if (whenValue == true) {
rewrite = sql.OrAccumulate(rewrite, sql.Unary(SqlNodeType.IsNull, discriminator, discriminator.SourceExpression));
}
else {
rewrite = sql.AndAccumulate(rewrite, sql.Unary(SqlNodeType.IsNotNull, discriminator, discriminator.SourceExpression));
}
}
}
return rewrite;
}
/// <summary>
/// Remove any WHEN clauses which have the same value as ELSE.
///
/// CASE XXX CASE XXX
/// WHEN AAA THEN YYY ===> WHEN AAA THEN YYY
/// WHEN BBB THEN ZZZ WHEN CCC THEN YYY
/// WHEN CCC THEN YYY ELSE ZZZ
/// ELSE ZZZ END
/// END
///
/// </summary>
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
private SqlExpression TryToWriteAsReducedCase(Type caseType, SqlExpression discriminator, List<SqlWhen> newWhens, SqlExpression elseCandidate, int originalWhenCount) {
if (newWhens.Count != originalWhenCount) {
// Some whens were the same as the comparand.
if (elseCandidate == null) {
// -and- the comparand is ELSE (value == null).
// In this case, simplify the CASE to eliminate everything equivalent to ELSE.
return new SqlSimpleCase(caseType, discriminator, newWhens, discriminator.SourceExpression);
}
}
return null;
}
}
}
}
|