File: fx\src\data\System\Data\SqlClient\SqlCommandSet.cs
Project: ndp\System.Data.csproj (System.Data)
//------------------------------------------------------------------------------
// <copyright file="SqlBatchCommand.cs" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
// <owner current="true" primary="true">Microsoft</owner>
// <owner current="true" primary="false">Microsoft</owner>
//------------------------------------------------------------------------------
 
namespace System.Data.SqlClient {
 
    using System;
    using System.Collections.Generic;
    using System.ComponentModel;
    using System.Data;
    using System.Data.Common;
    using System.Diagnostics;
    using System.Globalization;
    using System.Text;
    using System.Text.RegularExpressions;
 
    internal sealed class SqlCommandSet {
 
        private const string SqlIdentifierPattern = "^@[\\p{Lo}\\p{Lu}\\p{Ll}\\p{Lm}_@#][\\p{Lo}\\p{Lu}\\p{Ll}\\p{Lm}\\p{Nd}\uff3f_@#\\$]*$";
        private static readonly Regex SqlIdentifierParser = new Regex(SqlIdentifierPattern, RegexOptions.ExplicitCapture|RegexOptions.Singleline);
 
        private List<LocalCommand> _commandList = new List<LocalCommand>();
 
        private SqlCommand _batchCommand;
 
        private static int _objectTypeCount; // Bid counter
        internal readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
 
        private sealed class LocalCommand {
            internal readonly string CommandText;
            internal readonly SqlParameterCollection Parameters;
            internal readonly int ReturnParameterIndex;
            internal readonly CommandType CmdType;
            internal readonly SqlCommandColumnEncryptionSetting ColumnEncryptionSetting;
 
            internal LocalCommand(string commandText, SqlParameterCollection parameters,  int returnParameterIndex, CommandType cmdType, SqlCommandColumnEncryptionSetting columnEncryptionSetting) {
                Debug.Assert(0 <= commandText.Length, "no text");
                this.CommandText = commandText;
                this.Parameters = parameters;
                this.ReturnParameterIndex = returnParameterIndex;
                this.CmdType = cmdType;
                this.ColumnEncryptionSetting = columnEncryptionSetting;
            }
        }
 
        internal SqlCommandSet() : base() {
            _batchCommand = new SqlCommand();
        }
 
        private SqlCommand BatchCommand {
            get {
                SqlCommand command = _batchCommand;
                if (null == command) {
                    throw ADP.ObjectDisposed(this);
                }
                return command;
            }
        }
 
        internal int CommandCount {
            get {
                return CommandList.Count;
            }
        }
 
        private List<LocalCommand> CommandList {
            get {
                List<LocalCommand> commandList = _commandList;
                if (null == commandList) {
                    throw ADP.ObjectDisposed(this);
                }
                return commandList;
            }
        }
 
        internal int CommandTimeout {
            /*get {
                return BatchCommand.CommandTimeout;
            }*/
            set {
                BatchCommand.CommandTimeout = value;
            }
        }
 
        internal SqlConnection Connection {
            get {
                return BatchCommand.Connection;
            }
            set {
                BatchCommand.Connection = value;
            }
        }
 
        internal SqlTransaction Transaction {
            /*get {
                return BatchCommand.Transaction;
            }*/
            set {
                BatchCommand.Transaction = value;
            }
        }
 
        internal int ObjectID {
            get {
                return _objectID;
            }
        }
 
        internal void Append(SqlCommand command) {
            ADP.CheckArgumentNull(command, "command");
            Bid.Trace("<sc.SqlCommandSet.Append|API> %d#, command=%d, parameterCount=%d\n", ObjectID, command.ObjectID, command.Parameters.Count);
 
            string cmdText = command.CommandText;
            if (ADP.IsEmpty(cmdText)) {
                throw ADP.CommandTextRequired(ADP.Append);
            }
 
            CommandType commandType = command.CommandType;
            switch(commandType) {
            case CommandType.Text:
            case CommandType.StoredProcedure:
                break;
            case CommandType.TableDirect:
                Debug.Assert(false, "command.CommandType");
                throw System.Data.SqlClient.SQL.NotSupportedCommandType(commandType);
            default:
                Debug.Assert(false, "command.CommandType");
                throw ADP.InvalidCommandType(commandType);
            }
 
            SqlParameterCollection parameters = null;
 
            SqlParameterCollection collection = command.Parameters;
            if (0 < collection.Count) {
                parameters = new SqlParameterCollection();
 
                // clone parameters so they aren't destroyed
                for(int i = 0; i < collection.Count; ++i) {
                    SqlParameter p = new SqlParameter();
                    collection[i].CopyTo(p);
                    parameters.Add(p);
 
                    // SQL Injection awarene
                    if (!SqlIdentifierParser.IsMatch(p.ParameterName)) {
                        throw ADP.BadParameterName(p.ParameterName);
                    }
                }
 
                foreach(SqlParameter p in parameters) {
                    // deep clone the parameter value if byte[] or char[]
                    object obj = p.Value;
                    byte[] byteValues = (obj as byte[]);
                    if (null != byteValues) {
                        int offset = p.Offset;
                        int size = p.Size;
                        int countOfBytes = byteValues.Length - offset;
                        if ((0 != size) && (size < countOfBytes)) {
                            countOfBytes = size;
                        }
                        byte[] copy = new byte[Math.Max(countOfBytes, 0)];
                        Buffer.BlockCopy(byteValues, offset, copy, 0, copy.Length);
                        p.Offset = 0;
                        p.Value = copy;
                    }
                    else {
                        char[] charValues = (obj as char[]);
                        if (null != charValues) {
                            int offset = p.Offset;
                            int size = p.Size;
                            int countOfChars = charValues.Length - offset;
                            if ((0 != size) && (size < countOfChars)) {
                                countOfChars = size;
                            }
                            char[] copy = new char[Math.Max(countOfChars, 0)];
                            Buffer.BlockCopy(charValues, offset, copy, 0, copy.Length*2);
                            p.Offset = 0;
                            p.Value = copy;
                        }
                        else {
                            ICloneable cloneable = (obj as ICloneable);
                            if (null != cloneable) {
                                p.Value = cloneable.Clone();
                            }
                        }
                    }
                }
            }
 
            int returnParameterIndex = -1;
            if (null != parameters) {
                for(int i = 0; i < parameters.Count; ++i) {
                    if (ParameterDirection.ReturnValue == parameters[i].Direction) {
                        returnParameterIndex = i;
                        break;
                    }
                }
            }
            LocalCommand cmd = new LocalCommand(cmdText, parameters, returnParameterIndex, command.CommandType, command.ColumnEncryptionSetting);
            CommandList.Add(cmd);
        }
 
        internal static void BuildStoredProcedureName(StringBuilder builder, string part) {
            if ((null != part) && (0 < part.Length)) {
                if ('[' == part[0]) {
                    int count = 0;
                    foreach(char c in part) {
                        if (']' == c) {
                            count++;
                        }
                    }
                    if (1 == (count%2)) {
                        builder.Append(part);
                        return;
                    }
                }
 
                // the part is not escaped, escape it now
                SqlServerEscapeHelper.EscapeIdentifier(builder, part);
            }
        }
 
        internal void Clear() {
            Bid.Trace("<sc.SqlCommandSet.Clear|API> %d#\n", ObjectID);
            DbCommand batchCommand = BatchCommand;
            if (null != batchCommand) {
                batchCommand.Parameters.Clear();
                batchCommand.CommandText = null;
            }
            List<LocalCommand> commandList = _commandList;
            if (null != commandList) {
                commandList.Clear();
            }
        }
 
        internal void Dispose() {
            Bid.Trace("<sc.SqlCommandSet.Dispose|API> %d#\n", ObjectID);
            SqlCommand command = _batchCommand;
            _commandList = null;
            _batchCommand = null;
 
            if (null != command) {
                command.Dispose();
            }
        }
 
        internal int ExecuteNonQuery() {
            SqlConnection.ExecutePermission.Demand();
 
            IntPtr hscp;
            Bid.ScopeEnter(out hscp, "<sc.SqlCommandSet.ExecuteNonQuery|API> %d#", ObjectID);
            try {
                if (Connection.IsContextConnection) {
                    throw SQL.BatchedUpdatesNotAvailableOnContextConnection();
                }
                ValidateCommandBehavior(ADP.ExecuteNonQuery, CommandBehavior.Default);
                BatchCommand.BatchRPCMode = true;
                BatchCommand.ClearBatchCommand();
                BatchCommand.Parameters.Clear();
                for (int ii = 0 ; ii < _commandList.Count; ii++) {
                    LocalCommand cmd = _commandList[ii];
                    BatchCommand.AddBatchCommand(cmd.CommandText, cmd.Parameters, cmd.CmdType, cmd.ColumnEncryptionSetting);
                }
                return BatchCommand.ExecuteBatchRPCCommand();
            }
            finally {
                Bid.ScopeLeave(ref hscp);
            }
        }
 
        internal SqlParameter GetParameter(int commandIndex, int parameterIndex) {
            return CommandList[commandIndex].Parameters[parameterIndex];
        }
 
        internal bool GetBatchedAffected(int commandIdentifier, out int recordsAffected, out Exception error) {
            error = BatchCommand.GetErrors(commandIdentifier);
            int? affected = BatchCommand.GetRecordsAffected(commandIdentifier);
            recordsAffected = affected.GetValueOrDefault();
            return affected.HasValue;
        }
 
        internal int GetParameterCount(int commandIndex) {
            return CommandList[commandIndex].Parameters.Count;
        }
 
        private void ValidateCommandBehavior(string method, CommandBehavior behavior) {
            if (0 != (behavior & ~(CommandBehavior.SequentialAccess|CommandBehavior.CloseConnection))) {
                ADP.ValidateCommandBehavior(behavior);
                throw ADP.NotSupportedCommandBehavior(behavior & ~(CommandBehavior.SequentialAccess|CommandBehavior.CloseConnection), method);
            }
        }
    }
}