File: System\ServiceModel\Channels\MaxMessageSizeStream.cs
Project: ndp\cdf\src\WCF\ServiceModel\System.ServiceModel.csproj (System.ServiceModel)
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
//------------------------------------------------------------
namespace System.ServiceModel.Channels
{
    using System.IO;
    using System.ServiceModel;
    using System.Diagnostics;
    using System.ServiceModel.Diagnostics.Application;
 
    class MaxMessageSizeStream : DelegatingStream
    {
        long maxMessageSize;
        long totalBytesRead;
        long bytesWritten;
 
        public MaxMessageSizeStream(Stream stream, long maxMessageSize)
            : base(stream)
        {
            this.maxMessageSize = maxMessageSize;
        }
 
        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
        {
            count = PrepareRead(count);
            return base.BeginRead(buffer, offset, count, callback, state);
        }
 
        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
        {
            PrepareWrite(count);
            return base.BeginWrite(buffer, offset, count, callback, state);
        }
 
        public override int EndRead(IAsyncResult result)
        {
            return FinishRead(base.EndRead(result));
        }
 
        public override int Read(byte[] buffer, int offset, int count)
        {
            count = PrepareRead(count);
            return FinishRead(base.Read(buffer, offset, count));
        }
 
        public override int ReadByte()
        {
            PrepareRead(1);
            int i = base.ReadByte();
            if (i != -1)
                FinishRead(1);
            return i;
        }
 
        public override void Write(byte[] buffer, int offset, int count)
        {
            PrepareWrite(count);
            base.Write(buffer, offset, count);
        }
 
        public override void WriteByte(byte value)
        {
            PrepareWrite(1);
            base.WriteByte(value);
        }
 
        internal static Exception CreateMaxReceivedMessageSizeExceededException(long maxMessageSize)
        {
            string message = SR.GetString(SR.MaxReceivedMessageSizeExceeded, maxMessageSize);
            Exception inner = new QuotaExceededException(message);
 
            if (TD.MaxReceivedMessageSizeExceededIsEnabled())
            {
                TD.MaxReceivedMessageSizeExceeded(message);
            }
 
            return new CommunicationException(message, inner);
        }
 
        internal static Exception CreateMaxSentMessageSizeExceededException(long maxMessageSize)
        {
            string message = SR.GetString(SR.MaxSentMessageSizeExceeded, maxMessageSize);
            Exception inner = new QuotaExceededException(message);
 
            if (TD.MaxSentMessageSizeExceededIsEnabled())
            {
                TD.MaxSentMessageSizeExceeded(message);
            }
 
            return new CommunicationException(message, inner);
        }
 
        int PrepareRead(int bytesToRead)
        {
            if (totalBytesRead >= maxMessageSize)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(CreateMaxReceivedMessageSizeExceededException(maxMessageSize));
            }
 
            long bytesRemaining = maxMessageSize - totalBytesRead;
 
            if (bytesRemaining > int.MaxValue)
            {
                return bytesToRead;
            }
            else
            {
                return Math.Min(bytesToRead, (int)(maxMessageSize - totalBytesRead));
            }
        }
 
        int FinishRead(int bytesRead)
        {
            totalBytesRead += bytesRead;
            return bytesRead;
        }
 
        void PrepareWrite(int bytesToWrite)
        {
            if (bytesWritten + bytesToWrite > maxMessageSize)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(CreateMaxSentMessageSizeExceededException(maxMessageSize));
            }
 
            bytesWritten += bytesToWrite;
        }
    }
}