|
@@ -0,0 +1,129 @@
|
|
|
|
+/*
|
|
|
|
+ * Copyright 2020-2023 the original author or authors.
|
|
|
|
+ *
|
|
|
|
+ * Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
+ * you may not use this file except in compliance with the License.
|
|
|
|
+ * You may obtain a copy of the License at
|
|
|
|
+ *
|
|
|
|
+ * https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
+ *
|
|
|
|
+ * Unless required by applicable law or agreed to in writing, software
|
|
|
|
+ * distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
+ * See the License for the specific language governing permissions and
|
|
|
|
+ * limitations under the License.
|
|
|
|
+ */
|
|
|
|
+package sample.extgrant;
|
|
|
|
+
|
|
|
|
+import org.springframework.security.authentication.AuthenticationProvider;
|
|
|
|
+import org.springframework.security.core.Authentication;
|
|
|
|
+import org.springframework.security.core.AuthenticationException;
|
|
|
|
+import org.springframework.security.oauth2.core.ClaimAccessor;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2Token;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.token.DefaultOAuth2TokenContext;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
|
|
|
|
+import org.springframework.util.Assert;
|
|
|
|
+
|
|
|
|
+public class CustomCodeGrantAuthenticationProvider implements AuthenticationProvider {
|
|
|
|
+ // @fold:on
|
|
|
|
+ private final OAuth2AuthorizationService authorizationService;
|
|
|
|
+ private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
|
|
|
|
+
|
|
|
|
+ public CustomCodeGrantAuthenticationProvider(OAuth2AuthorizationService authorizationService,
|
|
|
|
+ OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
|
|
|
|
+ Assert.notNull(authorizationService, "authorizationService cannot be null");
|
|
|
|
+ Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
|
|
|
|
+ this.authorizationService = authorizationService;
|
|
|
|
+ this.tokenGenerator = tokenGenerator;
|
|
|
|
+ }
|
|
|
|
+ // @fold:off
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public Authentication authenticate(Authentication authentication) throws AuthenticationException {
|
|
|
|
+ CustomCodeGrantAuthenticationToken customCodeGrantAuthentication =
|
|
|
|
+ (CustomCodeGrantAuthenticationToken) authentication;
|
|
|
|
+
|
|
|
|
+ // Ensure the client is authenticated
|
|
|
|
+ OAuth2ClientAuthenticationToken clientPrincipal =
|
|
|
|
+ getAuthenticatedClientElseThrowInvalidClient(customCodeGrantAuthentication);
|
|
|
|
+ RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
|
|
|
|
+
|
|
|
|
+ // Ensure the client is configured to use this authorization grant type
|
|
|
|
+ if (!registeredClient.getAuthorizationGrantTypes().contains(customCodeGrantAuthentication.getGrantType())) {
|
|
|
|
+ throw new OAuth2AuthenticationException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // TODO Validate the code parameter
|
|
|
|
+
|
|
|
|
+ // Generate the access token
|
|
|
|
+ OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
|
|
|
|
+ .registeredClient(registeredClient)
|
|
|
|
+ .principal(clientPrincipal)
|
|
|
|
+ .authorizationServerContext(AuthorizationServerContextHolder.getContext())
|
|
|
|
+ .tokenType(OAuth2TokenType.ACCESS_TOKEN)
|
|
|
|
+ .authorizationGrantType(customCodeGrantAuthentication.getGrantType())
|
|
|
|
+ .authorizationGrant(customCodeGrantAuthentication)
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
|
|
|
|
+ if (generatedAccessToken == null) {
|
|
|
|
+ OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
|
|
|
|
+ "The token generator failed to generate the access token.", null);
|
|
|
|
+ throw new OAuth2AuthenticationException(error);
|
|
|
|
+ }
|
|
|
|
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
|
|
|
|
+ generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
|
|
|
|
+ generatedAccessToken.getExpiresAt(), null);
|
|
|
|
+
|
|
|
|
+ // Initialize the OAuth2Authorization
|
|
|
|
+ OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient)
|
|
|
|
+ .principalName(clientPrincipal.getName())
|
|
|
|
+ .authorizationGrantType(customCodeGrantAuthentication.getGrantType());
|
|
|
|
+ if (generatedAccessToken instanceof ClaimAccessor) {
|
|
|
|
+ authorizationBuilder.token(accessToken, (metadata) ->
|
|
|
|
+ metadata.put(
|
|
|
|
+ OAuth2Authorization.Token.CLAIMS_METADATA_NAME,
|
|
|
|
+ ((ClaimAccessor) generatedAccessToken).getClaims())
|
|
|
|
+ );
|
|
|
|
+ } else {
|
|
|
|
+ authorizationBuilder.accessToken(accessToken);
|
|
|
|
+ }
|
|
|
|
+ OAuth2Authorization authorization = authorizationBuilder.build();
|
|
|
|
+
|
|
|
|
+ // Save the OAuth2Authorization
|
|
|
|
+ this.authorizationService.save(authorization);
|
|
|
|
+
|
|
|
|
+ return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public boolean supports(Class<?> authentication) {
|
|
|
|
+ return CustomCodeGrantAuthenticationToken.class.isAssignableFrom(authentication);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // @fold:on
|
|
|
|
+ private static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) {
|
|
|
|
+ OAuth2ClientAuthenticationToken clientPrincipal = null;
|
|
|
|
+ if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) {
|
|
|
|
+ clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal();
|
|
|
|
+ }
|
|
|
|
+ if (clientPrincipal != null && clientPrincipal.isAuthenticated()) {
|
|
|
|
+ return clientPrincipal;
|
|
|
|
+ }
|
|
|
|
+ throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
|
|
|
|
+ }
|
|
|
|
+ // @fold:off
|
|
|
|
+
|
|
|
|
+}
|