|
using System;
using System.Collections.Generic;
using System.Data.Linq;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
namespace System.Data.Linq.SqlClient {
internal abstract class SqlVisitor {
int nDepth;
// Visit a SqlNode
[SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification="Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
[SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
internal virtual SqlNode Visit(SqlNode node) {
SqlNode result = null;
if (node == null) {
return null;
}
try {
nDepth++;
CheckRecursionDepth(500, nDepth);
switch (node.NodeType) {
case SqlNodeType.Not:
case SqlNodeType.Not2V:
case SqlNodeType.Negate:
case SqlNodeType.BitNot:
case SqlNodeType.IsNull:
case SqlNodeType.IsNotNull:
case SqlNodeType.Count:
case SqlNodeType.LongCount:
case SqlNodeType.Max:
case SqlNodeType.Min:
case SqlNodeType.Sum:
case SqlNodeType.Avg:
case SqlNodeType.Stddev:
case SqlNodeType.Convert:
case SqlNodeType.ValueOf:
case SqlNodeType.OuterJoinedValue:
case SqlNodeType.ClrLength:
result = this.VisitUnaryOperator((SqlUnary)node);
break;
case SqlNodeType.Lift:
result = this.VisitLift((SqlLift)node);
break;
case SqlNodeType.Add:
case SqlNodeType.Sub:
case SqlNodeType.Mul:
case SqlNodeType.Div:
case SqlNodeType.Mod:
case SqlNodeType.BitAnd:
case SqlNodeType.BitOr:
case SqlNodeType.BitXor:
case SqlNodeType.And:
case SqlNodeType.Or:
case SqlNodeType.GE:
case SqlNodeType.GT:
case SqlNodeType.LE:
case SqlNodeType.LT:
case SqlNodeType.EQ:
case SqlNodeType.NE:
case SqlNodeType.EQ2V:
case SqlNodeType.NE2V:
case SqlNodeType.Concat:
case SqlNodeType.Coalesce:
result = this.VisitBinaryOperator((SqlBinary)node);
break;
case SqlNodeType.Between:
result = this.VisitBetween((SqlBetween)node);
break;
case SqlNodeType.In:
result = this.VisitIn((SqlIn)node);
break;
case SqlNodeType.Like:
result = this.VisitLike((SqlLike)node);
break;
case SqlNodeType.Treat:
result = this.VisitTreat((SqlUnary)node);
break;
case SqlNodeType.Alias:
result = this.VisitAlias((SqlAlias)node);
break;
case SqlNodeType.AliasRef:
result = this.VisitAliasRef((SqlAliasRef)node);
break;
case SqlNodeType.Member:
result = this.VisitMember((SqlMember)node);
break;
case SqlNodeType.Row:
result = this.VisitRow((SqlRow)node);
break;
case SqlNodeType.Column:
result = this.VisitColumn((SqlColumn)node);
break;
case SqlNodeType.ColumnRef:
result = this.VisitColumnRef((SqlColumnRef)node);
break;
case SqlNodeType.Table:
result = this.VisitTable((SqlTable)node);
break;
case SqlNodeType.UserQuery:
result = this.VisitUserQuery((SqlUserQuery)node);
break;
case SqlNodeType.StoredProcedureCall:
result = this.VisitStoredProcedureCall((SqlStoredProcedureCall)node);
break;
case SqlNodeType.UserRow:
result = this.VisitUserRow((SqlUserRow)node);
break;
case SqlNodeType.UserColumn:
result = this.VisitUserColumn((SqlUserColumn)node);
break;
case SqlNodeType.Multiset:
case SqlNodeType.ScalarSubSelect:
case SqlNodeType.Element:
case SqlNodeType.Exists:
result = this.VisitSubSelect((SqlSubSelect)node);
break;
case SqlNodeType.Join:
result = this.VisitJoin((SqlJoin)node);
break;
case SqlNodeType.Select:
result = this.VisitSelect((SqlSelect)node);
break;
case SqlNodeType.Parameter:
result = this.VisitParameter((SqlParameter)node);
break;
case SqlNodeType.New:
result = this.VisitNew((SqlNew)node);
break;
case SqlNodeType.Link:
result = this.VisitLink((SqlLink)node);
break;
case SqlNodeType.ClientQuery:
result = this.VisitClientQuery((SqlClientQuery)node);
break;
case SqlNodeType.JoinedCollection:
result = this.VisitJoinedCollection((SqlJoinedCollection)node);
break;
case SqlNodeType.Value:
result = this.VisitValue((SqlValue)node);
break;
case SqlNodeType.ClientArray:
result = this.VisitClientArray((SqlClientArray)node);
break;
case SqlNodeType.Insert:
result = this.VisitInsert((SqlInsert)node);
break;
case SqlNodeType.Update:
result = this.VisitUpdate((SqlUpdate)node);
break;
case SqlNodeType.Delete:
result = this.VisitDelete((SqlDelete)node);
break;
case SqlNodeType.MemberAssign:
result = this.VisitMemberAssign((SqlMemberAssign)node);
break;
case SqlNodeType.Assign:
result = this.VisitAssign((SqlAssign)node);
break;
case SqlNodeType.Block:
result = this.VisitBlock((SqlBlock)node);
break;
case SqlNodeType.SearchedCase:
result = this.VisitSearchedCase((SqlSearchedCase)node);
break;
case SqlNodeType.ClientCase:
result = this.VisitClientCase((SqlClientCase)node);
break;
case SqlNodeType.SimpleCase:
result = this.VisitSimpleCase((SqlSimpleCase)node);
break;
case SqlNodeType.TypeCase:
result = this.VisitTypeCase((SqlTypeCase)node);
break;
case SqlNodeType.Union:
result = this.VisitUnion((SqlUnion)node);
break;
case SqlNodeType.ExprSet:
result = this.VisitExprSet((SqlExprSet)node);
break;
case SqlNodeType.Variable:
result = this.VisitVariable((SqlVariable)node);
break;
case SqlNodeType.DoNotVisit:
result = this.VisitDoNotVisit((SqlDoNotVisitExpression)node);
break;
case SqlNodeType.OptionalValue:
result = this.VisitOptionalValue((SqlOptionalValue)node);
break;
case SqlNodeType.FunctionCall:
result = this.VisitFunctionCall((SqlFunctionCall)node);
break;
case SqlNodeType.TableValuedFunctionCall:
result = this.VisitTableValuedFunctionCall((SqlTableValuedFunctionCall)node);
break;
case SqlNodeType.MethodCall:
result = this.VisitMethodCall((SqlMethodCall)node);
break;
case SqlNodeType.Nop:
result = this.VisitNop((SqlNop)node);
break;
case SqlNodeType.SharedExpression:
result = this.VisitSharedExpression((SqlSharedExpression)node);
break;
case SqlNodeType.SharedExpressionRef:
result = this.VisitSharedExpressionRef((SqlSharedExpressionRef)node);
break;
case SqlNodeType.SimpleExpression:
result = this.VisitSimpleExpression((SqlSimpleExpression)node);
break;
case SqlNodeType.Grouping:
result = this.VisitGrouping((SqlGrouping)node);
break;
case SqlNodeType.DiscriminatedType:
result = this.VisitDiscriminatedType((SqlDiscriminatedType)node);
break;
case SqlNodeType.DiscriminatorOf:
result = this.VisitDiscriminatorOf((SqlDiscriminatorOf)node);
break;
case SqlNodeType.ClientParameter:
result = this.VisitClientParameter((SqlClientParameter)node);
break;
case SqlNodeType.RowNumber:
result = this.VisitRowNumber((SqlRowNumber)node);
break;
case SqlNodeType.IncludeScope:
result = this.VisitIncludeScope((SqlIncludeScope)node);
break;
default:
throw Error.UnexpectedNode(node);
}
} finally {
this.nDepth--;
}
return result;
}
/// <summary>
/// This method checks the recursion level to help diagnose/prevent
/// infinite recursion in debug builds. Calls are ommitted in non debug builds.
/// </summary>
[SuppressMessage("Microsoft.Usage", "CA2201:DoNotRaiseReservedExceptionTypes", Justification="Debug-only code.")]
[Conditional("DEBUG")]
internal static void CheckRecursionDepth(int maxLevel, int level) {
if (level > maxLevel) {
System.Diagnostics.Debug.Assert(false);
//**********************************************************************
// EXCLUDING FROM LOCALIZATION.
// Reason: This code only executes in DEBUG.
throw new Exception("Infinite Descent?");
//**********************************************************************
}
}
[SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
internal object Eval(SqlExpression expr) {
if (expr.NodeType == SqlNodeType.Value) {
return ((SqlValue)expr).Value;
}
throw Error.UnexpectedNode(expr.NodeType);
}
internal virtual SqlExpression VisitDoNotVisit(SqlDoNotVisitExpression expr) {
return expr.Expression;
}
internal virtual SqlRowNumber VisitRowNumber(SqlRowNumber rowNumber) {
for (int i = 0, n = rowNumber.OrderBy.Count; i < n; i++) {
rowNumber.OrderBy[i].Expression = this.VisitExpression(rowNumber.OrderBy[i].Expression);
}
return rowNumber;
}
internal virtual SqlExpression VisitExpression(SqlExpression exp) {
return (SqlExpression)this.Visit(exp);
}
internal virtual SqlSelect VisitSequence(SqlSelect sel) {
return (SqlSelect)this.Visit(sel);
}
internal virtual SqlExpression VisitNop(SqlNop nop) {
return nop;
}
internal virtual SqlExpression VisitLift(SqlLift lift) {
lift.Expression = this.VisitExpression(lift.Expression);
return lift;
}
internal virtual SqlExpression VisitUnaryOperator(SqlUnary uo) {
uo.Operand = this.VisitExpression(uo.Operand);
return uo;
}
internal virtual SqlExpression VisitBinaryOperator(SqlBinary bo) {
bo.Left = this.VisitExpression(bo.Left);
bo.Right = this.VisitExpression(bo.Right);
return bo;
}
internal virtual SqlAlias VisitAlias(SqlAlias a) {
a.Node = this.Visit(a.Node);
return a;
}
internal virtual SqlExpression VisitAliasRef(SqlAliasRef aref) {
return aref;
}
internal virtual SqlNode VisitMember(SqlMember m) {
m.Expression = this.VisitExpression(m.Expression);
return m;
}
internal virtual SqlExpression VisitCast(SqlUnary c) {
c.Operand = this.VisitExpression(c.Operand);
return c;
}
internal virtual SqlExpression VisitTreat(SqlUnary t) {
t.Operand = this.VisitExpression(t.Operand);
return t;
}
internal virtual SqlTable VisitTable(SqlTable tab) {
return tab;
}
internal virtual SqlUserQuery VisitUserQuery(SqlUserQuery suq) {
for (int i = 0, n = suq.Arguments.Count; i < n; i++) {
suq.Arguments[i] = this.VisitExpression(suq.Arguments[i]);
}
suq.Projection = this.VisitExpression(suq.Projection);
for (int i = 0, n = suq.Columns.Count; i < n; i++) {
suq.Columns[i] = (SqlUserColumn) this.Visit(suq.Columns[i]);
}
return suq;
}
internal virtual SqlStoredProcedureCall VisitStoredProcedureCall(SqlStoredProcedureCall spc) {
for (int i = 0, n = spc.Arguments.Count; i < n; i++) {
spc.Arguments[i] = this.VisitExpression(spc.Arguments[i]);
}
spc.Projection = this.VisitExpression(spc.Projection);
for (int i = 0, n = spc.Columns.Count; i < n; i++) {
spc.Columns[i] = (SqlUserColumn) this.Visit(spc.Columns[i]);
}
return spc;
}
internal virtual SqlExpression VisitUserColumn(SqlUserColumn suc) {
return suc;
}
internal virtual SqlExpression VisitUserRow(SqlUserRow row) {
return row;
}
internal virtual SqlRow VisitRow(SqlRow row) {
for (int i = 0, n = row.Columns.Count; i < n; i++) {
row.Columns[i].Expression = this.VisitExpression(row.Columns[i].Expression);
}
return row;
}
internal virtual SqlExpression VisitNew(SqlNew sox) {
for (int i = 0, n = sox.Args.Count; i < n; i++) {
sox.Args[i] = this.VisitExpression(sox.Args[i]);
}
for (int i = 0, n = sox.Members.Count; i < n; i++) {
sox.Members[i].Expression = this.VisitExpression(sox.Members[i].Expression);
}
return sox;
}
internal virtual SqlNode VisitLink(SqlLink link) {
// Don't visit the link's Expansion
for (int i = 0, n = link.KeyExpressions.Count; i < n; i++) {
link.KeyExpressions[i] = this.VisitExpression(link.KeyExpressions[i]);
}
return link;
}
internal virtual SqlExpression VisitClientQuery(SqlClientQuery cq) {
for (int i = 0, n = cq.Arguments.Count; i < n; i++) {
cq.Arguments[i] = this.VisitExpression(cq.Arguments[i]);
}
return cq;
}
internal virtual SqlExpression VisitJoinedCollection(SqlJoinedCollection jc) {
jc.Expression = this.VisitExpression(jc.Expression);
jc.Count = this.VisitExpression(jc.Count);
return jc;
}
internal virtual SqlExpression VisitClientArray(SqlClientArray scar) {
for (int i = 0, n = scar.Expressions.Count; i < n; i++) {
scar.Expressions[i] = this.VisitExpression(scar.Expressions[i]);
}
return scar;
}
internal virtual SqlExpression VisitClientParameter(SqlClientParameter cp) {
return cp;
}
internal virtual SqlExpression VisitColumn(SqlColumn col) {
col.Expression = this.VisitExpression(col.Expression);
return col;
}
internal virtual SqlExpression VisitColumnRef(SqlColumnRef cref) {
return cref;
}
internal virtual SqlExpression VisitParameter(SqlParameter p) {
return p;
}
internal virtual SqlExpression VisitValue(SqlValue value) {
return value;
}
internal virtual SqlExpression VisitSubSelect(SqlSubSelect ss) {
switch(ss.NodeType) {
case SqlNodeType.ScalarSubSelect: return this.VisitScalarSubSelect(ss);
case SqlNodeType.Multiset: return this.VisitMultiset(ss);
case SqlNodeType.Element: return this.VisitElement(ss);
case SqlNodeType.Exists: return this.VisitExists(ss);
}
throw Error.UnexpectedNode(ss.NodeType);
}
internal virtual SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
ss.Select = this.VisitSequence(ss.Select);
return ss;
}
internal virtual SqlExpression VisitMultiset(SqlSubSelect sms) {
sms.Select = this.VisitSequence(sms.Select);
return sms;
}
internal virtual SqlExpression VisitElement(SqlSubSelect elem) {
elem.Select = this.VisitSequence(elem.Select);
return elem;
}
internal virtual SqlExpression VisitExists(SqlSubSelect sqlExpr) {
sqlExpr.Select = this.VisitSequence(sqlExpr.Select);
return sqlExpr;
}
internal virtual SqlSource VisitJoin(SqlJoin join) {
join.Left = this.VisitSource(join.Left);
join.Right = this.VisitSource(join.Right);
join.Condition = this.VisitExpression(join.Condition);
return join;
}
internal virtual SqlSource VisitSource(SqlSource source) {
return (SqlSource) this.Visit(source);
}
internal virtual SqlSelect VisitSelectCore(SqlSelect select) {
select.From = this.VisitSource(select.From);
select.Where = this.VisitExpression(select.Where);
for (int i = 0, n = select.GroupBy.Count; i < n; i++) {
select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
}
select.Having = this.VisitExpression(select.Having);
for (int i = 0, n = select.OrderBy.Count; i < n; i++) {
select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
}
select.Top = this.VisitExpression(select.Top);
select.Row = (SqlRow)this.Visit(select.Row);
return select;
}
internal virtual SqlSelect VisitSelect(SqlSelect select) {
select = this.VisitSelectCore(select);
select.Selection = this.VisitExpression(select.Selection);
return select;
}
internal virtual SqlStatement VisitInsert(SqlInsert insert) {
insert.Table = (SqlTable)this.Visit(insert.Table);
insert.Expression = this.VisitExpression(insert.Expression);
insert.Row = (SqlRow)this.Visit(insert.Row);
return insert;
}
internal virtual SqlStatement VisitUpdate(SqlUpdate update) {
update.Select = this.VisitSequence(update.Select);
for (int i = 0, n = update.Assignments.Count; i < n; i++) {
update.Assignments[i] = (SqlAssign)this.Visit(update.Assignments[i]);
}
return update;
}
internal virtual SqlStatement VisitDelete(SqlDelete delete) {
delete.Select = this.VisitSequence(delete.Select);
return delete;
}
internal virtual SqlMemberAssign VisitMemberAssign(SqlMemberAssign ma) {
ma.Expression = this.VisitExpression(ma.Expression);
return ma;
}
internal virtual SqlStatement VisitAssign(SqlAssign sa) {
sa.LValue = this.VisitExpression(sa.LValue);
sa.RValue = this.VisitExpression(sa.RValue);
return sa;
}
internal virtual SqlBlock VisitBlock(SqlBlock b) {
for (int i = 0, n = b.Statements.Count; i < n; i++) {
b.Statements[i] = (SqlStatement)this.Visit(b.Statements[i]);
}
return b;
}
internal virtual SqlExpression VisitSearchedCase(SqlSearchedCase c) {
for (int i = 0, n = c.Whens.Count; i < n; i++) {
SqlWhen when = c.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.Value = this.VisitExpression(when.Value);
}
c.Else = this.VisitExpression(c.Else);
return c;
}
internal virtual SqlExpression VisitClientCase(SqlClientCase c) {
c.Expression = this.VisitExpression(c.Expression);
for (int i = 0, n = c.Whens.Count; i < n; i++) {
SqlClientWhen when = c.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.Value = this.VisitExpression(when.Value);
}
return c;
}
internal virtual SqlExpression VisitSimpleCase(SqlSimpleCase c) {
c.Expression = this.VisitExpression(c.Expression);
for (int i = 0, n = c.Whens.Count; i < n; i++) {
SqlWhen when = c.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.Value = this.VisitExpression(when.Value);
}
return c;
}
internal virtual SqlExpression VisitTypeCase(SqlTypeCase tc) {
tc.Discriminator = this.VisitExpression(tc.Discriminator);
for (int i = 0, n = tc.Whens.Count; i < n; i++) {
SqlTypeCaseWhen when = tc.Whens[i];
when.Match = this.VisitExpression(when.Match);
when.TypeBinding = this.VisitExpression(when.TypeBinding);
}
return tc;
}
internal virtual SqlNode VisitUnion(SqlUnion su) {
su.Left = this.Visit(su.Left);
su.Right = this.Visit(su.Right);
return su;
}
internal virtual SqlExpression VisitExprSet(SqlExprSet xs) {
for (int i = 0, n = xs.Expressions.Count; i < n; i++) {
xs.Expressions[i] = this.VisitExpression(xs.Expressions[i]);
}
return xs;
}
internal virtual SqlExpression VisitVariable(SqlVariable v) {
return v;
}
internal virtual SqlExpression VisitOptionalValue(SqlOptionalValue sov) {
sov.HasValue = this.VisitExpression(sov.HasValue);
sov.Value = this.VisitExpression(sov.Value);
return sov;
}
internal virtual SqlExpression VisitBetween(SqlBetween between) {
between.Expression = this.VisitExpression(between.Expression);
between.Start = this.VisitExpression(between.Start);
between.End = this.VisitExpression(between.End);
return between;
}
internal virtual SqlExpression VisitIn(SqlIn sin) {
sin.Expression = this.VisitExpression(sin.Expression);
for (int i = 0, n = sin.Values.Count; i < n; i++) {
sin.Values[i] = this.VisitExpression(sin.Values[i]);
}
return sin;
}
internal virtual SqlExpression VisitLike(SqlLike like) {
like.Expression = this.VisitExpression(like.Expression);
like.Pattern = this.VisitExpression(like.Pattern);
like.Escape = this.VisitExpression(like.Escape);
return like;
}
internal virtual SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
fc.Arguments[i] = this.VisitExpression(fc.Arguments[i]);
}
return fc;
}
internal virtual SqlExpression VisitTableValuedFunctionCall(SqlTableValuedFunctionCall fc) {
for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
fc.Arguments[i] = this.VisitExpression(fc.Arguments[i]);
}
return fc;
}
internal virtual SqlExpression VisitMethodCall(SqlMethodCall mc) {
mc.Object = this.VisitExpression(mc.Object);
for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
mc.Arguments[i] = this.VisitExpression(mc.Arguments[i]);
}
return mc;
}
internal virtual SqlExpression VisitSharedExpression(SqlSharedExpression shared) {
shared.Expression = this.VisitExpression(shared.Expression);
return shared;
}
internal virtual SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
return sref;
}
internal virtual SqlExpression VisitSimpleExpression(SqlSimpleExpression simple) {
simple.Expression = this.VisitExpression(simple.Expression);
return simple;
}
internal virtual SqlExpression VisitGrouping(SqlGrouping g) {
g.Key = this.VisitExpression(g.Key);
g.Group = this.VisitExpression(g.Group);
return g;
}
internal virtual SqlExpression VisitDiscriminatedType(SqlDiscriminatedType dt) {
dt.Discriminator = this.VisitExpression(dt.Discriminator);
return dt;
}
internal virtual SqlExpression VisitDiscriminatorOf(SqlDiscriminatorOf dof) {
dof.Object = this.VisitExpression(dof.Object);
return dof;
}
internal virtual SqlNode VisitIncludeScope(SqlIncludeScope node) {
node.Child = this.Visit(node.Child);
return node;
}
#if DEBUG
int refersDepth;
#endif
internal bool RefersToColumn(SqlExpression exp, SqlColumn col) {
#if DEBUG
try {
refersDepth++;
System.Diagnostics.Debug.Assert(refersDepth < 20);
#endif
if (exp != null) {
switch (exp.NodeType) {
case SqlNodeType.Column:
return exp == col || this.RefersToColumn(((SqlColumn)exp).Expression, col);
case SqlNodeType.ColumnRef:
SqlColumnRef cref = (SqlColumnRef)exp;
return cref.Column == col || this.RefersToColumn(cref.Column.Expression, col);
case SqlNodeType.ExprSet:
SqlExprSet set = (SqlExprSet)exp;
for (int i = 0, n = set.Expressions.Count; i < n; i++) {
if (this.RefersToColumn(set.Expressions[i], col)) {
return true;
}
}
break;
case SqlNodeType.OuterJoinedValue:
return this.RefersToColumn(((SqlUnary)exp).Operand, col);
}
}
return false;
#if DEBUG
}
finally {
refersDepth--;
}
#endif
}
}
}
|