File: SqlClient\Query\SqlRetyper.cs
Project: ndp\fx\src\DLinq\Dlinq\System.Data.Linq.csproj (System.Data.Linq)
using System.Linq;
 
namespace System.Data.Linq.SqlClient {
    using System.Data.Linq.Mapping;
 
    internal class SqlRetyper {
        Visitor visitor;
 
        internal SqlRetyper(TypeSystemProvider typeProvider, MetaModel model) {
            this.visitor = new Visitor(typeProvider, model);
        }
 
        internal SqlNode Retype(SqlNode node) {
            return this.visitor.Visit(node);
        }
 
        class Visitor : SqlVisitor {
            private TypeSystemProvider typeProvider;
            private SqlFactory sql;
 
            internal Visitor(TypeSystemProvider typeProvider, MetaModel model) {
                this.sql = new SqlFactory(typeProvider, model);
                this.typeProvider = typeProvider;
            }
 
            internal override SqlExpression VisitColumn(SqlColumn col) {
                return base.VisitColumn(col);
            }
 
            internal override SqlExpression VisitUnaryOperator(SqlUnary uo) {
                base.VisitUnaryOperator(uo);
                if (uo.NodeType != SqlNodeType.Convert && uo.Operand != null && uo.Operand.SqlType != null) {
                    uo.SetSqlType(this.typeProvider.PredictTypeForUnary(uo.NodeType, uo.Operand.SqlType));
                }
                return uo;
            }
 
            private static bool CanDbConvert(Type from, Type to) {
                from = System.Data.Linq.SqlClient.TypeSystem.GetNonNullableType(from);
                to = System.Data.Linq.SqlClient.TypeSystem.GetNonNullableType(to);
                if (from == to)
                    return true;
                if (to.IsAssignableFrom(from))
                    return true;
                var tcTo = Type.GetTypeCode(to);
                var tcFrom = Type.GetTypeCode(from);
                switch (tcTo) {
                    case TypeCode.Int16: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte;
                    case TypeCode.Int32: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte || tcFrom == TypeCode.Int16 || tcFrom == TypeCode.UInt16;
                    case TypeCode.Int64: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte || tcFrom == TypeCode.Int16 || tcFrom == TypeCode.UInt16 || tcFrom == TypeCode.Int32 || tcFrom==TypeCode.UInt32;
                    case TypeCode.UInt16: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte;
                    case TypeCode.UInt32: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte || tcFrom == TypeCode.Int16 || tcFrom == TypeCode.UInt16;
                    case TypeCode.UInt64: return tcFrom == TypeCode.Byte || tcFrom == TypeCode.SByte || tcFrom == TypeCode.Int16 || tcFrom == TypeCode.UInt16 || tcFrom == TypeCode.Int32 || tcFrom == TypeCode.UInt32;
                    case TypeCode.Double: return tcFrom == TypeCode.Single;
                    case TypeCode.Decimal: return tcFrom == TypeCode.Single || tcFrom == TypeCode.Double;
                }
                return false;
            }
            internal override SqlExpression VisitBinaryOperator(SqlBinary bo) {
                base.VisitBinaryOperator(bo);
                if (bo.NodeType.IsComparisonOperator() 
                    && bo.Left.ClrType!=typeof(bool) && bo.Right.ClrType!=typeof(bool)) {
                    // Strip unnecessary CONVERT calls. 
                    if (bo.Left.NodeType == SqlNodeType.Convert) {
                        var conv = (SqlUnary)bo.Left;
                        if (CanDbConvert(conv.Operand.ClrType, bo.Right.ClrType) 
                            && conv.Operand.SqlType.ComparePrecedenceTo(bo.Right.SqlType) != 1) {
                            return VisitBinaryOperator(new SqlBinary(bo.NodeType, bo.ClrType, bo.SqlType, conv.Operand, bo.Right));
                        }
                    }
                    if (bo.Right.NodeType == SqlNodeType.Convert) {
                        var conv = (SqlUnary)bo.Right;
                        if (CanDbConvert(conv.Operand.ClrType, bo.Left.ClrType)
                            && conv.Operand.SqlType.ComparePrecedenceTo(bo.Left.SqlType) != 1) {
                            return VisitBinaryOperator(new SqlBinary(bo.NodeType, bo.ClrType, bo.SqlType, bo.Left, conv.Operand));
                        }
                    }
                }
                if (bo.Right != null && bo.NodeType != SqlNodeType.Concat) {
                    SqlExpression left = bo.Left;
                    SqlExpression right = bo.Right;
                    this.CoerceBinaryArgs(ref left, ref right);
                    if (bo.Left != left || bo.Right != right) {
                        bo = sql.Binary(bo.NodeType, left, right);
                    }
                    bo.SetSqlType(typeProvider.PredictTypeForBinary(bo.NodeType, left.SqlType, right.SqlType));
                }
                if (bo.NodeType.IsComparisonOperator()) {
                    // When comparing a unicode value against a non-unicode column, 
                    // we want retype the parameter as non-unicode.
                    Func<SqlExpression, SqlExpression, bool> needsRetype = 
                        (expr, val) => (val.NodeType == SqlNodeType.Value || val.NodeType == SqlNodeType.ClientParameter) && 
                                       !(expr.NodeType == SqlNodeType.Value || expr.NodeType == SqlNodeType.ClientParameter) &&
                                       val.SqlType.IsUnicodeType && !expr.SqlType.IsUnicodeType;
                    SqlSimpleTypeExpression valueToRetype = null;
                    if (needsRetype(bo.Left, bo.Right)) {
                        valueToRetype = (SqlSimpleTypeExpression)bo.Right;
                    } else if (needsRetype(bo.Right, bo.Left)) {
                        valueToRetype = (SqlSimpleTypeExpression)bo.Left;
                    }
                    if(valueToRetype != null) {
                        valueToRetype.SetSqlType(valueToRetype.SqlType.GetNonUnicodeEquivalent());
                    }
                }
                return bo;
            }
 
            internal override SqlExpression VisitIn(SqlIn sin) {
                // Treat the IN as a series of binary comparison expressions (and coerce if necessary).
                // Check to see if any expressions need to change as a result of coercion, where we start
                // with "sin.Expression IN sin.Values" and coerced expressions are "test IN newValues".
                SqlExpression test = sin.Expression;
                bool requiresCoercion = false;
                var newValues = new System.Collections.Generic.List<SqlExpression>(sin.Values.Count);
                ProviderType valueType = null;
                for (int i = 0, n = sin.Values.Count; i < n; i++) {
                    SqlExpression value = sin.Values[i];
                    this.CoerceBinaryArgs(ref test, ref value);
                    if (value != sin.Values[i]) {
                        // Build up 'widest' type by repeatedly applying PredictType
                        valueType = null == valueType
                            ? value.SqlType
                            : this.typeProvider.PredictTypeForBinary(SqlNodeType.EQ, value.SqlType, valueType);
                        requiresCoercion = true;
                    }
                    newValues.Add(value);
                }
                if (test != sin.Expression) {
                    requiresCoercion = true;
                }
                if (requiresCoercion) {
                    ProviderType providerType = this.typeProvider.PredictTypeForBinary(SqlNodeType.EQ, test.SqlType, valueType);
                    sin = new SqlIn(sin.ClrType, providerType, test, newValues, sin.SourceExpression);
                }
                return sin;
            }
 
            internal override SqlExpression VisitLike(SqlLike like) {
                base.VisitLike(like);
                // When comparing a unicode pattern against a non-unicode expression, 
                // we want retype the pattern as non-unicode.
                if (!like.Expression.SqlType.IsUnicodeType && like.Pattern.SqlType.IsUnicodeType &&
                    (like.Pattern.NodeType == SqlNodeType.Value || like.Pattern.NodeType == SqlNodeType.ClientParameter)) {
                    SqlSimpleTypeExpression pattern = (SqlSimpleTypeExpression)like.Pattern;
                    pattern.SetSqlType(pattern.SqlType.GetNonUnicodeEquivalent());
                }
                return like;
            }
 
            internal override SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
                base.VisitScalarSubSelect(ss);
                ss.SetSqlType(ss.Select.Selection.SqlType);
                return ss;
            }
 
            internal override SqlExpression VisitSearchedCase(SqlSearchedCase c) {
                base.VisitSearchedCase(c);
 
                // determine the best common type for all the when and else values
                ProviderType type = c.Whens[0].Value.SqlType;
                for (int i = 1; i < c.Whens.Count; i++) {
                    ProviderType whenType = c.Whens[i].Value.SqlType;
                    type = typeProvider.GetBestType(type, whenType);
                }
                if (c.Else != null) {
                    ProviderType elseType = c.Else.SqlType;
                    type = typeProvider.GetBestType(type, elseType);
                }
 
                // coerce each one          
                foreach (SqlWhen when in c.Whens.Where(w => w.Value.SqlType != type && !w.Value.SqlType.IsRuntimeOnlyType)) {
                    when.Value = sql.UnaryConvert(when.Value.ClrType, type, when.Value, when.Value.SourceExpression);
                }
 
                if (c.Else != null && c.Else.SqlType != type && !c.Else.SqlType.IsRuntimeOnlyType) {
                    c.Else = sql.UnaryConvert(c.Else.ClrType, type, c.Else, c.Else.SourceExpression);
                }
 
                return c;
            }
 
            internal override SqlExpression VisitSimpleCase(SqlSimpleCase c) {
                base.VisitSimpleCase(c);
 
                // determine the best common type for all the when values
                ProviderType type = c.Whens[0].Value.SqlType;
                for (int i = 1; i < c.Whens.Count; i++) {
                    ProviderType whenType = c.Whens[i].Value.SqlType;
                    type = typeProvider.GetBestType(type, whenType);
                }
 
                // coerce each one          
                foreach (SqlWhen when in c.Whens.Where(w => w.Value.SqlType != type && !w.Value.SqlType.IsRuntimeOnlyType)) {
                    when.Value = sql.UnaryConvert(when.Value.ClrType, type, when.Value, when.Value.SourceExpression);
                }
 
                return c;
            }
 
            internal override SqlStatement VisitAssign(SqlAssign sa) {
                base.VisitAssign(sa);
                SqlExpression right = sa.RValue;
                this.CoerceToFirst(sa.LValue, ref right);
                sa.RValue = right;
                return sa;
            }
 
            internal override SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
                for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
                    fc.Arguments[i] = this.VisitExpression(fc.Arguments[i]);
                }
                if (fc.Arguments.Count > 0) {
                    ProviderType oldType = fc.Arguments[0].SqlType;
                    // if this has a real argument (not e.g. the symbol "DAY" in DATEDIFF(DAY,...))
                    if (oldType != null) {
                        ProviderType newType = this.typeProvider.ReturnTypeOfFunction(fc);
                        if (newType != null) {
                            fc.SetSqlType(newType);
                        }
                    }
                }
                return fc;
            }
 
            private void CoerceToFirst(SqlExpression arg1, ref SqlExpression arg2) {
                if (arg1.SqlType != null && arg2.SqlType != null) {
                    if (arg2.NodeType == SqlNodeType.Value) {
                        SqlValue val = (SqlValue)arg2;
                        arg2 = sql.Value(
                            arg1.ClrType, arg1.SqlType,
                            DBConvert.ChangeType(val.Value, arg1.ClrType),
                            val.IsClientSpecified, arg2.SourceExpression
                            );
                    } else if (arg2.NodeType == SqlNodeType.ClientParameter && arg2.SqlType != arg1.SqlType) {
                        SqlClientParameter cp = (SqlClientParameter)arg2;
                        cp.SetSqlType(arg1.SqlType);
                    } else {
                        arg2 = sql.UnaryConvert(arg1.ClrType, arg1.SqlType, arg2, arg2.SourceExpression);
                    }
                }
            }
 
            private void CoerceBinaryArgs(ref SqlExpression arg1, ref SqlExpression arg2)
            {
                if (arg1.SqlType == null || arg2.SqlType == null) return;
 
                if (arg1.SqlType.IsSameTypeFamily(arg2.SqlType)) {
                    CoerceTypeFamily(arg1, arg2);
                }
                else {
                    // Don't coerce bools because predicates and bits have not been resolved yet.
                    // Leave this for booleanizer.
                    if (arg1.ClrType != typeof(bool) && arg2.ClrType != typeof(bool)) {
                        CoerceTypes(ref arg1, ref arg2);
                    }
                }
            }
 
            private void CoerceTypeFamily(SqlExpression arg1, SqlExpression arg2)
            {
                if ((arg1.SqlType.HasPrecisionAndScale && arg2.SqlType.HasPrecisionAndScale && arg1.SqlType != arg2.SqlType) ||
                    SqlFactory.IsSqlHighPrecisionDateTimeType(arg1) || SqlFactory.IsSqlHighPrecisionDateTimeType(arg2)) {
                        ProviderType best = typeProvider.GetBestType(arg1.SqlType, arg2.SqlType);
                        SetSqlTypeIfSimpleExpression(arg1, best);
                        SetSqlTypeIfSimpleExpression(arg2, best);
                        return;
                    }
 
                // The SQL data type DATE is special, in that it has a higher range but lower
                // precedence, so we need to account for that here (DevDiv 175229)
                if (SqlFactory.IsSqlDateType(arg1) && !SqlFactory.IsSqlHighPrecisionDateTimeType(arg2)) {
                    SetSqlTypeIfSimpleExpression(arg2, arg1.SqlType);
                }
                else if (SqlFactory.IsSqlDateType(arg2) && !SqlFactory.IsSqlHighPrecisionDateTimeType(arg1)) {
                    SetSqlTypeIfSimpleExpression(arg1, arg2.SqlType);
                }
            }
 
            private static void SetSqlTypeIfSimpleExpression(SqlExpression expression, ProviderType sqlType)
            {
                SqlSimpleTypeExpression simpleExpression = expression as SqlSimpleTypeExpression;
                if (simpleExpression != null) {
                    simpleExpression.SetSqlType(sqlType);
                }
            }
 
            private void CoerceTypes(ref SqlExpression arg1, ref SqlExpression arg2)
            {
                if (arg2.NodeType == SqlNodeType.Value) {
                    arg2 = CoerceValueForExpression((SqlValue)arg2, arg1);
                }
                else if (arg1.NodeType == SqlNodeType.Value) {
                    arg1 = CoerceValueForExpression((SqlValue)arg1, arg2);
                }
                else if (arg2.NodeType == SqlNodeType.ClientParameter && arg2.SqlType != arg1.SqlType) {
                    ((SqlClientParameter)arg2).SetSqlType(arg1.SqlType);
                }
                else if (arg1.NodeType == SqlNodeType.ClientParameter && arg1.SqlType != arg2.SqlType) {
                    ((SqlClientParameter)arg1).SetSqlType(arg2.SqlType);
                }
                else {
                    int coercionPrecedence = arg1.SqlType.ComparePrecedenceTo(arg2.SqlType);
                    if (coercionPrecedence > 0) {
                        arg2 = sql.UnaryConvert(arg1.ClrType, arg1.SqlType, arg2, arg2.SourceExpression);
                    }
                    else if (coercionPrecedence < 0) {
                        arg1 = sql.UnaryConvert(arg2.ClrType, arg2.SqlType, arg1, arg1.SourceExpression);
                    }
                }
            }
 
            private SqlExpression CoerceValueForExpression(SqlValue value, SqlExpression expression)
            {
                object clrValue = value.Value;
                if (!value.ClrType.IsAssignableFrom(expression.ClrType)) {
                    clrValue = DBConvert.ChangeType(clrValue, expression.ClrType);
                }
                ProviderType newSqlType = typeProvider.ChangeTypeFamilyTo(value.SqlType, expression.SqlType);
                return sql.Value(expression.ClrType, newSqlType, clrValue, value.IsClientSpecified, value.SourceExpression);
            }
        }
    }
}