File: DataAccess\SqlConnectionHelper.cs
Project: ndp\fx\src\xsp\system\Web\System.Web.csproj (System.Web)
//------------------------------------------------------------------------------
// <copyright file="SqlConnectionHelper.cs" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
//------------------------------------------------------------------------------
 
namespace System.Web.DataAccess {
 
    using System;
    using System.Collections.Specialized;
    using System.Configuration;
    using System.Configuration.Provider;
    using System.Data;
    using System.Data.SqlClient;
    using System.Diagnostics;
    using System.Globalization;
    using System.IO;
    using System.Security.Permissions;
    using System.Threading;
    using System.Web.Configuration;
    using System.Web.Hosting;
    using System.Web.Management;
    using System.Web.Util;
 
    internal static class SqlConnectionHelper {
        internal const string s_strDataDir = "DataDirectory";
        internal const string s_strUpperDataDirWithToken = "|DATADIRECTORY|";
        internal const string s_strSqlExprFileExt = ".MDF";
        internal const string s_strUpperUserInstance = "USER INSTANCE";
        private const string s_localDbName = "(LOCALDB)";
        private static object s_lock = new object();
 
        internal static void EnsureNoUserInstance(string connectionString) {
            SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString);
            if (builder.UserInstance) {
                throw new ProviderException(SR.GetString(SR.LocalDB_cannot_have_userinstance_flag));
            }
        }
 
        internal static SqlConnectionHolder GetConnection(string connectionString, bool revertImpersonation) {
            string strTempConnection = connectionString.ToUpperInvariant();
            if (strTempConnection.Contains(s_strUpperDataDirWithToken)) {
                EnsureDBFile(connectionString);
            }
 
            // Only block UserInstance for LocalDB connections
            if (strTempConnection.Contains(s_localDbName)) {
                EnsureNoUserInstance(connectionString);
            }
 
            SqlConnectionHolder holder = new SqlConnectionHolder(connectionString);
            bool closeConn = true;
            try {
                try {
                    holder.Open(null, revertImpersonation);
                    closeConn = false;
                }
                finally {
                    if (closeConn) {
                        holder.Close();
                        holder = null;
                    }
                }
            }
            catch {
                throw;
            }
            return holder;
        }
 
        internal static string GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel) {
            System.Web.Util.Debug.Assert((specifiedConnectionString != null) && (specifiedConnectionString.Length != 0));
            if (specifiedConnectionString == null || specifiedConnectionString.Length < 1)
                return null;
 
            string connectionString = null;
 
            // Step 1: Check <connectionStrings> config section for this connection string
            if (lookupConnectionString) {
                RuntimeConfig config = (appLevel) ? RuntimeConfig.GetAppConfig() : RuntimeConfig.GetConfig();
                ConnectionStringSettings connObj = config.ConnectionStrings.ConnectionStrings[specifiedConnectionString];
                if (connObj != null)
                    connectionString = connObj.ConnectionString;
 
                if (connectionString == null)
                    return null;
 
                //HandlerBase.CheckAndReadRegistryValue (ref connectionString, true);
            }
            else {
                connectionString = specifiedConnectionString;
            }
 
            return connectionString;
        }
 
        [PermissionSet(SecurityAction.Assert, Unrestricted = true)]
        internal static string GetDataDirectory() {
            if (HostingEnvironment.IsHosted)
                return Path.Combine(HttpRuntime.AppDomainAppPath, HttpRuntime.DataDirectoryName);
 
            string dataDir = AppDomain.CurrentDomain.GetData(s_strDataDir) as string;
            if (string.IsNullOrEmpty(dataDir)) {
                string appPath = null;
 
#if !FEATURE_PAL // FEATURE_PAL does not support ProcessModule
                Process p = Process.GetCurrentProcess();
                ProcessModule pm = (p != null ? p.MainModule : null);
                string exeName = (pm != null ? pm.FileName : null);
 
                if (!string.IsNullOrEmpty(exeName))
                    appPath = Path.GetDirectoryName(exeName);
#endif // !FEATURE_PAL
 
                if (string.IsNullOrEmpty(appPath))
                    appPath = Environment.CurrentDirectory;
 
                dataDir = Path.Combine(appPath, HttpRuntime.DataDirectoryName);
                AppDomain.CurrentDomain.SetData(s_strDataDir, dataDir, new FileIOPermission(FileIOPermissionAccess.PathDiscovery, dataDir));
            }
 
            return dataDir;
        }
 
        private static void EnsureDBFile(string connectionString) {
            string partialFileName = null;
            string fullFileName = null;
            string dataDir = GetDataDirectory();
            bool lookingForDataDir = true;
            bool lookingForDB = true;
            string[] splitedConnStr = connectionString.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
            bool lookingForUserInstance = !connectionString.ToUpperInvariant().Contains(s_localDbName); // We don't require UserInstance=True for LocalDb
            bool lookingForTimeout = true;
 
            foreach (string str in splitedConnStr) {
                string strUpper = str.ToUpper(CultureInfo.InvariantCulture).Trim();
 
                if (lookingForDataDir && strUpper.Contains(s_strUpperDataDirWithToken)) {
                    lookingForDataDir = false;
 
                    // Replace the AttachDBFilename part with "Pooling=false"
                    connectionString = connectionString.Replace(str, "Pooling=false");
 
                    // Extract the filenames
                    int startPos = strUpper.IndexOf(s_strUpperDataDirWithToken, StringComparison.Ordinal) + s_strUpperDataDirWithToken.Length;
                    partialFileName = strUpper.Substring(startPos).Trim();
                    while (partialFileName.StartsWith("\\", StringComparison.Ordinal))
                        partialFileName = partialFileName.Substring(1);
                    if (partialFileName.Contains("..")) // don't allow it to traverse-up
                        partialFileName = null;
                    else
                        fullFileName = Path.Combine(dataDir, partialFileName);
                    if (!lookingForDB)
                        break; // done
                }
                else if (lookingForDB && (strUpper.StartsWith("INITIAL CATALOG", StringComparison.Ordinal) || strUpper.StartsWith("DATABASE", StringComparison.Ordinal))) {
                    lookingForDB = false;
                    connectionString = connectionString.Replace(str, "Database=master");
                    if (!lookingForDataDir)
                        break; // done
                }
                else if (lookingForUserInstance && strUpper.StartsWith(s_strUpperUserInstance, StringComparison.Ordinal)) {
                    lookingForUserInstance = false;
                    int pos = strUpper.IndexOf('=');
                    if (pos < 0)
                        return;
                    string strTemp = strUpper.Substring(pos + 1).Trim();
                    if (strTemp != "TRUE")
                        return;
                }
                else if (lookingForTimeout && strUpper.StartsWith("CONNECT TIMEOUT", StringComparison.Ordinal)) {
                    lookingForTimeout = false;
                }
            }
            if (lookingForUserInstance)
                return;
 
            if (fullFileName == null)
                throw new ProviderException(SR.GetString(SR.SqlExpress_file_not_found_in_connection_string));
 
            if (File.Exists(fullFileName))
                return;
 
            if (!HttpRuntime.HasAspNetHostingPermission(AspNetHostingPermissionLevel.High))
                throw new ProviderException(SR.GetString(SR.Provider_can_not_create_file_in_this_trust_level));
 
            if (!connectionString.Contains("Database=master"))
                connectionString += ";Database=master";
            if (lookingForTimeout)
                connectionString += ";Connect Timeout=45";
            using (new ApplicationImpersonationContext())
                lock (s_lock)
                    if (!File.Exists(fullFileName))
                        CreateMdfFile(fullFileName, dataDir, connectionString);
        }
 
        [PermissionSet(SecurityAction.Assert, Unrestricted = true)]
        private static void CreateMdfFile(string fullFileName, string dataDir, string connectionString) {
            bool creatingDir = false;
            string databaseName = null;
            HttpContext context = HttpContext.Current;
            string tempFileName = null;
 
            try {
                if (!Directory.Exists(dataDir)) {
                    creatingDir = true;
                    Directory.CreateDirectory(dataDir);
                    creatingDir = false;
                    try {
                        if (context != null)
                            HttpRuntime.RestrictIISFolders(context);
                    }
                    catch { }
                }
 
                fullFileName = fullFileName.ToUpper(CultureInfo.InvariantCulture);
                char[] strippedFileNameChars = Path.GetFileNameWithoutExtension(fullFileName).ToCharArray();
                for (int iter = 0; iter < strippedFileNameChars.Length; iter++)
                    if (!char.IsLetterOrDigit(strippedFileNameChars[iter]))
                        strippedFileNameChars[iter] = '_';
                string strippedFileName = new string(strippedFileNameChars);
                if (strippedFileName.Length > 30)
                    databaseName = strippedFileName.Substring(0, 30) + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
                else
                    databaseName = strippedFileName + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
 
                tempFileName = Path.Combine(Path.GetDirectoryName(fullFileName), strippedFileName + "_TMP" + s_strSqlExprFileExt);
 
                // Auto create the temporary database
                SqlServices.Install(databaseName, tempFileName, connectionString);
                DetachDB(databaseName, connectionString);
                try {
                    File.Move(tempFileName, fullFileName);
                }
                catch {
                    if (!File.Exists(fullFileName)) {
                        File.Copy(tempFileName, fullFileName);
                        try {
                            File.Delete(tempFileName);
                        }
                        catch { }
                    }
                }
                try {
                    File.Delete(tempFileName.Replace("_TMP.MDF", "_TMP_log.LDF"));
                }
                catch { }
            }
            catch (Exception e) {
                if (context == null || context.IsCustomErrorEnabled)
                    throw;
                HttpException httpExec = new HttpException(e.Message, e);
                if (e is UnauthorizedAccessException)
                    httpExec.SetFormatter(new SqlExpressConnectionErrorFormatter(creatingDir ? DataConnectionErrorEnum.CanNotCreateDataDir : DataConnectionErrorEnum.CanNotWriteToDataDir));
                else
                    httpExec.SetFormatter(new SqlExpressDBFileAutoCreationErrorFormatter(e));
                throw httpExec;
            }
        }
 
        private static void DetachDB(string databaseName, string connectionString) {
            SqlConnection connection = new SqlConnection(connectionString);
            try {
                connection.Open();
                SqlCommand command = new SqlCommand("USE master", connection);
                command.ExecuteNonQuery();
                command = new SqlCommand("sp_detach_db", connection);
                command.CommandType = CommandType.StoredProcedure;
                command.Parameters.AddWithValue("@dbname", databaseName);
                command.Parameters.AddWithValue("@skipchecks", "true");
                command.ExecuteNonQuery();
            }
            catch {
            }
            finally {
                connection.Close();
            }
        }
    }
 
    internal sealed class SqlConnectionHolder {
        internal SqlConnection _Connection;
        private bool _Opened;
 
        internal SqlConnection Connection {
            get { return _Connection; }
        }
 
        internal SqlConnectionHolder(string connectionString) {
            try {
                _Connection = new SqlConnection(connectionString);
                System.Web.Util.Debug.Assert(_Connection != null);
            }
            catch (ArgumentException e) {
                throw new ArgumentException(SR.GetString(SR.SqlError_Connection_String), "connectionString", e);
            }
        }
 
        internal void Open(HttpContext context, bool revertImpersonate) {
            if (_Opened)
                return; // Already opened
 
            if (revertImpersonate) {
                using (new ApplicationImpersonationContext()) {
                    Connection.Open();
                }
            }
            else {
                Connection.Open();
            }
 
            _Opened = true; // Open worked!
        }
 
        internal void Close() {
            if (!_Opened) // Not open!
                return;
            // Close connection
            Connection.Close();
            _Opened = false;
        }
    }
}