Ver código fonte

Add AuthnRequestConsumerResolver

Closes gh-8141
Josh Cummings 5 anos atrás
pai
commit
2c960d2ad1

+ 68 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,16 @@
 
 package org.springframework.security.config.annotation.web.configurers.saml2;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.net.URLDecoder;
 import java.time.Duration;
 import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.zip.Inflater;
+import java.util.zip.InflaterOutputStream;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 
@@ -54,9 +58,12 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
+import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
@@ -69,7 +76,11 @@ import org.springframework.security.web.context.HttpSessionSecurityContextReposi
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.web.util.UriComponents;
+import org.springframework.web.util.UriComponentsBuilder;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -157,6 +168,20 @@ public class Saml2LoginConfigurerTests {
 		verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
 	}
 
+	@Test
+	public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
+		this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
+
+		MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
+				.andReturn();
+		UriComponents components = UriComponentsBuilder
+				.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
+		String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
+		String decoded = URLDecoder.decode(samlRequest, "UTF-8");
+		String inflated = samlInflate(samlDecode(decoded));
+		assertThat(inflated).contains("ForceAuthn=\"true\"");
+	}
+
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 		// get the OpenSamlAuthenticationProvider
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -275,6 +300,29 @@ public class Saml2LoginConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http
+				.authorizeRequests(authz -> authz
+					.anyRequest().authenticated()
+				)
+				.saml2Login(saml2 -> {});
+		}
+
+		@Bean
+		Saml2AuthenticationRequestFactory authenticationRequestFactory() {
+			OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
+					new OpenSamlAuthenticationRequestFactory();
+			authenticationRequestFactory.setAuthnRequestConsumerResolver(
+					context -> authnRequest -> authnRequest.setForceAuthn(true));
+			return authenticationRequestFactory;
+		}
+	}
+
 	private static AuthenticationManager getAuthenticationManagerMock(String role) {
 		return new AuthenticationManager() {
 
@@ -315,4 +363,23 @@ public class Saml2LoginConfigurerTests {
 		}
 	}
 
+	private static org.apache.commons.codec.binary.Base64 BASE64 =
+			new org.apache.commons.codec.binary.Base64(0, new byte[]{'\n'});
+
+	private static byte[] samlDecode(String s) {
+		return BASE64.decode(s);
+	}
+
+	private static String samlInflate(byte[] b) {
+		try {
+			ByteArrayOutputStream out = new ByteArrayOutputStream();
+			InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
+			iout.write(b);
+			iout.finish();
+			return new String(out.toByteArray(), UTF_8);
+		}
+		catch (IOException e) {
+			throw new Saml2Exception("Unable to inflate string", e);
+		}
+	}
 }

+ 20 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

@@ -21,6 +21,8 @@ import java.time.Instant;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
+import java.util.function.Consumer;
+import java.util.function.Function;
 
 import org.joda.time.DateTime;
 import org.opensaml.saml.common.xml.SAMLConstants;
@@ -43,6 +45,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
 	private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
 
+	private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
+			= context -> authnRequest -> {};
+
 	@Override
 	@Deprecated
 	public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
@@ -95,8 +100,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	}
 
 	private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
-		return createAuthnRequest(context.getIssuer(),
+		AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
 				context.getDestination(), context.getAssertionConsumerServiceUrl());
+		this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
+		return authnRequest;
 	}
 
 	private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
@@ -114,6 +121,18 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		return auth;
 	}
 
+	/**
+	 * Set the {@link AuthnRequest} post-processor resolver
+	 *
+	 * @param authnRequestConsumerResolver
+	 * @since 5.4
+	 */
+	public void setAuthnRequestConsumerResolver(
+			Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
+		Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
+		this.authnRequestConsumerResolver = authnRequestConsumerResolver;
+	}
+
 	/**
 	 * '
 	 * Use this {@link Clock} with {@link Instant#now()} for generating

+ 36 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -16,6 +16,9 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.util.function.Consumer;
+import java.util.function.Function;
+
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
@@ -29,9 +32,13 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.hamcrest.CoreMatchers.containsString;
-import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
+import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
 import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -160,6 +167,34 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 		factory.setProtocolBinding("my-invalid-binding");
 	}
 
+	@Test
+	public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
+		Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
+				mock(Function.class);
+		when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
+		this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
+
+		this.factory.createPostAuthenticationRequest(this.context);
+		verify(authnRequestConsumerResolver).apply(this.context);
+	}
+
+	@Test
+	public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
+		Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
+				mock(Function.class);
+		when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
+		this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
+
+		this.factory.createRedirectAuthenticationRequest(this.context);
+		verify(authnRequestConsumerResolver).apply(this.context);
+	}
+
+	@Test
+	public void setAuthnRequestConsumerResolverWhenNullThenException() {
+		assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
 		AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
 				factory.createRedirectAuthenticationRequest(context) :