File: SqlClient\Query\SqlNamer.cs
Project: ndp\fx\src\DLinq\Dlinq\System.Data.Linq.csproj (System.Data.Linq)
using System;
using System.Collections.Generic;
using System.Text;
using System.Data.Linq;
using System.Diagnostics.CodeAnalysis;
 
namespace System.Data.Linq.SqlClient {
 
    internal class SqlNamer {
        Visitor visitor;
 
        internal SqlNamer() {
            this.visitor = new Visitor();
        }
 
        internal SqlNode AssignNames(SqlNode node) {
            return this.visitor.Visit(node);
        }
 
        class Visitor : SqlVisitor {
            int aliasCount;
            SqlAlias alias;
            bool makeUnique;
            bool useMappedNames;
            string lastName;
 
            internal Visitor() {
                this.makeUnique = true;
                this.useMappedNames = false;
            }
 
            internal string GetNextAlias() {
                return "t" + (aliasCount++);
            }
 
            internal override SqlAlias VisitAlias(SqlAlias sqlAlias) {
                SqlAlias save = this.alias;
                this.alias = sqlAlias;
                sqlAlias.Node = this.Visit(sqlAlias.Node);
                sqlAlias.Name = this.GetNextAlias();
                this.alias = save;
                return sqlAlias;
            }
 
            internal override SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
                base.VisitScalarSubSelect(ss);
                if (ss.Select.Row.Columns.Count > 0) {
                    System.Diagnostics.Debug.Assert(ss != null && ss.Select != null && ss.Select.Row != null && ss.Select.Row.Columns.Count == 1);
                    // make sure these scalar subselects don't get redundantly named
                    ss.Select.Row.Columns[0].Name = "";
                }
                return ss;
            }
 
            internal override SqlStatement VisitInsert(SqlInsert insert) {
                bool saveMakeUnique = this.makeUnique;
                this.makeUnique = false;
                bool saveUseMappedNames = this.useMappedNames;
                this.useMappedNames = true;
                SqlStatement stmt = base.VisitInsert(insert);
                this.makeUnique = saveMakeUnique;
                this.useMappedNames = saveUseMappedNames;
                return stmt;
            }
 
            internal override SqlStatement VisitUpdate(SqlUpdate update) {
                bool saveMakeUnique = this.makeUnique;
                this.makeUnique = false;
                bool saveUseMappedNames = this.useMappedNames;
                this.useMappedNames = true;
                SqlStatement stmt = base.VisitUpdate(update);
                this.makeUnique = saveMakeUnique;
                this.useMappedNames = saveUseMappedNames;
                return stmt;
            }
 
            internal override SqlSelect VisitSelect(SqlSelect select) {
                select = base.VisitSelect(select);
 
                string[] names = new string[select.Row.Columns.Count];
                for (int i = 0, n = names.Length; i < n; i++) {
                    SqlColumn c = select.Row.Columns[i];
                    string name = c.Name;
                    if (name == null) {
                        name = SqlNamer.DiscoverName(c);
                    }
                    names[i] = name;
                    c.Name = null;
                }
                
                var reservedNames = this.GetColumnNames(select.OrderBy);
 
                for (int i = 0, n = select.Row.Columns.Count; i < n; i++) {
                    SqlColumn c = select.Row.Columns[i];
                    string rootName = names[i];
                    string name = rootName;
                    if (this.makeUnique) {
                        int iName = 1;
                        while (!this.IsUniqueName(select.Row.Columns, reservedNames, c, name)) {
                            iName++;
                            name = rootName + iName;
                        }
                    }
                    c.Name = name;
                    c.Ordinal = i;
                }
 
                return select;
            }
 
            [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
            private bool IsUniqueName(List<SqlColumn> columns, ICollection<string> reservedNames, SqlColumn c, string name) {
                foreach (SqlColumn sc in columns) {
                    if (sc != c && string.Compare(sc.Name, name, StringComparison.OrdinalIgnoreCase) == 0)
                        return false;
                }
 
                if (!IsSimpleColumn(c, name)) {
                    return !reservedNames.Contains(name);
                }
 
                return true;
            }
 
            /// <summary>
            /// An expression is a simple reprojection if it's a column node whose expression is null, or 
            /// whose expression is a column whose name matches the name of the given name or where
            /// where the given name is null or empty.
            /// </summary>
            /// <param name="c"></param>
            /// <returns></returns>
            private static bool IsSimpleColumn(SqlColumn c, string name) {
                if (c.Expression != null) {
                    switch (c.Expression.NodeType) {
                        case SqlNodeType.ColumnRef:
                            var colRef = c.Expression as SqlColumnRef;
                            return String.IsNullOrEmpty(name) || string.Compare(name, colRef.Column.Name, StringComparison.OrdinalIgnoreCase) == 0;
                        default:
                            return false;
                    }
                }
                return true;
            }
 
            internal override SqlExpression VisitExpression(SqlExpression expr) {
                string saveLastName = this.lastName;
                this.lastName = null;
                try {
                    return (SqlExpression)this.Visit(expr);
                }
                finally {
                    this.lastName = saveLastName;
                }
            }
 
            private SqlExpression VisitNamedExpression(SqlExpression expr, string name) {
                string saveLastName = this.lastName;
                this.lastName = name;
                try {
                    return (SqlExpression)this.Visit(expr);
                }
                finally {
                    this.lastName = saveLastName;
                }
            }
 
            internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
                if (cref.Column.Name == null && this.lastName != null) {
                    cref.Column.Name = this.lastName;
                }
                return cref;
            }
 
            internal override SqlExpression VisitNew(SqlNew sox) {
                if (sox.Constructor != null) {
                    System.Reflection.ParameterInfo[] pis = sox.Constructor.GetParameters();
                    for (int i = 0, n = sox.Args.Count; i < n; i++) {
                        sox.Args[i] = this.VisitNamedExpression(sox.Args[i], pis[i].Name);
                    }
                }
                else {
                    for (int i = 0, n = sox.Args.Count; i < n; i++) {
                        sox.Args[i] = this.VisitExpression(sox.Args[i]);
                    }
                }
                foreach (SqlMemberAssign ma in sox.Members) {
                    string n = ma.Member.Name;
                    if (this.useMappedNames) {
                        n = sox.MetaType.GetDataMember(ma.Member).MappedName;
                    }
                    ma.Expression = this.VisitNamedExpression(ma.Expression, n);
                }
                return sox;
            }
 
            internal override SqlExpression VisitGrouping(SqlGrouping g) {
                g.Key = this.VisitNamedExpression(g.Key, "Key");
                g.Group = this.VisitNamedExpression(g.Group, "Group");
                return g;
            }
 
            internal override SqlExpression VisitOptionalValue(SqlOptionalValue sov) {
                sov.HasValue = this.VisitNamedExpression(sov.HasValue, "test");
                sov.Value = this.VisitExpression(sov.Value);
                return sov;
            }
 
            internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
                mc.Object = this.VisitExpression(mc.Object);
                System.Reflection.ParameterInfo[] pis = mc.Method.GetParameters();
                for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
                    mc.Arguments[i] = this.VisitNamedExpression(mc.Arguments[i], pis[i].Name);
                }
                return mc;
            }
 
 
            ICollection<string> GetColumnNames(IEnumerable<SqlOrderExpression> orderList)
            {
                var visitor = new ColumnNameGatherer();
 
                foreach (var expr in orderList) {
                    visitor.Visit(expr.Expression);
                }
 
                return visitor.Names;
            }
        }
 
        internal static string DiscoverName(SqlExpression e) {
            if (e != null) {
                switch (e.NodeType) {
                    case SqlNodeType.Column:
                        return DiscoverName(((SqlColumn)e).Expression);
                    case SqlNodeType.ColumnRef:
                        SqlColumnRef cref = (SqlColumnRef)e;
                        if (cref.Column.Name != null) return cref.Column.Name;
                        return DiscoverName(cref.Column);
                    case SqlNodeType.ExprSet:
                        SqlExprSet eset = (SqlExprSet)e;
                        return DiscoverName(eset.Expressions[0]);
                }
            }
            return "value";
        }
        
        class ColumnNameGatherer : SqlVisitor {
            public HashSet<string> Names { get; set; }
 
            public ColumnNameGatherer()
                : base() {
                this.Names = new HashSet<string>();
            }
 
            internal override SqlExpression VisitColumn(SqlColumn col) {
                if (!String.IsNullOrEmpty(col.Name)) {
                    this.Names.Add(col.Name);
                }
 
                return base.VisitColumn(col);
            }
 
            internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
                Visit(cref.Column);
 
                return base.VisitColumnRef(cref);
            }
        }
    }
}