Browse Source

Throw Saml2AuthenticationException

Closes gh-9310
Han YanJing 4 years ago
parent
commit
6e41246a2b

+ 40 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,12 +30,14 @@ import java.util.zip.InflaterOutputStream;
 
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 
@@ -62,10 +64,13 @@ import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
 import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.core.Saml2Utils;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
@@ -78,6 +83,7 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
@@ -210,6 +216,24 @@ public class Saml2LoginConfigurerTests {
 		verify(CustomAuthenticationConverter.authenticationConverter).convert(any(HttpServletRequest.class));
 	}
 
+	@Test
+	public void authenticateWithInvalidDeflatedSAMLResponseThenFailureHandlerUses() throws Exception {
+		this.spring.register(CustomAuthenticationFailureHandler.class).autowire();
+		byte[] invalidDeflated = "invalid".getBytes();
+		String encoded = Saml2Utils.samlEncode(invalidDeflated);
+		MockHttpServletRequestBuilder request = get("/login/saml2/sso/registration-id").queryParam("SAMLResponse",
+				encoded);
+		this.mvc.perform(request);
+		ArgumentCaptor<Saml2AuthenticationException> captor = ArgumentCaptor
+				.forClass(Saml2AuthenticationException.class);
+		verify(CustomAuthenticationFailureHandler.authenticationFailureHandler).onAuthenticationFailure(
+				any(HttpServletRequest.class), any(HttpServletResponse.class), captor.capture());
+		Saml2AuthenticationException exception = captor.getValue();
+		assertThat(exception.getSaml2Error().getErrorCode()).isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE);
+		assertThat(exception.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string");
+		assertThat(exception.getCause()).isInstanceOf(IOException.class);
+	}
+
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 		// get the OpenSamlAuthenticationProvider
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -314,6 +338,21 @@ public class Saml2LoginConfigurerTests {
 
 	}
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthenticationFailureHandler extends WebSecurityConfigurerAdapter {
+
+		static final AuthenticationFailureHandler authenticationFailureHandler = mock(
+				AuthenticationFailureHandler.class);
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http.authorizeRequests((authz) -> authz.anyRequest().authenticated())
+					.saml2Login((saml2) -> saml2.failureHandler(authenticationFailureHandler));
+		}
+
+	}
+
 	@EnableWebSecurity
 	@Import(Saml2LoginConfigBeans.class)
 	static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {

+ 14 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,7 +28,9 @@ import org.apache.commons.codec.binary.Base64;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpMethod;
-import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2Error;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -83,7 +85,13 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 	}
 
 	private byte[] samlDecode(String s) {
-		return BASE64.decode(s);
+		try {
+			return BASE64.decode(s);
+		}
+		catch (Exception ex) {
+			throw new Saml2AuthenticationException(
+					new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
+		}
 	}
 
 	private String samlInflate(byte[] b) {
@@ -94,8 +102,9 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 			inflaterOutputStream.finish();
 			return new String(out.toByteArray(), StandardCharsets.UTF_8);
 		}
-		catch (IOException ex) {
-			throw new Saml2Exception("Unable to inflate string", ex);
+		catch (Exception ex) {
+			throw new Saml2AuthenticationException(
+					new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
 		}
 	}
 

+ 39 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -29,7 +29,9 @@ import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.io.ClassPathResource;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2Utils;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
@@ -37,6 +39,7 @@ import org.springframework.util.StreamUtils;
 import org.springframework.web.util.UriUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
@@ -64,6 +67,22 @@ public class Saml2AuthenticationTokenConverterTests {
 				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
 	}
 
+	@Test
+	public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
+				this.relyingPartyRegistrationResolver);
+		given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
+				.willReturn(this.relyingPartyRegistration);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter("SAMLResponse", "invalid");
+		assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
+				.withCauseInstanceOf(IllegalArgumentException.class)
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
+						.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription())
+						.isEqualTo("Failed to decode SAMLResponse"));
+	}
+
 	@Test
 	public void convertWhenNoSamlResponseThenNull() {
 		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
@@ -100,6 +119,25 @@ public class Saml2AuthenticationTokenConverterTests {
 				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
 	}
 
+	@Test
+	public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() {
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
+				this.relyingPartyRegistrationResolver);
+		given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
+				.willReturn(this.relyingPartyRegistration);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setMethod("GET");
+		byte[] invalidDeflated = "invalid".getBytes();
+		String encoded = Saml2Utils.samlEncode(invalidDeflated);
+		request.setParameter("SAMLResponse", encoded);
+		assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
+				.withCauseInstanceOf(IOException.class)
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
+						.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
+				.satisfies(
+						(ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"));
+	}
+
 	@Test
 	public void constructorWhenResolverIsNullThenIllegalArgument() {
 		assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));