|
//-----------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//-----------------------------------------------------------------------------
namespace System.ServiceModel.Security
{
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.IdentityModel.Policy;
using System.IdentityModel.Tokens;
using System.Runtime;
using System.Runtime.Diagnostics;
using System.Security.Authentication.ExtendedProtection;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Diagnostics;
using System.ServiceModel.Diagnostics.Application;
using System.ServiceModel.Description;
using System.ServiceModel.Dispatcher;
using System.ServiceModel.Security.Tokens;
using System.Net;
using System.Xml;
using SafeFreeCredentials = System.IdentityModel.SafeFreeCredentials;
class SecuritySessionSecurityTokenProvider : CommunicationObjectSecurityTokenProvider
{
static readonly MessageOperationFormatter operationFormatter = new MessageOperationFormatter();
BindingContext issuerBindingContext;
IChannelFactory<IRequestChannel> rstChannelFactory;
SecurityAlgorithmSuite securityAlgorithmSuite;
SecurityStandardsManager standardsManager;
Object thisLock = new Object();
SecurityKeyEntropyMode keyEntropyMode;
SecurityTokenParameters issuedTokenParameters;
bool requiresManualReplyAddressing;
EndpointAddress targetAddress;
SecurityBindingElement bootstrapSecurityBindingElement;
Uri via;
string sctUri;
Uri privacyNoticeUri;
int privacyNoticeVersion;
MessageVersion messageVersion;
EndpointAddress localAddress;
ChannelParameterCollection channelParameters;
SafeFreeCredentials credentialsHandle;
bool ownCredentialsHandle;
WebHeaderCollection webHeaderCollection;
public SecuritySessionSecurityTokenProvider(SafeFreeCredentials credentialsHandle)
: base()
{
this.credentialsHandle = credentialsHandle;
this.standardsManager = SecurityStandardsManager.DefaultInstance;
this.keyEntropyMode = AcceleratedTokenProvider.defaultKeyEntropyMode;
}
public WebHeaderCollection WebHeaders
{
get
{
return this.webHeaderCollection;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.webHeaderCollection = value;
}
}
public SecurityAlgorithmSuite SecurityAlgorithmSuite
{
get
{
return this.securityAlgorithmSuite;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.securityAlgorithmSuite = value;
}
}
public SecurityKeyEntropyMode KeyEntropyMode
{
get
{
return this.keyEntropyMode;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
SecurityKeyEntropyModeHelper.Validate(value);
this.keyEntropyMode = value;
}
}
MessageVersion MessageVersion
{
get
{
return this.messageVersion;
}
}
public EndpointAddress TargetAddress
{
get { return this.targetAddress; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.targetAddress = value;
}
}
public EndpointAddress LocalAddress
{
get { return this.localAddress; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.localAddress = value;
}
}
public Uri Via
{
get { return this.via; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.via = value;
}
}
public BindingContext IssuerBindingContext
{
get
{
return this.issuerBindingContext;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
if (value == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("value");
}
this.issuerBindingContext = value.Clone();
}
}
public SecurityBindingElement BootstrapSecurityBindingElement
{
get { return this.bootstrapSecurityBindingElement; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
if (value == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("value");
}
this.bootstrapSecurityBindingElement = (SecurityBindingElement)value.Clone();
}
}
public SecurityStandardsManager StandardsManager
{
get
{
return this.standardsManager;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
if (value == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("value"));
}
if (!value.TrustDriver.IsSessionSupported)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentException(SR.GetString(SR.TrustDriverVersionDoesNotSupportSession), "value"));
}
if (!value.SecureConversationDriver.IsSessionSupported)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentException(SR.GetString(SR.SecureConversationDriverVersionDoesNotSupportSession), "value"));
}
this.standardsManager = value;
}
}
public SecurityTokenParameters IssuedSecurityTokenParameters
{
get
{
return this.issuedTokenParameters;
}
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.issuedTokenParameters = value;
}
}
public Uri PrivacyNoticeUri
{
get { return this.privacyNoticeUri; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.privacyNoticeUri = value;
}
}
public ChannelParameterCollection ChannelParameters
{
get { return this.channelParameters; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.channelParameters = value;
}
}
public int PrivacyNoticeVersion
{
get { return this.privacyNoticeVersion; }
set
{
this.CommunicationObject.ThrowIfDisposedOrImmutable();
this.privacyNoticeVersion = value;
}
}
public virtual XmlDictionaryString IssueAction
{
get
{
return this.standardsManager.SecureConversationDriver.IssueAction;
}
}
public virtual XmlDictionaryString IssueResponseAction
{
get
{
return this.standardsManager.SecureConversationDriver.IssueResponseAction;
}
}
public virtual XmlDictionaryString RenewAction
{
get
{
return this.standardsManager.SecureConversationDriver.RenewAction;
}
}
public virtual XmlDictionaryString RenewResponseAction
{
get
{
return this.standardsManager.SecureConversationDriver.RenewResponseAction;
}
}
public virtual XmlDictionaryString CloseAction
{
get
{
return standardsManager.SecureConversationDriver.CloseAction;
}
}
public virtual XmlDictionaryString CloseResponseAction
{
get
{
return standardsManager.SecureConversationDriver.CloseResponseAction;
}
}
// ISecurityCommunicationObject methods
public override void OnAbort()
{
if (this.rstChannelFactory != null)
{
this.rstChannelFactory.Abort();
this.rstChannelFactory = null;
}
FreeCredentialsHandle();
}
public override void OnOpen(TimeSpan timeout)
{
TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
if (this.targetAddress == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.TargetAddressIsNotSet, this.GetType())));
}
if (this.IssuerBindingContext == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.IssuerBuildContextNotSet, this.GetType())));
}
if (this.IssuedSecurityTokenParameters == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.IssuedSecurityTokenParametersNotSet, this.GetType())));
}
if (this.BootstrapSecurityBindingElement == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.BootstrapSecurityBindingElementNotSet, this.GetType())));
}
if (this.SecurityAlgorithmSuite == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.SecurityAlgorithmSuiteNotSet, this.GetType())));
}
InitializeFactories();
this.rstChannelFactory.Open(timeoutHelper.RemainingTime());
this.sctUri = this.StandardsManager.SecureConversationDriver.TokenTypeUri;
}
public override void OnOpening()
{
base.OnOpening();
if (this.credentialsHandle == null)
{
if (this.IssuerBindingContext == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.IssuerBuildContextNotSet, this.GetType())));
}
if (this.BootstrapSecurityBindingElement == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.BootstrapSecurityBindingElementNotSet, this.GetType())));
}
this.credentialsHandle = SecurityUtils.GetCredentialsHandle(this.bootstrapSecurityBindingElement, this.issuerBindingContext);
this.ownCredentialsHandle = true;
}
}
public override void OnClose(TimeSpan timeout)
{
TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
if (this.rstChannelFactory != null)
{
this.rstChannelFactory.Close(timeoutHelper.RemainingTime());
this.rstChannelFactory = null;
}
FreeCredentialsHandle();
}
void FreeCredentialsHandle()
{
if (this.credentialsHandle != null)
{
if (this.ownCredentialsHandle)
{
this.credentialsHandle.Close();
}
this.credentialsHandle = null;
}
}
void InitializeFactories()
{
ISecurityCapabilities securityCapabilities = this.BootstrapSecurityBindingElement.GetProperty<ISecurityCapabilities>(this.IssuerBindingContext);
SecurityCredentialsManager securityCredentials = this.IssuerBindingContext.BindingParameters.Find<SecurityCredentialsManager>();
if (securityCredentials == null)
{
securityCredentials = ClientCredentials.CreateDefaultCredentials();
}
BindingContext context = this.IssuerBindingContext;
this.bootstrapSecurityBindingElement.ReaderQuotas = context.GetInnerProperty<XmlDictionaryReaderQuotas>();
if (this.bootstrapSecurityBindingElement.ReaderQuotas == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.EncodingBindingElementDoesNotHandleReaderQuotas)));
}
TransportBindingElement transportBindingElement = context.RemainingBindingElements.Find<TransportBindingElement>();
if (transportBindingElement != null)
this.bootstrapSecurityBindingElement.MaxReceivedMessageSize = transportBindingElement.MaxReceivedMessageSize;
SecurityProtocolFactory securityProtocolFactory = this.BootstrapSecurityBindingElement.CreateSecurityProtocolFactory<IRequestChannel>(this.IssuerBindingContext.Clone(), securityCredentials, false, this.IssuerBindingContext.Clone());
if (securityProtocolFactory is MessageSecurityProtocolFactory)
{
MessageSecurityProtocolFactory soapBindingFactory = securityProtocolFactory as MessageSecurityProtocolFactory;
soapBindingFactory.ApplyConfidentiality = soapBindingFactory.ApplyIntegrity
= soapBindingFactory.RequireConfidentiality = soapBindingFactory.RequireIntegrity = true;
soapBindingFactory.ProtectionRequirements.IncomingSignatureParts.ChannelParts.IsBodyIncluded = true;
soapBindingFactory.ProtectionRequirements.OutgoingSignatureParts.ChannelParts.IsBodyIncluded = true;
MessagePartSpecification bodyPart = new MessagePartSpecification(true);
soapBindingFactory.ProtectionRequirements.IncomingSignatureParts.AddParts(bodyPart, IssueAction);
soapBindingFactory.ProtectionRequirements.IncomingEncryptionParts.AddParts(bodyPart, IssueAction);
soapBindingFactory.ProtectionRequirements.IncomingSignatureParts.AddParts(bodyPart, RenewAction);
soapBindingFactory.ProtectionRequirements.IncomingEncryptionParts.AddParts(bodyPart, RenewAction);
soapBindingFactory.ProtectionRequirements.OutgoingSignatureParts.AddParts(bodyPart, IssueResponseAction);
soapBindingFactory.ProtectionRequirements.OutgoingEncryptionParts.AddParts(bodyPart, IssueResponseAction);
soapBindingFactory.ProtectionRequirements.OutgoingSignatureParts.AddParts(bodyPart, RenewResponseAction);
soapBindingFactory.ProtectionRequirements.OutgoingEncryptionParts.AddParts(bodyPart, RenewResponseAction);
}
securityProtocolFactory.PrivacyNoticeUri = this.PrivacyNoticeUri;
securityProtocolFactory.PrivacyNoticeVersion = this.privacyNoticeVersion;
if (this.localAddress != null)
{
MessageFilter issueAndRenewFilter = new SessionActionFilter(this.standardsManager, this.IssueResponseAction.Value, this.RenewResponseAction.Value);
context.BindingParameters.Add(new LocalAddressProvider(localAddress, issueAndRenewFilter));
}
ChannelBuilder channelBuilder = new ChannelBuilder(context, true);
IChannelFactory<IRequestChannel> innerChannelFactory;
// if the underlying transport does not support request/reply, wrap it inside
// a service channel factory.
if (channelBuilder.CanBuildChannelFactory<IRequestChannel>())
{
innerChannelFactory = channelBuilder.BuildChannelFactory<IRequestChannel>();
requiresManualReplyAddressing = true;
}
else
{
ClientRuntime clientRuntime = new ClientRuntime("RequestSecuritySession", NamingHelper.DefaultNamespace);
clientRuntime.UseSynchronizationContext = false;
clientRuntime.AddTransactionFlowProperties = false;
clientRuntime.ValidateMustUnderstand = false;
ServiceChannelFactory serviceChannelFactory = ServiceChannelFactory.BuildChannelFactory(channelBuilder, clientRuntime);
ClientOperation issueOperation = new ClientOperation(serviceChannelFactory.ClientRuntime, "Issue", IssueAction.Value);
issueOperation.Formatter = operationFormatter;
serviceChannelFactory.ClientRuntime.Operations.Add(issueOperation);
ClientOperation renewOperation = new ClientOperation(serviceChannelFactory.ClientRuntime, "Renew", RenewAction.Value);
renewOperation.Formatter = operationFormatter;
serviceChannelFactory.ClientRuntime.Operations.Add(renewOperation);
innerChannelFactory = new RequestChannelFactory(serviceChannelFactory);
requiresManualReplyAddressing = false;
}
SecurityChannelFactory<IRequestChannel> securityChannelFactory = new SecurityChannelFactory<IRequestChannel>(
securityCapabilities, this.IssuerBindingContext, channelBuilder, securityProtocolFactory, innerChannelFactory);
// attach the ExtendedProtectionPolicy to the securityProtcolFactory so it will be
// available when building the channel.
if (transportBindingElement != null)
{
if (securityChannelFactory.SecurityProtocolFactory != null)
{
securityChannelFactory.SecurityProtocolFactory.ExtendedProtectionPolicy = transportBindingElement.GetProperty<ExtendedProtectionPolicy>(context);
}
}
this.rstChannelFactory = securityChannelFactory;
this.messageVersion = securityChannelFactory.MessageVersion;
}
// token provider methods
protected override IAsyncResult BeginGetTokenCore(TimeSpan timeout, AsyncCallback callback, object state)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
return new SessionOperationAsyncResult(this, SecuritySessionOperation.Issue, this.TargetAddress, this.Via, null, timeout, callback, state);
}
protected override SecurityToken EndGetTokenCore(IAsyncResult result)
{
return SessionOperationAsyncResult.End(result);
}
protected override SecurityToken GetTokenCore(TimeSpan timeout)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
return this.DoOperation(SecuritySessionOperation.Issue, this.targetAddress, this.via, null, timeout);
}
protected override IAsyncResult BeginRenewTokenCore(TimeSpan timeout, SecurityToken tokenToBeRenewed, AsyncCallback callback, object state)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
return new SessionOperationAsyncResult(this, SecuritySessionOperation.Renew, this.TargetAddress, this.Via, tokenToBeRenewed, timeout, callback, state);
}
protected override SecurityToken EndRenewTokenCore(IAsyncResult result)
{
return SessionOperationAsyncResult.End(result);
}
protected override SecurityToken RenewTokenCore(TimeSpan timeout, SecurityToken tokenToBeRenewed)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
return this.DoOperation(SecuritySessionOperation.Renew, this.targetAddress, this.via, tokenToBeRenewed, timeout);
}
IRequestChannel CreateChannel(SecuritySessionOperation operation, EndpointAddress target, Uri via)
{
IChannelFactory<IRequestChannel> cf;
if (operation == SecuritySessionOperation.Issue || operation == SecuritySessionOperation.Renew)
{
cf = this.rstChannelFactory;
}
else
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new NotSupportedException());
}
IRequestChannel channel;
if (via != null)
{
channel = cf.CreateChannel(target, via);
}
else
{
channel = cf.CreateChannel(target);
}
if (this.channelParameters != null)
{
this.channelParameters.PropagateChannelParameters(channel);
}
if (this.ownCredentialsHandle)
{
ChannelParameterCollection newParameters = channel.GetProperty<ChannelParameterCollection>();
if (newParameters != null)
{
newParameters.Add(new SspiIssuanceChannelParameter(true, this.credentialsHandle));
}
}
return channel;
}
Message CreateRequest(SecuritySessionOperation operation, EndpointAddress target, SecurityToken currentToken, out object requestState)
{
if (operation == SecuritySessionOperation.Issue)
{
return this.CreateIssueRequest(target, out requestState);
}
else if (operation == SecuritySessionOperation.Renew)
{
return this.CreateRenewRequest(target, currentToken, out requestState);
}
else
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new NotSupportedException());
}
}
GenericXmlSecurityToken ProcessReply(Message reply, SecuritySessionOperation operation, object requestState)
{
ThrowIfFault(reply, this.targetAddress);
GenericXmlSecurityToken issuedToken = null;
if (operation == SecuritySessionOperation.Issue)
{
issuedToken = this.ProcessIssueResponse(reply, requestState);
}
else if (operation == SecuritySessionOperation.Renew)
{
issuedToken = this.ProcessRenewResponse(reply, requestState);
}
return issuedToken;
}
void OnOperationSuccess(SecuritySessionOperation operation, EndpointAddress target, SecurityToken issuedToken, SecurityToken currentToken)
{
SecurityTraceRecordHelper.TraceSecuritySessionOperationSuccess(operation, target, currentToken, issuedToken);
}
void OnOperationFailure(SecuritySessionOperation operation, EndpointAddress target, SecurityToken currentToken, Exception e, IChannel channel)
{
SecurityTraceRecordHelper.TraceSecuritySessionOperationFailure(operation, target, currentToken, e);
if (channel != null)
{
channel.Abort();
}
}
GenericXmlSecurityToken DoOperation(SecuritySessionOperation operation, EndpointAddress target, Uri via, SecurityToken currentToken, TimeSpan timeout)
{
if (target == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("target");
}
if (operation == SecuritySessionOperation.Renew && currentToken == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("currentToken");
}
IRequestChannel channel = null;
try
{
SecurityTraceRecordHelper.TraceBeginSecuritySessionOperation(operation, target, currentToken);
channel = this.CreateChannel(operation, target, via);
TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
channel.Open(timeoutHelper.RemainingTime());
object requestState;
GenericXmlSecurityToken issuedToken;
using (Message requestMessage = this.CreateRequest(operation, target, currentToken, out requestState))
{
EventTraceActivity eventTraceActivity = null;
if (TD.MessageReceivedFromTransportIsEnabled())
{
eventTraceActivity = EventTraceActivityHelper.TryExtractActivity(requestMessage);
}
TraceUtility.ProcessOutgoingMessage(requestMessage, eventTraceActivity);
using (Message reply = channel.Request(requestMessage, timeoutHelper.RemainingTime()))
{
if (reply == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new CommunicationException(SR.GetString(SR.FailToRecieveReplyFromNegotiation)));
}
if (eventTraceActivity == null && TD.MessageReceivedFromTransportIsEnabled())
{
eventTraceActivity = EventTraceActivityHelper.TryExtractActivity(reply);
}
TraceUtility.ProcessIncomingMessage(reply, eventTraceActivity);
ThrowIfFault(reply, this.targetAddress);
issuedToken = ProcessReply(reply, operation, requestState);
ValidateKeySize(issuedToken);
}
}
channel.Close(timeoutHelper.RemainingTime());
this.OnOperationSuccess(operation, target, issuedToken, currentToken);
return issuedToken;
}
catch (Exception e)
{
if (Fx.IsFatal(e))
throw;
if (e is TimeoutException)
{
e = new TimeoutException(SR.GetString(SR.ClientSecuritySessionRequestTimeout, timeout), e);
}
OnOperationFailure(operation, target, currentToken, e, channel);
throw;
}
}
byte[] GenerateEntropy(int entropySize)
{
byte[] result = DiagnosticUtility.Utility.AllocateByteArray(entropySize / 8);
CryptoHelper.FillRandomBytes(result);
return result;
}
RequestSecurityToken CreateRst(EndpointAddress target, out object requestState)
{
RequestSecurityToken rst = new RequestSecurityToken(this.standardsManager);
//rst.SetAppliesTo<EndpointAddress>(target, new XmlFormatter());
rst.KeySize = this.SecurityAlgorithmSuite.DefaultSymmetricKeyLength;
rst.TokenType = this.sctUri;
if (this.KeyEntropyMode == SecurityKeyEntropyMode.ClientEntropy || this.KeyEntropyMode == SecurityKeyEntropyMode.CombinedEntropy)
{
byte[] entropy = GenerateEntropy(rst.KeySize);
rst.SetRequestorEntropy(entropy);
requestState = entropy;
}
else
{
requestState = null;
}
return rst;
}
void PrepareRequest(Message message)
{
RequestReplyCorrelator.PrepareRequest(message);
if (this.requiresManualReplyAddressing)
{
if (this.localAddress != null)
{
message.Headers.ReplyTo = this.LocalAddress;
}
else
{
message.Headers.ReplyTo = EndpointAddress.AnonymousAddress;
}
}
if (this.webHeaderCollection != null && this.webHeaderCollection.Count > 0)
{
object prop = null;
HttpRequestMessageProperty rmp = null;
if (message.Properties.TryGetValue(HttpRequestMessageProperty.Name, out prop))
{
rmp = prop as HttpRequestMessageProperty;
}
else
{
rmp = new HttpRequestMessageProperty();
message.Properties.Add(HttpRequestMessageProperty.Name, rmp);
}
if (rmp != null && rmp.Headers != null)
{
rmp.Headers.Add(this.webHeaderCollection);
}
}
}
protected virtual Message CreateIssueRequest(EndpointAddress target, out object requestState)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
RequestSecurityToken rst = CreateRst(target, out requestState);
rst.RequestType = this.StandardsManager.TrustDriver.RequestTypeIssue;
rst.MakeReadOnly();
Message result = Message.CreateMessage(this.MessageVersion, ActionHeader.Create(this.IssueAction, this.MessageVersion.Addressing), rst);
PrepareRequest(result);
return result;
}
GenericXmlSecurityToken ExtractToken(Message response, object requestState)
{
// get the claims corresponding to the server
SecurityMessageProperty serverContextProperty = response.Properties.Security;
ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies;
if (serverContextProperty != null && serverContextProperty.ServiceSecurityContext != null)
{
authorizationPolicies = serverContextProperty.ServiceSecurityContext.AuthorizationPolicies;
}
else
{
authorizationPolicies = EmptyReadOnlyCollection<IAuthorizationPolicy>.Instance;
}
RequestSecurityTokenResponse rstr = null;
XmlDictionaryReader bodyReader = response.GetReaderAtBodyContents();
using (bodyReader)
{
if (this.StandardsManager.MessageSecurityVersion.TrustVersion == TrustVersion.WSTrustFeb2005)
rstr = this.StandardsManager.TrustDriver.CreateRequestSecurityTokenResponse(bodyReader);
else if (this.StandardsManager.MessageSecurityVersion.TrustVersion == TrustVersion.WSTrust13)
{
RequestSecurityTokenResponseCollection rstrc = this.StandardsManager.TrustDriver.CreateRequestSecurityTokenResponseCollection(bodyReader);
foreach (RequestSecurityTokenResponse rstrItem in rstrc.RstrCollection)
{
if (rstr != null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new MessageSecurityException(SR.GetString(SR.MoreThanOneRSTRInRSTRC)));
rstr = rstrItem;
}
}
else
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new NotSupportedException());
}
response.ReadFromBodyContentsToEnd(bodyReader);
}
byte[] requestorEntropy;
if (requestState != null)
{
requestorEntropy = (byte[])requestState;
}
else
{
requestorEntropy = null;
}
GenericXmlSecurityToken issuedToken = rstr.GetIssuedToken(null, null, this.KeyEntropyMode, requestorEntropy, this.sctUri, authorizationPolicies, this.SecurityAlgorithmSuite.DefaultSymmetricKeyLength, false);
return issuedToken;
}
protected virtual GenericXmlSecurityToken ProcessIssueResponse(Message response, object requestState)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
return ExtractToken(response, requestState);
}
protected virtual Message CreateRenewRequest(EndpointAddress target, SecurityToken currentSessionToken, out object requestState)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
RequestSecurityToken rst = CreateRst(target, out requestState);
rst.RequestType = this.StandardsManager.TrustDriver.RequestTypeRenew;
rst.RenewTarget = this.IssuedSecurityTokenParameters.CreateKeyIdentifierClause(currentSessionToken, SecurityTokenReferenceStyle.External);
rst.MakeReadOnly();
Message result = Message.CreateMessage(this.MessageVersion, ActionHeader.Create(this.RenewAction, this.MessageVersion.Addressing), rst);
SecurityMessageProperty supportingTokenProperty = new SecurityMessageProperty();
supportingTokenProperty.OutgoingSupportingTokens.Add(new SupportingTokenSpecification(currentSessionToken, EmptyReadOnlyCollection<IAuthorizationPolicy>.Instance, SecurityTokenAttachmentMode.Endorsing, this.IssuedSecurityTokenParameters));
result.Properties.Security = supportingTokenProperty;
PrepareRequest(result);
return result;
}
protected virtual GenericXmlSecurityToken ProcessRenewResponse(Message response, object requestState)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
if (response.Headers.Action != this.RenewResponseAction.Value)
{
throw TraceUtility.ThrowHelperError(new SecurityNegotiationException(SR.GetString(SR.InvalidRenewResponseAction, response.Headers.Action)), response);
}
return ExtractToken(response, requestState);
}
static protected void ThrowIfFault(Message message, EndpointAddress target)
{
SecurityUtils.ThrowIfNegotiationFault(message, target);
}
protected void ValidateKeySize(GenericXmlSecurityToken issuedToken)
{
this.CommunicationObject.ThrowIfClosedOrNotOpen();
ReadOnlyCollection<SecurityKey> issuedKeys = issuedToken.SecurityKeys;
if (issuedKeys != null && issuedKeys.Count == 1)
{
SymmetricSecurityKey symmetricKey = issuedKeys[0] as SymmetricSecurityKey;
if (symmetricKey != null)
{
if (this.SecurityAlgorithmSuite.IsSymmetricKeyLengthSupported(symmetricKey.KeySize))
{
return;
}
else
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(SR.GetString(SR.InvalidIssuedTokenKeySize, symmetricKey.KeySize)));
}
}
}
else
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(SR.GetString(SR.CannotObtainIssuedTokenKeySize)));
}
}
class SessionOperationAsyncResult : AsyncResult
{
static AsyncCallback openChannelCallback = Fx.ThunkCallback(new AsyncCallback(OpenChannelCallback));
static AsyncCallback closeChannelCallback = Fx.ThunkCallback(new AsyncCallback(CloseChannelCallback));
SecuritySessionSecurityTokenProvider requestor;
SecuritySessionOperation operation;
EndpointAddress target;
Uri via;
SecurityToken currentToken;
GenericXmlSecurityToken issuedToken;
IRequestChannel channel;
TimeoutHelper timeoutHelper;
public SessionOperationAsyncResult(SecuritySessionSecurityTokenProvider requestor, SecuritySessionOperation operation, EndpointAddress target, Uri via, SecurityToken currentToken, TimeSpan timeout, AsyncCallback callback, object state)
: base(callback, state)
{
this.requestor = requestor;
this.operation = operation;
this.target = target;
this.via = via;
this.currentToken = currentToken;
this.timeoutHelper = new TimeoutHelper(timeout);
SecurityTraceRecordHelper.TraceBeginSecuritySessionOperation(operation, target, currentToken);
bool completeSelf = false;
try
{
completeSelf = this.StartOperation();
}
#pragma warning suppress 56500 // covered by FxCOP
catch (Exception e)
{
if (Fx.IsFatal(e))
throw;
this.OnOperationFailure(e);
throw;
}
if (completeSelf)
{
this.OnOperationComplete();
Complete(true);
}
}
/*
* Session issuance/renewal consists of the following steps (some may be async):
* 1. Create a channel (sync)
* 2. Open the channel (async)
* 3. Create the request to send to server (sync)
* 4. Send the message and get reply (async)
* 5. Process the reply to get the token
* 6. Close the channel (async) and complete the async result
*/
bool StartOperation()
{
this.channel = this.requestor.CreateChannel(this.operation, this.target, this.via);
IAsyncResult result = this.channel.BeginOpen(this.timeoutHelper.RemainingTime(), openChannelCallback, this);
if (!result.CompletedSynchronously)
{
return false;
}
this.channel.EndOpen(result);
return this.OnChannelOpened();
}
static void OpenChannelCallback(IAsyncResult result)
{
if (result.CompletedSynchronously)
{
return;
}
SessionOperationAsyncResult self = (SessionOperationAsyncResult)result.AsyncState;
bool completeSelf = false;
Exception completionException = null;
try
{
self.channel.EndOpen(result);
completeSelf = self.OnChannelOpened();
if (completeSelf)
{
self.OnOperationComplete();
}
}
#pragma warning suppress 56500 // covered by FxCOP
catch (Exception e)
{
if (Fx.IsFatal(e))
throw;
completeSelf = true;
completionException = e;
self.OnOperationFailure(completionException);
}
if (completeSelf)
{
self.Complete(false, completionException);
}
}
bool OnChannelOpened()
{
object requestState;
Message requestMessage = this.requestor.CreateRequest(this.operation, this.target, this.currentToken, out requestState);
if (requestMessage == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.NullSessionRequestMessage, this.operation.ToString())));
}
ChannelOpenAsyncResultWrapper wrapper = new ChannelOpenAsyncResultWrapper();
wrapper.Message = requestMessage;
wrapper.RequestState = requestState;
bool closeMessage = true;
try
{
IAsyncResult result = this.channel.BeginRequest(requestMessage, this.timeoutHelper.RemainingTime(), Fx.ThunkCallback(new AsyncCallback(this.RequestCallback)), wrapper);
if (!result.CompletedSynchronously)
{
closeMessage = false;
return false;
}
Message reply = this.channel.EndRequest(result);
return this.OnReplyReceived(reply, requestState);
}
finally
{
if (closeMessage)
{
wrapper.Message.Close();
}
}
}
void RequestCallback(IAsyncResult result)
{
if (result.CompletedSynchronously)
{
return;
}
ChannelOpenAsyncResultWrapper wrapper = (ChannelOpenAsyncResultWrapper)result.AsyncState;
object requestState = wrapper.RequestState;
bool completeSelf = false;
Exception completionException = null;
try
{
Message reply = this.channel.EndRequest(result);
completeSelf = this.OnReplyReceived(reply, requestState);
if (completeSelf)
{
this.OnOperationComplete();
}
}
#pragma warning suppress 56500 // covered by FxCOP
catch (Exception e)
{
if (Fx.IsFatal(e))
throw;
completeSelf = true;
completionException = e;
this.OnOperationFailure(e);
}
finally
{
if (wrapper.Message != null)
wrapper.Message.Close();
}
if (completeSelf)
{
Complete(false, completionException);
}
}
bool OnReplyReceived(Message reply, object requestState)
{
if (reply == null)
{
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new CommunicationException(SR.GetString(SR.FailToRecieveReplyFromNegotiation)));
}
using (reply)
{
this.issuedToken = this.requestor.ProcessReply(reply, this.operation, requestState);
this.requestor.ValidateKeySize(this.issuedToken);
}
return this.OnReplyProcessed();
}
bool OnReplyProcessed()
{
IAsyncResult result = this.channel.BeginClose(this.timeoutHelper.RemainingTime(), closeChannelCallback, this);
if (!result.CompletedSynchronously)
{
return false;
}
this.channel.EndClose(result);
return true;
}
static void CloseChannelCallback(IAsyncResult result)
{
if (result.CompletedSynchronously)
{
return;
}
SessionOperationAsyncResult self = (SessionOperationAsyncResult)result.AsyncState;
Exception completionException = null;
try
{
self.channel.EndClose(result);
self.OnOperationComplete();
}
#pragma warning suppress 56500 // covered by FxCOP
catch (Exception e)
{
if (Fx.IsFatal(e))
throw;
completionException = e;
self.OnOperationFailure(completionException);
}
self.Complete(false, completionException);
}
void OnOperationFailure(Exception e)
{
try
{
this.requestor.OnOperationFailure(operation, target, currentToken, e, this.channel);
}
catch (CommunicationException ex)
{
DiagnosticUtility.TraceHandledException(ex, TraceEventType.Information);
}
}
void OnOperationComplete()
{
this.requestor.OnOperationSuccess(this.operation, this.target, this.issuedToken, this.currentToken);
}
public static SecurityToken End(IAsyncResult result)
{
SessionOperationAsyncResult self = AsyncResult.End<SessionOperationAsyncResult>(result);
return self.issuedToken;
}
}
class ChannelOpenAsyncResultWrapper
{
public object RequestState;
public Message Message;
}
internal class RequestChannelFactory : ChannelFactoryBase<IRequestChannel>
{
ServiceChannelFactory serviceChannelFactory;
public RequestChannelFactory(ServiceChannelFactory serviceChannelFactory)
{
this.serviceChannelFactory = serviceChannelFactory;
}
protected override IRequestChannel OnCreateChannel(EndpointAddress address, Uri via)
{
return serviceChannelFactory.CreateChannel<IRequestChannel>(address, via);
}
protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
{
return this.serviceChannelFactory.BeginOpen(timeout, callback, state);
}
protected override void OnEndOpen(IAsyncResult result)
{
this.serviceChannelFactory.EndOpen(result);
}
protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state)
{
return new ChainedCloseAsyncResult(timeout, callback, state, base.OnBeginClose, base.OnEndClose, this.serviceChannelFactory);
}
protected override void OnEndClose(IAsyncResult result)
{
ChainedCloseAsyncResult.End(result);
}
protected override void OnClose(TimeSpan timeout)
{
base.OnClose(timeout);
this.serviceChannelFactory.Close(timeout);
}
protected override void OnOpen(TimeSpan timeout)
{
this.serviceChannelFactory.Open(timeout);
}
protected override void OnAbort()
{
this.serviceChannelFactory.Abort();
base.OnAbort();
}
public override T GetProperty<T>()
{
return this.serviceChannelFactory.GetProperty<T>();
}
}
}
}
|