|
//------------------------------------------------------------------------------
// <copyright file="_DynamicWinsockMethods.cs" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//------------------------------------------------------------------------------
using System.Security;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace System.Net.Sockets
{
internal sealed class DynamicWinsockMethods
{
// In practice there will never be more than four of these, so its not worth a complicated
// hash table structure. Store them in a list and search through it.
private static List<DynamicWinsockMethods> s_MethodTable = new List<DynamicWinsockMethods>();
public static DynamicWinsockMethods GetMethods(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
lock (s_MethodTable)
{
DynamicWinsockMethods methods;
for (int i = 0; i < s_MethodTable.Count; i++)
{
methods = s_MethodTable[i];
if (methods.addressFamily == addressFamily && methods.socketType == socketType && methods.protocolType == protocolType)
{
return methods;
}
}
methods = new DynamicWinsockMethods(addressFamily, socketType, protocolType);
s_MethodTable.Add(methods);
return methods;
}
}
private AddressFamily addressFamily;
private SocketType socketType;
private ProtocolType protocolType;
private object lockObject;
private AcceptExDelegate acceptEx;
private GetAcceptExSockaddrsDelegate getAcceptExSockaddrs;
private ConnectExDelegate connectEx;
private TransmitPacketsDelegate transmitPackets;
private DisconnectExDelegate disconnectEx;
private DisconnectExDelegate_Blocking disconnectEx_Blocking;
private WSARecvMsgDelegate recvMsg;
private WSARecvMsgDelegate_Blocking recvMsg_Blocking;
private DynamicWinsockMethods(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
this.addressFamily = addressFamily;
this.socketType = socketType;
this.protocolType = protocolType;
this.lockObject = new object();
}
public T GetDelegate<T>(SafeCloseSocket socketHandle) where T: class
{
if (typeof(T) == typeof(AcceptExDelegate))
{
EnsureAcceptEx(socketHandle);
return (T)(object)acceptEx;
}
else if (typeof(T) == typeof(GetAcceptExSockaddrsDelegate))
{
EnsureGetAcceptExSockaddrs(socketHandle);
return (T)(object)getAcceptExSockaddrs;
}
else if (typeof(T) == typeof(ConnectExDelegate))
{
EnsureConnectEx(socketHandle);
return (T)(object)connectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegate))
{
EnsureDisconnectEx(socketHandle);
return (T)(object)disconnectEx;
}
else if (typeof(T) == typeof(DisconnectExDelegate_Blocking))
{
EnsureDisconnectEx(socketHandle);
return (T)(object)disconnectEx_Blocking;
}
else if (typeof(T) == typeof(WSARecvMsgDelegate))
{
EnsureWSARecvMsg(socketHandle);
return (T)(object)recvMsg;
}
else if (typeof(T) == typeof(WSARecvMsgDelegate_Blocking))
{
EnsureWSARecvMsg(socketHandle);
return (T)(object)recvMsg_Blocking;
}
else if (typeof(T) == typeof(TransmitPacketsDelegate))
{
EnsureTransmitPackets(socketHandle);
return (T)(object)transmitPackets;
}
System.Diagnostics.Debug.Assert(false, "Invalid type passed to DynamicWinsockMethods.GetDelegate");
return null;
}
// private methods to actually load the function pointers
private IntPtr LoadDynamicFunctionPointer(SafeCloseSocket socketHandle, ref Guid guid)
{
IntPtr ptr = IntPtr.Zero;
int length;
SocketError errorCode;
unsafe
{
errorCode = UnsafeNclNativeMethods.OSSOCK.WSAIoctl(
socketHandle,
IoctlSocketConstants.SIOGETEXTENSIONFUNCTIONPOINTER,
ref guid,
sizeof(Guid),
out ptr,
sizeof(IntPtr),
out length,
IntPtr.Zero,
IntPtr.Zero);
}
if (errorCode != SocketError.Success)
{
throw new SocketException();
}
return ptr;
}
private void EnsureAcceptEx(SafeCloseSocket socketHandle)
{
if (acceptEx == null)
{
lock (lockObject)
{
if (acceptEx == null)
{
Guid guid = new Guid("{0xb5367df1,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrAcceptEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
acceptEx = (AcceptExDelegate)Marshal.GetDelegateForFunctionPointer(ptrAcceptEx, typeof(AcceptExDelegate));
}
}
}
}
private void EnsureGetAcceptExSockaddrs(SafeCloseSocket socketHandle)
{
if (getAcceptExSockaddrs == null)
{
lock (lockObject)
{
if (getAcceptExSockaddrs == null)
{
Guid guid = new Guid("{0xb5367df2,0xcbac,0x11cf,{0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}");
IntPtr ptrGetAcceptExSockaddrs = LoadDynamicFunctionPointer(socketHandle, ref guid);
getAcceptExSockaddrs = (GetAcceptExSockaddrsDelegate)Marshal.GetDelegateForFunctionPointer(ptrGetAcceptExSockaddrs,
typeof(GetAcceptExSockaddrsDelegate));
}
}
}
}
private void EnsureConnectEx(SafeCloseSocket socketHandle)
{
if (connectEx == null)
{
lock (lockObject)
{
if (connectEx == null)
{
Guid guid = new Guid("{0x25a207b9,0x0ddf3,0x4660,{0x8e,0xe9,0x76,0xe5,0x8c,0x74,0x06,0x3e}}");
IntPtr ptrConnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
connectEx = (ConnectExDelegate)Marshal.GetDelegateForFunctionPointer(ptrConnectEx, typeof(ConnectExDelegate));
}
}
}
}
private void EnsureDisconnectEx(SafeCloseSocket socketHandle)
{
if (disconnectEx == null)
{
lock (lockObject)
{
if (disconnectEx == null)
{
Guid guid = new Guid("{0x7fda2e11,0x8630,0x436f,{0xa0, 0x31, 0xf5, 0x36, 0xa6, 0xee, 0xc1, 0x57}}");
IntPtr ptrDisconnectEx = LoadDynamicFunctionPointer(socketHandle, ref guid);
disconnectEx = (DisconnectExDelegate)Marshal.GetDelegateForFunctionPointer(ptrDisconnectEx, typeof(DisconnectExDelegate));
disconnectEx_Blocking = (DisconnectExDelegate_Blocking)Marshal.GetDelegateForFunctionPointer(ptrDisconnectEx,
typeof(DisconnectExDelegate_Blocking));
}
}
}
}
private void EnsureWSARecvMsg(SafeCloseSocket socketHandle)
{
if (recvMsg == null)
{
lock (lockObject)
{
if (recvMsg == null)
{
Guid guid = new Guid("{0xf689d7c8,0x6f1f,0x436b,{0x8a,0x53,0xe5,0x4f,0xe3,0x51,0xc3,0x22}}");
IntPtr ptrWSARecvMsg = LoadDynamicFunctionPointer(socketHandle, ref guid);
recvMsg = (WSARecvMsgDelegate)Marshal.GetDelegateForFunctionPointer(ptrWSARecvMsg, typeof(WSARecvMsgDelegate));
recvMsg_Blocking = (WSARecvMsgDelegate_Blocking)Marshal.GetDelegateForFunctionPointer(ptrWSARecvMsg,
typeof(WSARecvMsgDelegate_Blocking));
}
}
}
}
private void EnsureTransmitPackets(SafeCloseSocket socketHandle)
{
if (transmitPackets == null)
{
lock (lockObject)
{
if (transmitPackets == null)
{
Guid guid = new Guid("{0xd9689da0,0x1f90,0x11d3,{0x99,0x71,0x00,0xc0,0x4f,0x68,0xc8,0x76}}");
IntPtr ptrTransmitPackets = LoadDynamicFunctionPointer(socketHandle, ref guid);
transmitPackets = (TransmitPacketsDelegate)Marshal.GetDelegateForFunctionPointer(ptrTransmitPackets,
typeof(TransmitPacketsDelegate));
}
}
}
}
}
[SuppressUnmanagedCodeSecurity]
internal delegate bool AcceptExDelegate(
SafeCloseSocket listenSocketHandle,
SafeCloseSocket acceptSocketHandle,
IntPtr buffer,
int len,
int localAddressLength,
int remoteAddressLength,
out int bytesReceived,
SafeHandle overlapped);
[SuppressUnmanagedCodeSecurity]
internal delegate void GetAcceptExSockaddrsDelegate(
IntPtr buffer,
int receiveDataLength,
int localAddressLength,
int remoteAddressLength,
out IntPtr localSocketAddress,
out int localSocketAddressLength,
out IntPtr remoteSocketAddress,
out int remoteSocketAddressLength);
[SuppressUnmanagedCodeSecurity]
internal delegate bool ConnectExDelegate(
SafeCloseSocket socketHandle,
IntPtr socketAddress,
int socketAddressSize,
IntPtr buffer,
int dataLength,
out int bytesSent,
SafeHandle overlapped);
[SuppressUnmanagedCodeSecurity]
internal delegate bool DisconnectExDelegate(SafeCloseSocket socketHandle, SafeHandle overlapped, int flags, int reserved);
[SuppressUnmanagedCodeSecurity]
internal delegate bool DisconnectExDelegate_Blocking(IntPtr socketHandle, IntPtr overlapped, int flags, int reserved);
[SuppressUnmanagedCodeSecurity]
internal delegate SocketError WSARecvMsgDelegate(
SafeCloseSocket socketHandle,
IntPtr msg,
out int bytesTransferred,
SafeHandle overlapped,
IntPtr completionRoutine);
[SuppressUnmanagedCodeSecurity]
internal delegate SocketError WSARecvMsgDelegate_Blocking(
IntPtr socketHandle,
IntPtr msg,
out int bytesTransferred,
IntPtr overlapped,
IntPtr completionRoutine);
[SuppressUnmanagedCodeSecurity]
internal delegate bool TransmitPacketsDelegate(
SafeCloseSocket socketHandle,
IntPtr packetArray,
int elementCount,
int sendSize,
SafeNativeOverlapped overlapped,
TransmitFileOptions flags);
}
|