Преглед на файлове

Add Authorization Endpoint filter

Fixes gh-66
Paurav Munshi преди 5 години
родител
ревизия
54e219a397

+ 206 - 8
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -15,33 +15,231 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import java.io.IOException;
+import java.util.stream.Stream;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
 import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.web.DefaultRedirectStrategy;
+import org.springframework.security.web.RedirectStrategy;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
-
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
+import org.springframework.web.util.UriComponentsBuilder;
 
 /**
  * @author Joe Grandja
+ * @author Paurav Munshi
+ * @since 0.0.1
  */
 public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
-	private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter;
+
+	private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
+
+	private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
 	private RegisteredClientRepository registeredClientRepository;
 	private OAuth2AuthorizationService authorizationService;
-	private StringKeyGenerator codeGenerator;
+	private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
+	private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
+	private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
+
+	public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
+			OAuth2AuthorizationService authorizationService) {
+		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
+		Assert.notNull(authorizationService, "authorizationService cannot be null.");
+		this.registeredClientRepository = registeredClientRepository;
+		this.authorizationService = authorizationService;
+	}
+
+	public final void setAuthorizationRequestConverter(
+			Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
+		Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
+		this.authorizationRequestConverter = authorizationRequestConverter;
+	}
+
+	public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
+		Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
+		this.codeGenerator = codeGenerator;
+	}
+
+	public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
+		Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
+		this.authorizationRedirectStrategy = authorizationRedirectStrategy;
+	}
+
+	public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
+		Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
+		this.authorizationEndpointMatcher = authorizationEndpointMatcher;
+	}
+
+	@Override
+	protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
+		boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
+		String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
+		boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
+		if (pathMatch && responseTypeMatch) {
+			return false;
+		}else {
+			return true;
+		}
+	}
 
 	@Override
 	protected void doFilterInternal(HttpServletRequest request,
 			HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 
+		RegisteredClient client = null;
+		OAuth2AuthorizationRequest authorizationRequest = null;
+		OAuth2Authorization authorization = null;
+
+		try {
+			checkUserAuthenticated();
+			Authentication auth = SecurityContextHolder.getContext().getAuthentication();
+			client = fetchRegisteredClient(request);
+
+			authorizationRequest = this.authorizationRequestConverter.convert(request);
+			validateAuthorizationRequest(authorizationRequest, client);
+
+			String code = this.codeGenerator.generateKey();
+			authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
+			this.authorizationService.save(authorization);
+
+			String redirectUri = getRedirectUri(authorizationRequest, client);
+			sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
+		}
+		catch(OAuth2AuthorizationException authorizationException) {
+			OAuth2Error authorizationError = authorizationException.getError();
+
+			if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
+					|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
+				sendErrorInResponse(response, authorizationError);
+			}
+			else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
+					|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
+				String redirectUri = getRedirectUri(authorizationRequest, client);
+				sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
+			}
+			else {
+				throw new ServletException(authorizationException);
+			}
+		}
+
+	}
+
+	private void checkUserAuthenticated() {
+		Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
+		if (currentAuth==null || !currentAuth.isAuthenticated()) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
+		}
 	}
 
+	private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
+		String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
+		if (StringUtils.isEmpty(clientId)) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+		}
+
+		RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
+		if (client==null) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
+		}
+
+		boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
+				.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
+		if (!isAuthorizationGrantAllowed) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
+		}
+
+		return client;
+
+	}
+
+	private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
+			OAuth2AuthorizationRequest authorizationRequest, String code) {
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
+					.principalName(auth.getPrincipal().toString())
+					.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
+					.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
+					.build();
+
+		return authorization;
+	}
+
+
+	private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
+		String redirectUri = authorizationRequest.getRedirectUri();
+		if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+		}
+		if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
+			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+		}
+	}
+
+	private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
+		return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
+		? authorizationRequest.getRedirectUri()
+		: client.getRedirectUris().stream().findFirst().get();
+	}
+
+	private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
+			OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
+		UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
+				.queryParam(OAuth2ParameterNames.CODE, code);
+		if (!StringUtils.isEmpty(authorizationRequest.getState())) {
+			redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+		}
+
+		String finalRedirectUri = redirectUriBuilder.toUriString();
+		this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
+	}
+
+	private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
+		int errorStatus = -1;
+		String errorCode = authorizationError.getErrorCode();
+		if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
+			errorStatus=HttpStatus.FORBIDDEN.value();
+		}
+		else {
+			errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
+		}
+		response.sendError(errorStatus, authorizationError.getErrorCode());
+	}
+
+	private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
+			OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
+			String redirectUri) throws IOException {
+		UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
+				.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
+
+		if (!StringUtils.isEmpty(authorizationRequest.getState())) {
+			redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+		}
+
+		String finalRedirectURI = redirectUriBuilder.toUriString();
+		this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
+	}
 }

+ 55 - 0
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java

@@ -0,0 +1,55 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.Set;
+
+import javax.servlet.http.HttpServletRequest;
+
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.StringUtils;
+
+/**
+ * @author Paurav Munshi
+ * @since 0.0.1
+ * @see Converter
+ */
+public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
+
+	@Override
+	public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
+		String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
+		Set<String> scopes = !StringUtils.isEmpty(scope)
+				? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
+				: Collections.emptySet();
+
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
+				.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
+				.scopes(scopes)
+				.state(request.getParameter(OAuth2ParameterNames.STATE))
+				.authorizationUri(request.getServletPath())
+				.build();
+
+		return authorizationRequest;
+	}
+
+}

+ 36 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java

@@ -46,4 +46,40 @@ public class TestRegisteredClients {
 				.scope("profile")
 				.scope("email");
 	}
+
+	public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
+		return RegisteredClient.withId("valid_client_id")
+				.clientId("valid_client")
+				.clientSecret("valid_secret")
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.redirectUri("http://localhost:8080/test-application/callback")
+				.scope("openid")
+				.scope("profile")
+				.scope("email");
+	}
+
+	public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
+		return RegisteredClient.withId("valid_client_multi_uri_id")
+				.clientId("valid_client_multi_uri")
+				.clientSecret("valid_secret")
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.redirectUri("http://localhost:8080/test-application/callback")
+				.redirectUri("http://localhost:8080/another-test-application/callback")
+				.scope("openid")
+				.scope("profile")
+				.scope("email");
+	}
+
+	public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
+		return RegisteredClient.withId("valid_cc_client_id")
+				.clientId("valid_cc_client")
+				.clientSecret("valid_secret")
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.scope("openid")
+				.scope("profile")
+				.scope("email");
+	}
 }

+ 371 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java

@@ -0,0 +1,371 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+
+/**
+ * Tests for {@link OAuth2AuthorizationEndpointFilter}.
+ *
+ * @author Paurav Munshi
+ * @since 0.0.1
+ */
+
+public class OAuth2AuthorizationEndpointFilterTest {
+
+	private static final String VALID_CLIENT = "valid_client";
+	private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
+	private static final String VALID_CC_CLIENT = "valid_cc_client";
+
+	private OAuth2AuthorizationEndpointFilter filter;
+
+	private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
+	private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
+	private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
+	private Authentication authentication = mock(Authentication.class);
+
+	@Before
+	public void setUp() {
+		this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
+		this.filter.setCodeGenerator(this.codeGenerator);
+
+		SecurityContextHolder.getContext().setAuthentication(this.authentication);
+	}
+
+	@Test
+	public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
+		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
+		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
+		assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
+		assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
+		assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
+		assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
+		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
+		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
+		when(this.authentication.getPrincipal()).thenReturn("test-user");
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication).isAuthenticated();
+		verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
+		verify(this.authorizationService).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
+
+	}
+
+	@Test
+	public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
+		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
+		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
+		when(this.authentication.getPrincipal()).thenReturn("test-user");
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication).isAuthenticated();
+		verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
+		verify(this.authorizationService).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
+
+	}
+
+	@Test
+	public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
+		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
+		when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication, times(1)).isAuthenticated();
+		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
+
+	}
+
+	@Test
+	public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
+		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication, times(1)).isAuthenticated();
+		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
+
+	}
+
+	@Test
+	public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
+		when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication, times(1)).isAuthenticated();
+		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
+
+	}
+
+	@Test
+	public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication).isAuthenticated();
+		verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
+		assertThat(response.getContentAsString()).isEmpty();
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
+
+	}
+
+	@Test
+	public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
+		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
+		when(this.authentication.isAuthenticated()).thenReturn(true);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication).isAuthenticated();
+		verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
+		assertThat(response.getContentAsString()).isEmpty();
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
+
+	}
+
+	@Test
+	public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		when(authentication.isAuthenticated()).thenReturn(false);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authentication).isAuthenticated();
+		verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
+		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
+		verify(this.codeGenerator, times(0)).generateKey();
+		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
+		assertThat(response.getContentAsString()).isEmpty();
+		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
+
+	}
+
+	@Test
+	public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setServletPath("/custom/authorize");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
+		spyFilter.doFilter(request, response, filterChain);
+
+		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
+		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
+	}
+
+	@Test
+	public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
+		spyFilter.doFilter(request, response, filterChain);
+
+		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
+		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
+		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
+		MockHttpServletRequest request = getValidMockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
+		spyFilter.doFilter(request, response, filterChain);
+
+		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
+		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
+		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	private MockHttpServletRequest getValidMockHttpServletRequest() {
+
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
+		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
+		request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
+		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
+		request.setParameter(OAuth2ParameterNames.STATE, "teststate");
+		request.setServletPath("/oauth2/authorize");
+
+		return request;
+
+
+	}
+
+}