File: channels\tcp\tcpstreams.cs
Project: ndp\clr\src\managedlibraries\remoting\System.Runtime.Remoting.csproj (System.Runtime.Remoting)
// ==++==
// 
//   Copyright (c) Microsoft Corporation.  All rights reserved.
// 
// ==--==
//============================================================
//
// File:    TcpStreams.cs
//
// Summary: Defines streams used by TCP channel.
//
//============================================================
 
 
 
using System;
using System.Collections;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Globalization;
 
 
namespace System.Runtime.Remoting.Channels.Tcp
{
 
    internal abstract class TcpReadingStream : Stream
    {
        public void ReadToEnd()
        {
            // This method should never be called where it would be valid
            //   to use this data, so it is ok to throw the excess bytes
            //   away.
            byte[] buffer = new byte[64];
            int readCount;
            do
            {
                readCount = Read(buffer, 0, 64);
            } while (readCount > 0);
        }
    
        public virtual bool FoundEnd { get { return false; } }
        
        public override bool CanRead {  get { return true; } }
        public override bool CanSeek { get { return false; } }
        public override bool CanWrite { get { return false; } }
       
        public override long Length {  get { throw new NotSupportedException(); } }
 
        public override long Position
        {
             get{ throw new NotSupportedException(); }
             set{ throw new NotSupportedException(); }
        }
 
        public override  void Flush() { throw new NotSupportedException(); }
        
        public override  long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); }
        public override  void SetLength(long value) { throw new NotSupportedException(); }
 
        public override void Write(byte[] buffer, int offset, int count)
        {
            throw new NotSupportedException();
        }        
 
    } // TcpReadingStream
 
 
    internal sealed class TcpFixedLengthReadingStream : TcpReadingStream
    { 
        private SocketHandler _inputStream; 
        private int _bytesLeft;   // bytes left to read
 
        internal TcpFixedLengthReadingStream(SocketHandler inputStream, 
                                             int contentLength)
        {
            _inputStream = inputStream;
            _bytesLeft = contentLength;
        } // TcpFixedLengthReadingStream
 
        public override bool FoundEnd { get { return _bytesLeft == 0; } }
 
        protected override void Dispose(bool disposing)
        {
            try {
                if (disposing)
                    _inputStream.OnInputStreamClosed();
            }
            finally {
                base.Dispose(disposing);
            }
        }
        
        public override int Read(byte[] buffer, int offset, int count)
        {
            if (_bytesLeft == 0)
                return 0;
        
            int readCount = _inputStream.Read(buffer, offset, Math.Min(_bytesLeft, count));
            if (readCount > 0)
                _bytesLeft -= readCount;
            
            return readCount;
        } // Read
 
        public override int ReadByte()
        {
            if (_bytesLeft == 0)
                return -1;
 
            _bytesLeft -= 1;
            return _inputStream.ReadByte();
        } // ReadByte
                
    } // TcpFixedLengthReadingStream
 
 
    internal sealed class TcpChunkedReadingStream : TcpReadingStream
    {   
        private SocketHandler _inputStream = null; // read chunked tcp data from here
        
        private int    _bytesLeft;          // bytes left in current chunk
        private bool   _bFoundEnd = false;  // has end of stream been reached?
 
        private byte[] _byteBuffer = new byte[1]; // buffer for reading bytes
        
        
        internal TcpChunkedReadingStream(SocketHandler inputStream)
        {
            _inputStream = inputStream;
            
            _bytesLeft = 0;
        } // HttpChunkedRequestStream
 
        public override bool FoundEnd { get { return _bFoundEnd; } }
 
        protected override void Dispose(bool disposing)
        {
            try {
                // 
            }
            finally {
                base.Dispose(disposing);
            }
        } // Close
        
        public override int Read(byte[] buffer, int offset, int count)
        {
            int bytesRead = 0;
        
            while (!_bFoundEnd && (count > 0))
            {
                // see if we need to start reading a new chunk
                if (_bytesLeft == 0)
                {
                    _bytesLeft = _inputStream.ReadInt32();
                                        
                    if (_bytesLeft == 0)
                    {
                        ReadTrailer();                        
                        _bFoundEnd = true;
                    }
                }
 
                if (!_bFoundEnd)
                {
                    int readCount = Math.Min(_bytesLeft, count);
                    int bytesReadThisTime = _inputStream.Read(buffer, offset, readCount);
                    if (bytesReadThisTime <= 0)
                    {
                        throw new RemotingException(
                            CoreChannel.GetResourceString(
                                "Remoting_Tcp_ChunkedEncodingError"));
                    }
                    
                    _bytesLeft -= bytesReadThisTime;
                    count -= bytesReadThisTime;
                    offset += bytesReadThisTime;
                    bytesRead += bytesReadThisTime;
                
                    // see if the end of the chunk was found
                    if (_bytesLeft == 0)
                    {
                        ReadTrailer();
                    }
                }
            } // while (count > 0)  
 
            return bytesRead;
        } // Read        
 
        public override int ReadByte()
        {
            int readCount = Read(_byteBuffer, 0, 1);
            if (readCount == 0)
                return -1;
                
            return _byteBuffer[0];         
        } // ReadByte
 
 
        private void ReadTrailer()
        {
            // read trailer bytes "\r\n" and throw an exception if they aren't correct.
            int ch = _inputStream.ReadByte();
            if (ch != '\r')
            {
                throw new RemotingException(
                    CoreChannel.GetResourceString(
                        "Remoting_Tcp_ChunkedEncodingError"));
            }
 
            ch = _inputStream.ReadByte();
            if (ch != '\n')
            {
                throw new RemotingException(
                    CoreChannel.GetResourceString(
                        "Remoting_Tcp_ChunkedEncodingError"));
            }
        }
        
    } // TcpChunkedReadingStream
 
    
 
 
     // Maintains control of a socket connection.
    internal sealed class TcpServerSocketHandler : TcpSocketHandler
    {
        // prebaked bytes
        private static byte[] s_endOfLineBytes = Encoding.ASCII.GetBytes("\r\n");
        
    
        // Used to keep track of socket connections
        private static Int64 _connectionIdCounter = 0;        
        
        private Int64 _connectionId;   // id for this connection
 
        private bool _bOneWayRequest;  // is the incoming request one way?
        private bool _bChunked;        // is the incoming request chunked?
        private int  _contentLength;   // content length of incoming request
 
        TcpReadingStream _requestStream; // the request stream
 
        
        internal TcpServerSocketHandler(Socket socket, RequestQueue requestQueue, Stream stream) : 
                        base(socket, requestQueue, stream)
        {                 
            _connectionId = Interlocked.Increment(ref _connectionIdCounter);
        } // TcpServerSocketHandler
 
 
        // Determine if it's possible to service another request
        public bool CanServiceAnotherRequest()
        {
            return true;
        } // CanServiceAnotherRequest
        
 
        // Prepare for reading a new request off of the same socket
        protected override void PrepareForNewMessage()
        {
            if (_requestStream != null)
            {
                if (!_requestStream.FoundEnd)
                    _requestStream.ReadToEnd();
                _requestStream = null;
            }
        } // PrepareForNewRequest
            
        protected override void SendErrorMessageIfPossible(Exception e)
        {        
            // A fatal exception occurred. We communicate this error by
            // writing an error message and empty message body.
            try
            {
                SendErrorResponse(e, true);
            }
            catch
            {
                // the connection must be dead, so it doesn't really matter.
            }
        } // SendErrorMessageIfPossible
            
 
        // read headers
        public ITransportHeaders ReadHeaders()
        {        
            BaseTransportHeaders headers = new BaseTransportHeaders();
 
            UInt16 operation;
            ReadVersionAndOperation(out operation);
 
            // make sure the operation is Request or OneWayRequest.
            if (operation == TcpOperations.Request)
            {
                _bOneWayRequest = false;
            }
            else
            if (operation == TcpOperations.OneWayRequest)
            {
                _bOneWayRequest = true;
            }
            else
            {
                throw new RemotingException(
                    String.Format(
                        CultureInfo.CurrentCulture, CoreChannel.GetResourceString("Remoting_Tcp_ExpectingRequestOp"),
                        operation.ToString(CultureInfo.CurrentCulture)));
            }            
 
            // content length must come next (may be chunked or a specific length)
            ReadContentLength(out _bChunked, out _contentLength);
 
            // read to end of headers  
            ReadToEndOfHeaders(headers);   
                           
            // add IP address and Connection Id to headers
            headers.IPAddress = ((IPEndPoint)NetSocket.RemoteEndPoint).Address;
            headers.ConnectionId = _connectionId;
            
            return headers;
        } // ReadHeaders
 
 
        public Stream GetRequestStream()
        {
            if (!_bChunked)
                _requestStream =  new TcpFixedLengthReadingStream(this, _contentLength);
            else
                _requestStream =  new TcpChunkedReadingStream(this);
            return _requestStream;
        } // GetRequestStream
      
 
        public void SendResponse(ITransportHeaders headers, Stream contentStream)
        {           
            // bail out if the original request was OneWay (means the client doesn't even
            //   want or expect to receive responses or error messages)
            if (_bOneWayRequest)
                return;            
        
            // build up headers and send      
            ChunkedMemoryStream headerStream = new ChunkedMemoryStream(CoreChannel.BufferPool);
 
            // output preamble and version
            WritePreambleAndVersion(headerStream);
            // output opcode
            WriteUInt16(TcpOperations.Reply, headerStream);
            // output content length delimiter
            WriteUInt16(TcpContentDelimiter.ContentLength, headerStream);
            WriteInt32((int)contentStream.Length, headerStream);
 
            // No status code header is needed because if we're in this code path
            //   the data transfer succeeded as far as the transport protocol is
            //   concerned (and the success status code is optional).
 
            WriteHeaders(headers, headerStream);
            
            headerStream.WriteTo(NetStream);
            headerStream.Close();
 
            StreamHelper.CopyStream(contentStream, NetStream);          
                         
            contentStream.Close();            
        } // SendResponse
 
        string GenerateFaultString(Exception e) {
            //If the user has specified it's a development server (versus a production server) in remoting config,
            //then we should just return e.ToString instead of extracting the list of messages.                        
            if (!CustomErrorsEnabled()) 
                return e.ToString();            
            else {                
                return CoreChannel.GetResourceString("Remoting_InternalError");                                                             
            }            
        }
 
        public void SendErrorResponse(Exception e, bool bCloseConnection)
        {
            SendErrorResponse(GenerateFaultString(e), bCloseConnection);
        }
 
        public void SendErrorResponse(string e, bool bCloseConnection)
        {
            // bail out if the original request was OneWay (means the client doesn't even
            //   want or expect to receive responses or error messages)
            if (_bOneWayRequest)
                return;
        
            // build up headers and send      
            ChunkedMemoryStream headerStream = new ChunkedMemoryStream(CoreChannel.BufferPool);
 
            // output preamble and version
            WritePreambleAndVersion(headerStream);
            // output opcode
            WriteUInt16(TcpOperations.Reply, headerStream);
            // output content length delimiter (0-length stream)
            WriteUInt16(TcpContentDelimiter.ContentLength, headerStream);
            WriteInt32(0, headerStream);
 
            // output status code and reason
            WriteUInt16(TcpHeaders.StatusCode, headerStream);
            WriteByte(TcpHeaderFormat.UInt16, headerStream);
            WriteUInt16(TcpStatusCode.GenericError, headerStream);
            // we purposely don't include the stack trace to avoid giving
            //   out too much information for security purposes.
            WriteUInt16(TcpHeaders.StatusPhrase, headerStream);
            WriteByte(TcpHeaderFormat.CountedString, headerStream);
            WriteCountedString(e, headerStream);
 
            // indicate that we are about to close the connection
            WriteUInt16(TcpHeaders.CloseConnection, headerStream);
            WriteByte(TcpHeaderFormat.Void, headerStream);
 
            // end of headers
            WriteUInt16(TcpHeaders.EndOfHeaders, headerStream);
            
            headerStream.WriteTo(NetStream);
            headerStream.Close();
        } // SendErrorResponse
               
 
    } // class TcpServerSocketHandler
 
    
    
 
} // namespace System.Runtime.Remoting.Channels.Tcp