File: System\Data\Common\Internal\Materialization\CoordinatorScratchpad.cs
Project: ndp\fx\src\DataEntity\System.Data.Entity.csproj (System.Data.Entity)
//------------------------------------------------------------------------------
// <copyright file="CoordinatorScratchpad.cs" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
// <owner current="true" primary="true">Microsoft</owner>
// <owner current="true" primary="false">Microsoft</owner>
//------------------------------------------------------------------------------
 
using System.Collections.Generic;
using System.Data.Objects.ELinq;
using System.Data.Query.InternalTrees;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Security;
using System.Security.Permissions;
 
namespace System.Data.Common.Internal.Materialization
{
    /// <summary>
    /// Used in the Translator to aggregate information about a (nested) reader 
    /// coordinator. After the translator visits the columnMaps, it will compile
    /// the coordinator(s) which produces an immutable CoordinatorFactory that 
    /// can be shared amongst many query instances.
    /// </summary>
    internal class CoordinatorScratchpad
    {
        #region private state
 
        private readonly Type _elementType;
        private CoordinatorScratchpad _parent;
        private readonly List<CoordinatorScratchpad> _nestedCoordinatorScratchpads;
        /// <summary>
        /// Map from original expressions to expressions with detailed error handling.
        /// </summary>
        private readonly Dictionary<Expression, Expression> _expressionWithErrorHandlingMap;
        /// <summary>
        /// Expressions that should be precompiled (i.e. reduced to constants in 
        /// compiled delegates.
        /// </summary>
        private readonly HashSet<LambdaExpression> _inlineDelegates;
 
        #endregion
 
        #region constructor
 
        internal CoordinatorScratchpad(Type elementType)
        {
            _elementType = elementType;
            _nestedCoordinatorScratchpads = new List<CoordinatorScratchpad>();
            _expressionWithErrorHandlingMap = new Dictionary<Expression, Expression>();
            _inlineDelegates = new HashSet<LambdaExpression>();
        }
 
        #endregion
 
        #region "public" surface area
 
        /// <summary>
        /// For nested collections, returns the parent coordinator.
        /// </summary>
        internal CoordinatorScratchpad Parent
        {
            get { return _parent; }
        }
 
        /// <summary>
        /// Gets or sets an Expression setting key values (these keys are used
        /// to determine when a collection has entered a new chapter) from the
        /// underlying store data reader.
        /// </summary>
        internal Expression SetKeys { get; set; }
 
        /// <summary>
        /// Gets or sets an Expression returning 'true' when the key values for 
        /// the current nested result (see SetKeys) are equal to the current key  
        /// values on the underlying data reader.
        /// </summary>
        internal Expression CheckKeys { get; set; }
 
        /// <summary>
        /// Gets or sets an expression returning 'true' if the current row in 
        /// the underlying data reader contains an element of the collection.
        /// </summary>
        internal Expression HasData { get; set; }
 
        /// <summary>
        /// Gets or sets an Expression yielding an element of the current collection
        /// given values in the underlying data reader.
        /// </summary>
        internal Expression Element { get; set; }
 
        /// <summary>
        /// Gets or sets an Expression initializing the collection storing results from this coordinator.
        /// </summary>
        internal Expression InitializeCollection { get; set; }
 
        /// <summary>
        /// Indicates which Shaper.State slot is home for this collection's coordinator.
        /// Used by Parent to pull out nested collection aggregators/streamers.
        /// </summary>
        internal int StateSlotNumber { get; set; }
 
        /// <summary>
        /// Gets or sets the depth of the current coordinator. A root collection has depth 0.
        /// </summary>
        internal int Depth { get; set; }
 
        /// <summary>
        /// List of all record types that we can return at this level in the query.
        /// </summary>
        private List<RecordStateScratchpad> _recordStateScratchpads;
 
        /// <summary>
        /// Allows sub-expressions to register an 'interest' in exceptions thrown when reading elements
        /// for this coordinator. When an exception is thrown, we rerun the delegate using the slower
        /// but more error-friendly versions of expressions (e.g. reader.GetValue + type check instead
        /// of reader.GetInt32())
        /// </summary>
        /// <param name="expression">The lean and mean raw version of the expression</param>
        /// <param name="expressionWithErrorHandling">The slower version of the same expression with better
        /// error handling</param>
        internal void AddExpressionWithErrorHandling(Expression expression, Expression expressionWithErrorHandling)
        {
            _expressionWithErrorHandlingMap[expression] = expressionWithErrorHandling;
        }
 
        /// <summary>
        /// Registers a lambda expression for pre-compilation (i.e. reduction to a constant expression)
        /// within materialization expression. Otherwise, the expression will be compiled every time
        /// the enclosing delegate is invoked.
        /// </summary>
        /// <param name="expression">Lambda expression to register.</param>
        internal void AddInlineDelegate(LambdaExpression expression)
        {
            _inlineDelegates.Add(expression);
        }
 
        /// <summary>
        /// Registers a coordinator for a nested collection contained in elements of this collection.
        /// </summary>
        internal void AddNestedCoordinator(CoordinatorScratchpad nested)
        {
            Debug.Assert(nested.Depth == this.Depth + 1, "can only nest depth + 1");
            nested._parent = this;
            _nestedCoordinatorScratchpads.Add(nested);
        }
 
        /// <summary>
        /// Use the information stored on the scratchpad to compile an immutable factory used
        /// to construct the coordinators used at runtime when materializing results.
        /// </summary>
        [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)]
        internal CoordinatorFactory Compile()
        {
            RecordStateFactory[] recordStateFactories;
            if (null != _recordStateScratchpads)
            {
                recordStateFactories = new RecordStateFactory[_recordStateScratchpads.Count];
                for (int i = 0; i < recordStateFactories.Length; i++)
                {
                    recordStateFactories[i] = _recordStateScratchpads[i].Compile();
                }
            }
            else
            {
                recordStateFactories = new RecordStateFactory[0];
            }
 
            CoordinatorFactory[] nestedCoordinators = new CoordinatorFactory[_nestedCoordinatorScratchpads.Count];
            for (int i = 0; i < nestedCoordinators.Length; i++)
            {
                nestedCoordinators[i] = _nestedCoordinatorScratchpads[i].Compile();
            }
 
            // compile inline delegates
            ReplacementExpressionVisitor replacementVisitor = new ReplacementExpressionVisitor(null, this._inlineDelegates);
            Expression element = new SecurityBoundaryExpressionVisitor().Visit(replacementVisitor.Visit(this.Element));
 
            // substitute expressions that have error handlers into a new expression (used
            // when a more detailed exception message is needed)
            replacementVisitor = new ReplacementExpressionVisitor(this._expressionWithErrorHandlingMap, this._inlineDelegates);
            Expression elementWithErrorHandling = new SecurityBoundaryExpressionVisitor().Visit(replacementVisitor.Visit(this.Element));
 
            CoordinatorFactory result = (CoordinatorFactory)Activator.CreateInstance(typeof(CoordinatorFactory<>).MakeGenericType(_elementType), new object[] {
                                                            this.Depth, 
                                                            this.StateSlotNumber, 
                                                            this.HasData, 
                                                            this.SetKeys, 
                                                            this.CheckKeys, 
                                                            nestedCoordinators, 
                                                            element,
                                                            elementWithErrorHandling,
                                                            this.InitializeCollection,
                                                            recordStateFactories
                                                            });
            return result;
        }
 
        /// <summary>
        /// Allocates a new RecordStateScratchpad and adds it to the list of the ones we're
        /// responsible for; will create the list if it hasn't alread been created.
        /// </summary>
        internal RecordStateScratchpad CreateRecordStateScratchpad()
        {
            RecordStateScratchpad recordStateScratchpad = new RecordStateScratchpad();
 
            if (null == _recordStateScratchpads)
            {
                _recordStateScratchpads = new List<RecordStateScratchpad>();
            }
            _recordStateScratchpads.Add(recordStateScratchpad);
            return recordStateScratchpad;
        }
        #endregion
 
        #region Nested types
 
        /// <summary>
        /// Visitor supporting (non-recursive) replacement of LINQ sub-expressions and
        /// compilation of inline delegates.
        /// </summary>
        private class ReplacementExpressionVisitor : EntityExpressionVisitor
        {
            // Map from original expressions to replacement expressions.
            private readonly Dictionary<Expression, Expression> _replacementDictionary;
            private readonly HashSet<LambdaExpression> _inlineDelegates;
 
            internal ReplacementExpressionVisitor(Dictionary<Expression, Expression> replacementDictionary,
                HashSet<LambdaExpression> inlineDelegates)
            {
                this._replacementDictionary = replacementDictionary;
                this._inlineDelegates = inlineDelegates;
            }
 
            internal override Expression Visit(Expression expression)
            {
                if (null == expression)
                {
                    return expression;
                }
 
                Expression result;
 
                // check to see if a substitution has been provided for this expression
                Expression replacement;
                if (null != this._replacementDictionary && this._replacementDictionary.TryGetValue(expression, out replacement))
                {
                    // once a substitution is found, we stop walking the sub-expression and
                    // return immediately (since recursive replacement is not needed or wanted)
                    result = replacement;
                }
                else
                {
                    // check if we need to precompile an inline delegate
                    bool preCompile = false;
                    LambdaExpression lambda = null;
 
                    if (expression.NodeType == ExpressionType.Lambda &&
                        null != _inlineDelegates)
                    {
                        lambda = (LambdaExpression)expression;
                        preCompile = _inlineDelegates.Contains(lambda);
                    }
 
                    if (preCompile)
                    {
                        // do replacement in the body of the lambda expression
                        Expression body = Visit(lambda.Body);
 
                        // compile to a delegate
                        result = Expression.Constant(Translator.Compile(body.Type, body));
                    }
                    else
                    {
                        result = base.Visit(expression);
                    }
                }
 
                return result;
            }
        }
 
        /// <summary>
        /// Used to replace references to user expressions with compiled delegates
        /// which represent those expressions.
        /// </summary>
        /// <remarks>
        /// The materialization delegate used to be one big function, which included
        /// user-provided expressions in various places in the tree. Due to security reasons
        /// (Dev11 311339), we need to separate this delegate into two pieces: trusted code,
        /// run under a security assert, and untrusted code, run under the current AppDomain's
        /// permission set.
        /// 
        /// This visitor does that separation by compiling the untrusted code into delegates
        /// and re-inserting them back into the expression tree. When the untrusted code is
        /// run, it will run in another stack frame that does not have a security assert
        /// associated with it; therefore, any attempt to take advantage of MemberAccess
        /// reflection permissions will be blocked by the CLR.
        /// 
        /// The compiled user delegates accept two parameters, one of type DbDataReader
        /// to speed up access to the current reader, and the other of type object[],
        /// which contains all other values that they might require to correctly materialize an object. Most of these
        /// objects require the <see cref="Shaper"/>, so they must be run inside of trusted code.
        /// </remarks>
        private sealed class SecurityBoundaryExpressionVisitor : EntityExpressionVisitor
        {
            private static readonly MethodInfo s_userMaterializationFuncInvokeMethod = typeof(Func<DbDataReader, object[], object>).GetMethod("Invoke");
            private ParameterExpression _values = Expression.Parameter(typeof(object[]), "values");
            private ParameterExpression _reader = Expression.Parameter(typeof(DbDataReader), "reader");
            private List<Expression> _initializationArguments = new List<Expression>();
            private int _userExpressionDepth;
 
            /// <summary>
            /// Used to track the type of a constructor argument or member assignment
            /// when it could be a special type we create (e.g., CompensatingCollection{T}
            /// for collections and Grouping{K,V} for groups).
            /// </summary>
            private Type _userArgumentType;
 
            internal override Expression Visit(Expression exp)
            {
                if (exp == null)
                {
                    return exp;
                }
 
                var nex = exp as NewExpression;
                if (nex != null && _userExpressionDepth >= 1)
                {
                    // We are creating an internal type like CompensatingCollection<T> or Grouping<K, V>
                    // and at this particular point we are sure that the user isn't creating these
                    // since this.userArgumentType is not null.
                    if (_userArgumentType != null && !nex.Type.IsPublic && nex.Type.Assembly == typeof(SecurityBoundaryExpressionVisitor).Assembly)
                    {
                        return this.CreateInitializationArgumentReplacement(nex, _userArgumentType);
                    }
 
                    var constructorParameters = nex.Constructor.GetParameters();
                    var arguments = nex.Arguments;
                    var newArguments = new List<Expression>();
                    for (var i = 0; i < arguments.Count; ++i)
                    {
                        var argument = arguments[i];
 
                        // Visit this argument because it itself could be a user expression e.g.
                        // new { Argument = new SecureString { m_length = 32 } }
                        _userArgumentType = constructorParameters[i].ParameterType;
                        var visitedArgument = this.Visit(argument);
 
                        // If it hasn't changed, it's trusted code. (Untrusted code would have its
                        // Convert and MarkAsUserExpression expressions removed.)
                        if (visitedArgument == argument)
                        {
                            var convert = this.CreateInitializationArgumentReplacement(argument);
 
                            // Change the argument to access the values array.
                            newArguments.Add(convert);
                        }
                        else
                        {
                            newArguments.Add(visitedArgument);
                        }
                    }
 
                    nex = Expression.New(nex.Constructor, newArguments);
 
                    if (_userExpressionDepth == 1)
                    {
                        var userMaterializationFunc = Expression.Lambda<Func<DbDataReader, object[], object>>(nex, _reader, _values).Compile();
 
                        // Convert the constructor invocation into a func that runs without elevated permissions.
                        return Expression.Convert(
                            Expression.Call(
                                Expression.Constant(userMaterializationFunc),
                                s_userMaterializationFuncInvokeMethod,
                                Translator.Shaper_Reader,
                                Expression.NewArrayInit(typeof(object), _initializationArguments)),
                            nex.Type);
                    }
 
                    return nex;
                }
 
                return base.Visit(exp);
            }
 
            internal override Expression VisitConditional(ConditionalExpression c)
            {
                if (_userExpressionDepth >= 1 && _userArgumentType != null)
                {
                    var test = c.Test as MethodCallExpression;
                    var ifFalse = c.IfFalse as MethodCallExpression;
 
                    // We can optimize the path that checks for DbNull and then
                    // reads a value directly off the reader or invokes another user expression.
                    if (test != null && test.Object != null
                        && typeof(DbDataReader).IsAssignableFrom(test.Object.Type)
                        && test.Method.Name == "IsDBNull")
                    {
                        if (ifFalse != null && (ifFalse.Object != null && typeof(DbDataReader).IsAssignableFrom(ifFalse.Object.Type) || IsUserExpressionMethod(ifFalse.Method)))
                        {
                            return base.VisitConditional(c);
                        }
                    }
 
                    // If there's something more complicated then we have to replace it all.
                    // We can't just replace the false expression because it may not be evaluated
                    // if the test returns true.
                    return this.CreateInitializationArgumentReplacement(c);
                }
 
                return base.VisitConditional(c);
            }
 
            internal override Expression VisitMemberAccess(MemberExpression m)
            {
                if (_userExpressionDepth >= 1)
                {
                    // Sometimes we will add expressions inside of a user expression that is actually
                    // our code, but we need to rewrite it since it accesses the shaper's reader to check if a column is null.
                    // e.g. Select(x => new { Y = new Entity { Name = x.Name } })
                    // -> new f<>__AnonymousType`1(IIF($shaper.Reader.IsDbNull(0), null, new Entity { Name = $shaper.Reader.GetString(0) }))
                    if (typeof(DbDataReader).IsAssignableFrom(m.Type))
                    {
                        var shaper = m.Expression as ParameterExpression;
                        if (shaper != null && shaper == Translator.Shaper_Parameter)
                        {
                            return _reader;
                        }
                    }
                }
 
                return base.VisitMemberAccess(m);
            }
 
            internal override Expression VisitMemberInit(MemberInitExpression init)
            {
                if (_userExpressionDepth >= 1)
                {
                    var newMemberInit = base.VisitMemberInit(init);
 
                    // Only compile into a delegate if this is the top-level user expression.
                    if (newMemberInit != init && _userExpressionDepth == 1)
                    {
                        var userMaterializationFunc = Expression.Lambda<Func<DbDataReader, object[], object>>(newMemberInit, _reader, _values).Compile();
 
                        // Convert the object initializer into a func that runs without elevated permissions.
                        return Expression.Convert(
                            Expression.Call(
                                Expression.Constant(userMaterializationFunc),
                                s_userMaterializationFuncInvokeMethod,
                                Translator.Shaper_Reader,
                                Expression.NewArrayInit(typeof(object), _initializationArguments)),
                            init.Type);
                    }
                    else
                    {
                        return newMemberInit;
                    }
                }
 
                return base.VisitMemberInit(init);
            }
 
            internal override MemberAssignment VisitMemberAssignment(MemberAssignment assignment)
            {
                if (_userExpressionDepth >= 1)
                {
                    var fieldInfo = assignment.Member as FieldInfo;
                    var propertyInfo = assignment.Member as PropertyInfo;
                    if (fieldInfo != null)
                    {
                        _userArgumentType = fieldInfo.FieldType;
                    }
                    else if (propertyInfo != null)
                    {
                        _userArgumentType = propertyInfo.PropertyType;
                    }
                }
 
                return base.VisitMemberAssignment(assignment);
            }
 
            internal override Expression VisitMethodCall(MethodCallExpression m)
            {
                var method = m.Method;
                if (IsUserExpressionMethod(method))
                {
                    Debug.Assert(
                        m.Arguments.Count == 1,
                        "m.Arguments.Count == 1",
                        "There should be one expression argument provided to the user expression marker.");
 
                    try
                    {
                        // Clear this type because we are about to process a user expression
                        _userArgumentType = null;
 
                        _userExpressionDepth++;
                        return this.Visit(m.Arguments[0]);
                    }
                    finally
                    {
                        _userExpressionDepth--;
                    }
                }
                else if (_userExpressionDepth >= 1)
                {
                    // If this method call is on a DbDataReader then we can replace it; otherwise,
                    // assume it's something on the shaper and extract the value into the values array.
                    if (m.Object != null && typeof(DbDataReader).IsAssignableFrom(m.Object.Type))
                    {
                        return base.VisitMethodCall(m);
                    }
 
                    return this.CreateInitializationArgumentReplacement(m);
                }
 
                return base.VisitMethodCall(m);
            }
 
            private Expression CreateInitializationArgumentReplacement(Expression original)
            {
                return this.CreateInitializationArgumentReplacement(original, original.Type);
            }
 
            private Expression CreateInitializationArgumentReplacement(Expression original, Type expressionType)
            {
                _initializationArguments.Add(Expression.Convert(original, typeof(object)));
                
                return Expression.Convert(
                    Expression.MakeBinary(ExpressionType.ArrayIndex, _values, Expression.Constant(_initializationArguments.Count - 1)),
                    expressionType);
            }
 
            private static bool IsUserExpressionMethod(MethodInfo method)
            {
                return method.IsGenericMethod && method.GetGenericMethodDefinition() == InitializerMetadata.UserExpressionMarker;
            }
        }
        #endregion
    }
}