File: System\ServiceModel\ComIntegration\TypeCacheManager.cs
Project: ndp\cdf\src\WCF\ServiceModel\System.ServiceModel.csproj (System.ServiceModel)
//-----------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//-----------------------------------------------------------------------------
namespace System.ServiceModel.ComIntegration
{
    using System;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Globalization;
    using System.Reflection;
    using System.Runtime;
    using System.Runtime.Diagnostics;
    using System.Runtime.InteropServices;
    using System.Runtime.InteropServices.ComTypes;
    using System.ServiceModel.Diagnostics;
    using System.Threading;
    using Microsoft.Win32;
 
    class TypeCacheManager : ITypeCacheManager
    {
        enum RegKind
        {
            Default = 0,
            Register = 1,
            None = 2
        }
 
 
        // TypeCacheManager.Provider will give access to the static instance of the TypeCache
        static Guid clrAssemblyCustomID = new Guid("90883F05-3D28-11D2-8F17-00A0C9A6186D");
        static object instanceLock = new object();
 
 
        static public ITypeCacheManager Provider
        {
            get
            {
                lock (instanceLock)
                {
                    if (instance == null)
                    {
                        ITypeCacheManager localInstance = new TypeCacheManager();
                        Thread.MemoryBarrier();
                        instance = localInstance;
                    }
                }
 
                return instance;
            }
        }
        static internal ITypeCacheManager instance;
 
 
        // Convert to typeLibrary ID (GUID)
        private Dictionary<Guid, Assembly> assemblyTable;
        private Dictionary<Guid, Type> typeTable;
        private object typeTableLock;
        private object assemblyTableLock;
 
        internal TypeCacheManager()
        {
            assemblyTable = new Dictionary<Guid, Assembly>();
            typeTable = new Dictionary<Guid, Type>();
            typeTableLock = new object();
            assemblyTableLock = new object();
        }
        private Guid GettypeLibraryIDFromIID(Guid iid, bool isServer, out String version)
        {
            // In server we need to open the the User hive for the Process User.
            RegistryKey interfaceKey = null;
            try
            {
                string keyName = null;
                if (isServer)
                {
                    keyName = String.Concat("software\\classes\\interface\\{", iid.ToString(), "}\\typelib");
                    interfaceKey = Registry.LocalMachine.OpenSubKey(keyName, false);
                }
                else
                {
                    keyName = String.Concat("interface\\{", iid.ToString(), "}\\typelib");
                    interfaceKey = Registry.ClassesRoot.OpenSubKey(keyName, false);
                }
                if (interfaceKey == null)
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.InterfaceNotRegistered)));
                string typeLibID = interfaceKey.GetValue("").ToString();
                if (string.IsNullOrEmpty(typeLibID))
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.NoTypeLibraryFoundForInterface)));
                version = interfaceKey.GetValue("Version").ToString();
                if (string.IsNullOrEmpty(version))
                    version = "1.0";
 
                Guid typeLibraryID;
                if (!DiagnosticUtility.Utility.TryCreateGuid(typeLibID, out typeLibraryID))
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.BadInterfaceRegistration)));
                }
                return typeLibraryID;
 
            }
            finally
            {
                if (interfaceKey != null)
                    interfaceKey.Close();
            }
 
        }
        private void ParseVersion(string version, bool parseVersionAsHex, out ushort major, out ushort minor)
        {
            NumberStyles numberStyle = (parseVersionAsHex) ? NumberStyles.HexNumber : NumberStyles.None;
            major = 0;
            minor = 0;
            if (String.IsNullOrEmpty(version))
                return;
            int indexOfDot = version.IndexOf(".", StringComparison.Ordinal);
            try
            {
 
                if (indexOfDot == -1)
                {
                    major = ushort.Parse(version, numberStyle, NumberFormatInfo.InvariantInfo);
                    minor = 0;
                }
                else
                {
                    major = ushort.Parse(version.Substring(0, indexOfDot), numberStyle, NumberFormatInfo.InvariantInfo); 
                    string minorVersion = version.Substring(indexOfDot + 1);
                    int indexOfDot2 = minorVersion.IndexOf(".", StringComparison.Ordinal);
 
                    if (indexOfDot2 != -1) // Ignore anything beyond the first minor version.
                        minorVersion = minorVersion.Substring(0, indexOfDot2);
 
                    minor = ushort.Parse(minorVersion, numberStyle, NumberFormatInfo.InvariantInfo);
                }
            }
            catch (FormatException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.BadInterfaceVersion)));
            }
            catch (OverflowException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.BadInterfaceVersion)));
            }
 
        }
        private ITypeLib2 GettypeLibrary(Guid typeLibraryID, string version, bool parseVersionAsHex)
        {
            ushort major = 0;
            ushort minor = 0;
            const int lcidLocalIndependent = 0;
            ParseVersion(version, parseVersionAsHex, out major, out minor);
            object otlb;
            int hr = SafeNativeMethods.LoadRegTypeLib(ref typeLibraryID, major, minor, lcidLocalIndependent, out otlb);
            if (hr != 0 || null == otlb)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new COMException(SR.GetString(SR.FailedToLoadTypeLibrary), hr));
            return otlb as ITypeLib2;
 
 
        }
 
 
        private Assembly ResolveAssemblyFromIID(Guid iid, bool noAssemblyGeneration, bool isServer)
        {
 
            String version;
            Guid typeLibraryID = GettypeLibraryIDFromIID(iid, isServer, out version);
 
            return ResolveAssemblyFromTypeLibID(iid, typeLibraryID, version, true, noAssemblyGeneration);
 
        }
 
        private Assembly ResolveAssemblyFromTypeLibID(Guid iid, Guid typeLibraryID, string version, bool parseVersionAsHex, bool noAssemblyGeneration)
        {
            ComPlusTLBImportTrace.Trace(TraceEventType.Verbose, TraceCode.ComIntegrationTLBImportStarting,
                                           SR.TraceCodeComIntegrationTLBImportStarting, iid, typeLibraryID);
            Assembly asm;
 
            bool generateNativeAssembly = false;
            ITypeLib2 typeLibrary = null;
 
            try
            {
                lock (assemblyTableLock)
                {
                    assemblyTable.TryGetValue(typeLibraryID, out asm);
                    if (asm == null)
                    {
                        typeLibrary = GettypeLibrary(typeLibraryID, version, parseVersionAsHex);
                        object opaqueData = null;
                        typeLibrary.GetCustData(ref clrAssemblyCustomID, out opaqueData);
                        if (opaqueData == null)
                            generateNativeAssembly = true;      // No custom data for this IID this is not a CLR typeLibrary
                        String assembly = opaqueData as String;
                        if (String.IsNullOrEmpty(assembly))
                            generateNativeAssembly = true;      // No custom data for this IID this is not a CLR typeLibrary
                        if (noAssemblyGeneration && generateNativeAssembly)
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.NativeTypeLibraryNotAllowed, typeLibraryID)));
                        else if (!generateNativeAssembly)
                        {
                            ComPlusTLBImportTrace.Trace(TraceEventType.Verbose, TraceCode.ComIntegrationTLBImportFromAssembly,
                                          SR.TraceCodeComIntegrationTLBImportFromAssembly, iid, typeLibraryID, assembly);
                            asm = Assembly.Load(assembly);            // Assembly.Load will get a full assembly name
                        }
                        else
                        {
                            ComPlusTLBImportTrace.Trace(TraceEventType.Verbose, TraceCode.ComIntegrationTLBImportFromTypelib,
                                               SR.TraceCodeComIntegrationTLBImportFromTypelib, iid, typeLibraryID);
                            asm = TypeLibraryHelper.GenerateAssemblyFromNativeTypeLibrary(iid, typeLibraryID, typeLibrary as ITypeLib);
                        }
 
                        assemblyTable[typeLibraryID] = asm;
                    }
                }
            }
            catch (Exception e)
            {
                DiagnosticUtility.EventLog.LogEvent(TraceEventType.Error,
                    (ushort)System.Runtime.Diagnostics.EventLogCategory.ComPlus,
                    (uint)System.Runtime.Diagnostics.EventLogEventId.ComPlusTLBImportError,
                    iid.ToString(),
                    typeLibraryID.ToString(),
                    e.ToString());
                throw;
            }
            finally
            {
 
                // Add Try Finally to cleanup typeLibrary
                if (typeLibrary != null)
                    Marshal.ReleaseComObject((object)typeLibrary);
            }
 
            if (null == asm)
            {
                throw Fx.AssertAndThrow("Assembly should not be null");
            }
            ComPlusTLBImportTrace.Trace(TraceEventType.Verbose, TraceCode.ComIntegrationTLBImportFinished,
                               SR.TraceCodeComIntegrationTLBImportFinished, iid, typeLibraryID);
            return asm;
        }
        private bool NoCoClassAttributeOnType(ICustomAttributeProvider attrProvider)
        {
            object[] attrs = System.ServiceModel.Description.ServiceReflector.GetCustomAttributes(attrProvider, typeof(CoClassAttribute), false);
            if (attrs.Length == 0)
                return true;
            else
                return false;
        }
 
        Assembly ITypeCacheManager.ResolveAssembly(Guid assembly)
        {
            Assembly ret = null;
 
            lock (assemblyTableLock)
            {
                this.assemblyTable.TryGetValue(assembly, out ret);
            }
 
            return ret;
        }
 
        void ITypeCacheManager.FindOrCreateType(Guid typeLibId, string typeLibVersion, Guid typeDefId, out Type userDefinedType, bool noAssemblyGeneration)
        {
            lock (typeTableLock)
            {
                typeTable.TryGetValue(typeDefId, out userDefinedType);
                if (userDefinedType == null)
                {
                    Assembly asm = ResolveAssemblyFromTypeLibID(Guid.Empty, typeLibId, typeLibVersion, false, noAssemblyGeneration);
                    foreach (Type t in asm.GetTypes())
                    {
                        if (t.GUID == typeDefId)
                        {
                            if (t.IsValueType)
                            {
                                userDefinedType = t;
                                break;
                            }
                        }
                    }
                    if (userDefinedType == null)
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.UdtNotFoundInAssembly, typeDefId)));
 
                    typeTable[typeDefId] = userDefinedType;
 
                }
            }
        }
 
 
        public void FindOrCreateType(Guid iid, out Type interfaceType, bool noAssemblyGeneration, bool isServer)
        {
            lock (typeTableLock)
            {
                typeTable.TryGetValue(iid, out interfaceType);
                if (interfaceType == null)
                {
                    Type coClassInterface = null;
                    Assembly asm = ResolveAssemblyFromIID(iid, noAssemblyGeneration, isServer);
                    foreach (Type t in asm.GetTypes())
                    {
                        if (t.GUID == iid)
                        {
                            if (t.IsInterface && NoCoClassAttributeOnType(t))
                            {
                                interfaceType = t;
                                break;
 
                            }
                            else if (t.IsInterface && !NoCoClassAttributeOnType(t))
                            {
                                coClassInterface = t;
                            }
                        }
 
                    }
                    if ((interfaceType == null) && (coClassInterface != null))
                        interfaceType = coClassInterface;
                    else if (interfaceType == null)
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.InterfaceNotFoundInAssembly)));
 
                    typeTable[iid] = interfaceType;
                }
            }
        }
        void ITypeCacheManager.FindOrCreateType(Type serverType, Guid iid, out Type interfaceType, bool noAssemblyGeneration, bool isServer)
        {
            interfaceType = null;
            if (serverType == null)
                FindOrCreateType(iid, out interfaceType, noAssemblyGeneration, isServer);
            else
            {
                if (!serverType.IsClass)
                {
                    throw Fx.AssertAndThrow("This should be a class");
                }
                foreach (Type interfaceInType in serverType.GetInterfaces())
                {
                    if (interfaceInType.GUID == iid)
                    {
                        interfaceType = interfaceInType;
                        break;
                    }
                }
                if (interfaceType == null)
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.InterfaceNotFoundInAssembly)));
            }
 
        }
 
        public static Type ResolveClsidToType(Guid clsid)
        {
            string keyName = String.Concat("software\\classes\\clsid\\{", clsid.ToString(), "}\\InprocServer32");
            using (RegistryKey clsidKey = Registry.LocalMachine.OpenSubKey(keyName, false))
            {
                if (clsidKey != null)
                {
                    using (RegistryKey assemblyKey = clsidKey.OpenSubKey(typeof(TypeCacheManager).Assembly.ImageRuntimeVersion))
                    {
                        string assemblyName = null;
                        if (assemblyKey == null)
                        {
                            keyName = null;
                            foreach (string subKeyName in clsidKey.GetSubKeyNames())
                            {
                                keyName = subKeyName;
                                if (String.IsNullOrEmpty(keyName))
                                    continue;
                                using (RegistryKey assemblyKeyAny = clsidKey.OpenSubKey(keyName))
                                {
                                    assemblyName = (string)assemblyKeyAny.GetValue("Assembly");
                                    if (String.IsNullOrEmpty(assemblyName))
                                        continue;
                                    else
                                        break;
                                }
                            }
                        }
                        else
                        {
                            assemblyName = (string)assemblyKey.GetValue("Assembly");
                        }
                        if (String.IsNullOrEmpty(assemblyName))
                            return null;
                        Assembly asm = Assembly.Load(assemblyName);
                        foreach (Type type in asm.GetTypes())
                        {
                            if (type.IsClass && (type.GUID == clsid))
                                return type;
                        }
                        return null;
                    }
                }
 
            }
            // We failed to get the hive information from a native process hive lets go for the alternative bitness
 
            using (RegistryHandle hkcr = RegistryHandle.GetBitnessHKCR(IntPtr.Size == 8 ? false : true))
            {
                if (hkcr != null)
                {
                    using (RegistryHandle clsidKey = hkcr.OpenSubKey(String.Concat("CLSID\\{", clsid.ToString(), "}\\InprocServer32")))
                    {
                        using (RegistryHandle assemblyKey = clsidKey.OpenSubKey(typeof(TypeCacheManager).Assembly.ImageRuntimeVersion))
                        {
                            string assemblyName = null;
                            if (assemblyKey == null)
                            {
                                keyName = null;
                                foreach (string subKeyName in clsidKey.GetSubKeyNames())
                                {
                                    keyName = subKeyName;
                                    if (String.IsNullOrEmpty(keyName))
                                        continue;
                                    using (RegistryHandle assemblyKeyAny = clsidKey.OpenSubKey(keyName))
                                    {
                                        assemblyName = (string)assemblyKeyAny.GetStringValue("Assembly");
                                        if (String.IsNullOrEmpty(assemblyName))
                                            continue;
                                        else
                                            break;
                                    }
                                }
                            }
                            else
                            {
                                assemblyName = assemblyKey.GetStringValue("Assembly");
                            }
                            if (String.IsNullOrEmpty(assemblyName))
                                return null;
                            Assembly asm = Assembly.Load(assemblyName);
                            foreach (Type type in asm.GetTypes())
                            {
                                if (type.IsClass && (type.GUID == clsid))
                                    return type;
                            }
                            return null;
                        }
                    }
                }
 
            }
            return null;
        }
 
        internal Type VerifyType(Guid iid)
        {
            Type interfaceType;
            ((ITypeCacheManager)(this)).FindOrCreateType(iid, out interfaceType, false, true);
            return interfaceType;
        }
    }
}