File: channels\core\socketmanager.cs
Project: ndp\clr\src\managedlibraries\remoting\System.Runtime.Remoting.csproj (System.Runtime.Remoting)
// ==++==
// 
//   Copyright (c) Microsoft Corporation.  All rights reserved.
// 
// ==--==
//==========================================================================
//  File:       SocketManager.cs
//
//  Summary:    Class for managing a socket connection.
//
//==========================================================================
 
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Remoting.Messaging;
using System.Security.Principal;
using System.Text;
using System.Threading;
 
 
namespace System.Runtime.Remoting.Channels
{
    internal delegate bool ValidateByteDelegate(byte b);
 
 
    internal abstract class SocketHandler
    {
        // socket manager data
        protected Socket NetSocket;    // network socket
        protected Stream NetStream; // network stream
        private DateTime _creationTime;
        private RequestQueue _requestQueue; // request queue to use for this connection
 
        private byte[] _dataBuffer; // buffered data
        private int    _dataBufferSize; // size of data buffer
        private int    _dataOffset; // offset of remaining data in buffer
        private int    _dataCount;  // count of remaining bytes in buffer
 
        private AsyncCallback _beginReadCallback; // callback to use when doing an async read
        private IAsyncResult _beginReadAsyncResult; // async result from doing a begin read
        private WaitCallback _dataArrivedCallback; // callback to signal once data is available
        private Object _dataArrivedCallbackState; // state object to go along with callback
#if !FEATURE_PAL    
        private WindowsIdentity _impersonationIdentity; // Identity to impersonate 
#endif // !FEATURE_PAL
        private byte[] _byteBuffer = new byte[4]; // buffer for reading bytes
 
 
        // control cookie -
        //   The control cookie is used for synchronization when a "user"
        //   wants to retrieve this client socket manager from the socket
        //   cache.
        private int _controlCookie = 1;
 
                    
 
        // hide default constructor        
        private SocketHandler(){}
 
        public SocketHandler(Socket socket, Stream netStream)
        {        
            _beginReadCallback = new AsyncCallback(this.BeginReadMessageCallback);
            _creationTime = DateTime.UtcNow;
        
            NetSocket = socket;
            NetStream = netStream;
            
            _dataBuffer = CoreChannel.BufferPool.GetBuffer();
            _dataBufferSize = _dataBuffer.Length;
            _dataOffset = 0;
            _dataCount = 0;
        } // SocketHandler
 
        internal SocketHandler(Socket socket, RequestQueue requestQueue, Stream netStream) : this(socket, netStream)
        {        
            _requestQueue = requestQueue;
        } // SocketHandler
 
        public DateTime CreationTime { get { return _creationTime; } }
 
        // If this method returns true, then whoever called it can assume control
        //   of the client socket manager. If it returns false, the caller is on
        //   their honor not to do anything further with this object.
        public bool RaceForControl()
        {
            if (1 == Interlocked.Exchange(ref _controlCookie, 0))
                return true;
 
            return false;            
        } // RaceForControl
 
        public void ReleaseControl()
        {
            _controlCookie = 1;
        } // ReleaseControl
 
 
        // Determines if the remote connection is from localhost.
        internal bool IsLocalhost()
        {
            if (NetSocket == null || NetSocket.RemoteEndPoint == null) return true;
            
            IPAddress remoteAddr = ((IPEndPoint)NetSocket.RemoteEndPoint).Address;
            return IPAddress.IsLoopback(remoteAddr) || CoreChannel.IsLocalIpAddress(remoteAddr);
        } // IsLocalhost
 
        // Determines if the remote connection is from localhost.
        internal bool IsLocal()
        {
            if (NetSocket == null) return true;
 
            IPAddress remoteAddr = ((IPEndPoint)NetSocket.RemoteEndPoint).Address;
            return IPAddress.IsLoopback(remoteAddr);
        } // IsLocal
                
        internal bool CustomErrorsEnabled()
        {
            try {
                return RemotingConfiguration.CustomErrorsEnabled(IsLocalhost());
            }
            catch {
                return true;
            }                
        }
 
        // does any necessary cleanup before reading the incoming message
        protected abstract void PrepareForNewMessage();
 
        // allows derived classes to send an error message if the async read
        //   in BeginReadMessage fails.
        protected virtual void SendErrorMessageIfPossible(Exception e)
        {
        }     
 
        // allows socket handler to do something when an input stream it handed
        //   out is closed. The input stream is responsible for calling this method.
        //   (usually, only client socket handlers will do anything with this).
        //   (input stream refers to data being read off of the network)
        public virtual void OnInputStreamClosed()
        {
        }
        
 
        public virtual void Close()
        {
            try
            {
                if (_requestQueue != null)
                    _requestQueue.ScheduleMoreWorkIfNeeded();
 
                if (NetStream != null)
                {
                    NetStream.Close();
                    NetStream = null;
                }
 
                if (NetSocket != null)
                {
                    NetSocket.Close();
                    NetSocket = null;
                }
            }
            finally
            {
                ReturnBufferToPool();
            }
        } // Close
 
        private byte[] DataBuffer
        {
            get
            {
                // Detect and prevent a NullReferenceException if the DataBuffer is
                // accessed after it has been returned to the byte buffer pool.
                // This is a risk mitigation. It doesn't appear to be hit in practice
                // and we should consider removing it in future versions.
                if (_dataBuffer == null)
                {
                    InternalRemotingServices.RemotingAssert(false, "SocketHandler claiming a byte buffer after it has been returned to the pool");
                    _dataBuffer = CoreChannel.BufferPool.GetBuffer();
                }
 
                return _dataBuffer;
            }
        }
 
        protected void ReturnBufferToPool()
        {
            // return buffer to the pool
            if (_dataBuffer != null)
            {
                byte[] bufferToReturn = Interlocked.Exchange(ref _dataBuffer, null);
                if (bufferToReturn != null)
                {
                    CoreChannel.BufferPool.ReturnBuffer(bufferToReturn);
                }
            }
        } // ReturnBufferToPool
 
 
        public WaitCallback DataArrivedCallback
        {
            set { _dataArrivedCallback = value; }            
        } // DataArrivedCallback
 
        public Object DataArrivedCallbackState
        {
            get { return _dataArrivedCallbackState; }
            set { _dataArrivedCallbackState = value; }
        } // DataArrivedCallbackState
 
#if !FEATURE_PAL    
 
        public WindowsIdentity ImpersonationIdentity 
        {  
            get { return _impersonationIdentity;}
            set { _impersonationIdentity = value;}
        }
 
#endif // !FEATURE_PAL
 
        public void BeginReadMessage()
        {        
            bool bProcessNow = false;
        
            try
            {
                if (_requestQueue != null)
                    _requestQueue.ScheduleMoreWorkIfNeeded();
        
                PrepareForNewMessage();       
  
                if (_dataCount == 0)
                {
                    _beginReadAsyncResult =
                        NetStream.BeginRead(DataBuffer, 0, _dataBufferSize, 
                                            _beginReadCallback, null);
                }
                else
                {            
                    // just queue the request if we already have some data
                    //   (note: we intentionally don't call the callback directly to avoid
                    //    overflowing the stack if we service a bunch of calls)    
                    bProcessNow = true;
                }
            }
            catch (Exception e)
            {
                CloseOnFatalError(e);
            }
 
            if (bProcessNow)
            {
                if (_requestQueue != null)
                    _requestQueue.ProcessNextRequest(this);
                else
                    ProcessRequestNow();
     
                _beginReadAsyncResult = null;
            }
        } // BeginReadMessage
 
 
        public void BeginReadMessageCallback(IAsyncResult ar)
        {        
            bool bProcessRequest = false;
        
            // data has been buffered; proceed to call provided callback
            try
            {
                _beginReadAsyncResult = null;  
            
                _dataOffset = 0;              
                _dataCount = NetStream.EndRead(ar);
                if (_dataCount <= 0)
                {
                    // socket has been closed
                    Close();
                }
                else
                {
                    bProcessRequest = true;
                }
            }
            catch (Exception e)
            {        
                CloseOnFatalError(e);       
            }
 
            if (bProcessRequest)
            {
                if (_requestQueue != null)
                    _requestQueue.ProcessNextRequest(this);
                else
                    ProcessRequestNow();
            }
        } // BeginReadMessageCallback     
 
 
        internal void CloseOnFatalError(Exception e)
        {
            try
            {
               SendErrorMessageIfPossible(e);
              
               // Something bad happened, so we should just close everything and 
               // return any buffers to the pool.
               Close();
            }
            catch
            {
                try
                {
                    Close();
                }
                catch
                {
                    // this is to prevent any weird errors with closing
                    // a socket from showing up as an unhandled exception.
                }
            }
        } // CloseOnFatalError
 
 
        // Called when the SocketHandler is pulled off the pending request queue.
        internal void ProcessRequestNow()
        {
            try
            {
                WaitCallback waitCallback = _dataArrivedCallback;
                if (waitCallback != null)
                    waitCallback(this); 
            }
            catch (Exception e)
            {          
                CloseOnFatalError(e);                
            }            
        } // ProcessRequestNow
 
 
        internal void RejectRequestNowSinceServerIsBusy()
        {       
            CloseOnFatalError(
                new RemotingException(
                        CoreChannel.GetResourceString("Remoting_ServerIsBusy")));                 
        } // RejectRequestNow
 
 
 
        public int ReadByte()
        {
            if (Read(_byteBuffer, 0, 1) != -1)
                return _byteBuffer[0];
            else
                return -1;
        } // ReadByte
 
        public void WriteByte(byte value, Stream outputStream)
        {
            _byteBuffer[0] = value;
            outputStream.Write(_byteBuffer, 0, 1);
        } // WriteUInt16
 
 
        public UInt16 ReadUInt16() 
        {
            Read(_byteBuffer, 0, 2);
        
            return (UInt16)(_byteBuffer[0] & 0xFF | _byteBuffer[1] << 8);
        } // ReadUInt16
        
        public void WriteUInt16(UInt16 value, Stream outputStream)
        {
            _byteBuffer[0] = (byte)value;
            _byteBuffer[1] = (byte)(value >> 8);
            outputStream.Write(_byteBuffer, 0, 2);
        } // WriteUInt16
 
 
        public int ReadInt32() 
        {
            Read(_byteBuffer, 0, 4);
        
            return (int)((_byteBuffer[0] & 0xFF) |
                          _byteBuffer[1] << 8 |
                          _byteBuffer[2] << 16 |
                          _byteBuffer[3] << 24);
        } // ReadInt32
 
        public void WriteInt32(int value, Stream outputStream)
        {
            _byteBuffer[0] = (byte)value;
            _byteBuffer[1] = (byte)(value >> 8);
            _byteBuffer[2] = (byte)(value >> 16);
            _byteBuffer[3] = (byte)(value >> 24);
            outputStream.Write(_byteBuffer, 0, 4);
        } // WriteInt32
 
 
        protected bool ReadAndMatchFourBytes(byte[] buffer)
        {
            InternalRemotingServices.RemotingAssert(buffer.Length == 4, "expecting 4 byte buffer.");
 
            Read(_byteBuffer, 0, 4);
            
            bool bMatch = 
                (_byteBuffer[0] == buffer[0]) &&
                (_byteBuffer[1] == buffer[1]) &&
                (_byteBuffer[2] == buffer[2]) &&
                (_byteBuffer[3] == buffer[3]);
 
            return bMatch;
        } // ReadAndMatchFourBytes
        
 
 
        public int Read(byte[] buffer, int offset, int count)
        {
            int totalBytesRead = 0;
 
            byte[] dataBuffer = this.DataBuffer;
 
            // see if we have buffered data
            if (_dataCount > 0)
            {
                // copy minimum of buffered data size and bytes left to read
                int readCount = Math.Min(_dataCount, count);
                StreamHelper.BufferCopy(dataBuffer, _dataOffset, buffer, offset, readCount);
                _dataCount -= readCount;
                _dataOffset += readCount;
                count -= readCount;
                offset += readCount;
                totalBytesRead += readCount;
            }
 
            // keep reading (whoever is calling this will make sure that they
            //   don't try to read too much).
            while (count > 0)
            {                
                if (count < 256)
                {
                    // if count is less than 256 bytes, I will buffer more data
                    // because it's not worth making a socket request for less.
                    BufferMoreData(dataBuffer);
 
                    // copy minimum of buffered data size and bytes left to read
                    int readCount = Math.Min(_dataCount, count);
                    StreamHelper.BufferCopy(dataBuffer, _dataOffset, buffer, offset, readCount);
                    _dataCount -= readCount;
                    _dataOffset += readCount;
                    count -= readCount;
                    offset += readCount;
                    totalBytesRead += readCount;    
                }
                else
                {
                    // just go directly to the socket
                    
                    // the internal buffer is guaranteed to be empty at this point, so just
                    //   read directly into the array given
                
                    int readCount = ReadFromSocket(buffer, offset, count);                    
                    count -= readCount;
                    offset += readCount;
                    totalBytesRead += readCount;
                }
            }
                        
            return totalBytesRead;
        } // Read
 
 
        // This should only be called when _dataCount is 0.
        private int BufferMoreData(byte[] dataBuffer)
        {        
            InternalRemotingServices.RemotingAssert(_dataCount == 0, 
                "SocketHandler::BufferMoreData called with data still in buffer." +
                "DataCount=" + _dataCount + "; DataOffset" + _dataOffset);
 
            int bytesRead = ReadFromSocket(dataBuffer, 0, _dataBufferSize);
                
            _dataOffset = 0;
            _dataCount = bytesRead;
 
            return bytesRead;
        } // BufferMoreData
 
 
        private int ReadFromSocket(byte[] buffer, int offset, int count)
        {
            int bytesRead = NetStream.Read(buffer, offset, count);
            if (bytesRead <= 0)
            {
                throw new RemotingException(       
                    CoreChannel.GetResourceString("Remoting_Socket_UnderlyingSocketClosed"));
            }
 
            return bytesRead;
        } // ReadFromSocket
        
 
        protected byte[] ReadToByte(byte b)
        {
            return ReadToByte(b, null);
        } /// ReadToByte
 
        protected byte[] ReadToByte(byte b, ValidateByteDelegate validator)
        {
            byte[] readBytes = null;
 
            byte[] dataBuffer = this.DataBuffer;
 
            // start at current position and return byte array consisting of bytes
            //   up to where we found the byte.
            if (_dataCount == 0)
                BufferMoreData(dataBuffer);
                
            int dataEnd = _dataOffset + _dataCount; // one byte past last valid byte
            int startIndex = _dataOffset; // current position
            int endIndex = startIndex; // current index
 
            bool foundByte = false;
            bool bufferEnd;
            while (!foundByte)
            {            
                InternalRemotingServices.RemotingAssert(endIndex <= dataEnd, "endIndex shouldn't pass dataEnd");
                bufferEnd = endIndex == dataEnd;
                foundByte = !bufferEnd && (dataBuffer[endIndex] == b);
 
                // validate character if necessary
                if ((validator != null) && !bufferEnd && !foundByte)
                {
                    if (!validator(dataBuffer[endIndex]))
                    {
                        throw new RemotingException(
                            CoreChannel.GetResourceString(
                                "Remoting_Http_InvalidDataReceived"));
                    }
                }
 
                // we're at the end of the currently buffered data or we've found our byte
                if (bufferEnd || foundByte)
                {
                    // store processed byte in the readBytes array
                    int count = endIndex - startIndex;                                        
                    if (readBytes == null)
                    {
                        readBytes = new byte[count];
                        StreamHelper.BufferCopy(dataBuffer, startIndex, readBytes, 0, count);
                        }
                    else
                    {
                        int oldSize = readBytes.Length;
                        byte[] newBytes = new byte[oldSize + count];
                        StreamHelper.BufferCopy(readBytes, 0, newBytes, 0, oldSize);
                        StreamHelper.BufferCopy(dataBuffer, startIndex, newBytes, oldSize, count);
                        readBytes = newBytes;
                    }
 
                    // update data counters
                    _dataOffset += count;
                    _dataCount -= count;
 
                    if (bufferEnd)
                    {
                        // we still haven't found the byte, so buffer more data
                        //   and keep looking.
                        BufferMoreData(dataBuffer);
 
                        // reset indices
                        dataEnd = _dataOffset + _dataCount; // last valid byte
                        startIndex = _dataOffset; // current position
                        endIndex = startIndex; // current index
                    }
                    else
                    if (foundByte)
                    {
                        // skip over the byte that we were looking for
                        _dataOffset += 1;
                        _dataCount -= 1;
                    }        
                }
                else
                {
                    // still haven't found character or end of buffer, so advance position
                    endIndex++;
                }
            }
                
            return readBytes;
        } // ReadToByte
 
 
        protected String ReadToChar(char ch)
        {
            return ReadToChar(ch, null);
        } // ReadToChar
 
        protected String ReadToChar(char ch, ValidateByteDelegate validator)
        {
            byte[] strBytes = ReadToByte((byte)ch, validator);
            if (strBytes == null)
                return null;
            if (strBytes.Length == 0)
                return String.Empty;
                
            String str = Encoding.ASCII.GetString(strBytes);
 
            return str;
        } // ReadToChar
 
 
        public String ReadToEndOfLine()
        {
            String str = ReadToChar('\r');
            if (ReadByte() == '\n')
                return str;
            else
                return null;
        } // ReadToEndOfLine        
               
    
    } // SocketHandler
 
 
} // namespace System.Runtime.Remoting.Channels