|
@@ -30,6 +30,7 @@ import java.util.Map;
|
|
|
* {@link OAuth2AuthorizationRequest} in the {@code HttpSession}.
|
|
|
*
|
|
|
* @author Joe Grandja
|
|
|
+ * @author Rob Winch
|
|
|
* @since 5.0
|
|
|
* @see AuthorizationRequestRepository
|
|
|
* @see OAuth2AuthorizationRequest
|
|
@@ -37,17 +38,16 @@ import java.util.Map;
|
|
|
public final class HttpSessionOAuth2AuthorizationRequestRepository implements AuthorizationRequestRepository<OAuth2AuthorizationRequest> {
|
|
|
private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
|
|
|
HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";
|
|
|
+
|
|
|
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
|
|
|
|
|
|
@Override
|
|
|
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
|
|
|
Assert.notNull(request, "request cannot be null");
|
|
|
- Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
|
|
|
+ String stateParameter = getStateParameter(request);
|
|
|
+ Assert.hasText(stateParameter, "state parameter cannot be empty");
|
|
|
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
|
|
- if (authorizationRequests != null) {
|
|
|
- return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
|
|
|
- }
|
|
|
- return null;
|
|
|
+ return authorizationRequests.get(stateParameter);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -59,35 +59,46 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
|
|
|
this.removeAuthorizationRequest(request);
|
|
|
return;
|
|
|
}
|
|
|
- Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
|
|
|
- Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
|
|
|
- authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
|
|
|
+ String state = authorizationRequest.getState();
|
|
|
+ Assert.hasText(state, "authorizationRequest.state cannot be empty");
|
|
|
+ Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
|
|
+ authorizationRequests.put(state, authorizationRequest);
|
|
|
+ request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
|
|
|
Assert.notNull(request, "request cannot be null");
|
|
|
- OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
|
|
|
- if (authorizationRequest != null) {
|
|
|
- Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
|
|
- authorizationRequests.remove(authorizationRequest.getState());
|
|
|
+ String stateParameter = getStateParameter(request);
|
|
|
+ if (stateParameter == null) {
|
|
|
+ return null;
|
|
|
}
|
|
|
- return authorizationRequest;
|
|
|
+ Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
|
|
|
+ OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
|
|
|
+ request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
|
|
|
+ return originalRequest;
|
|
|
}
|
|
|
|
|
|
- private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
|
|
|
- return this.getAuthorizationRequests(request, false);
|
|
|
+ /**
|
|
|
+ * Gets the state parameter from the {@link HttpServletRequest}
|
|
|
+ * @param request the request to use
|
|
|
+ * @return the state parameter or null if not found
|
|
|
+ */
|
|
|
+ private String getStateParameter(HttpServletRequest request) {
|
|
|
+ return request.getParameter(OAuth2ParameterNames.STATE);
|
|
|
}
|
|
|
|
|
|
- private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
|
|
|
- Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
|
|
|
- HttpSession session = request.getSession(createSession);
|
|
|
- if (session != null) {
|
|
|
- authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
|
|
|
- if (authorizationRequests == null) {
|
|
|
- authorizationRequests = new HashMap<>();
|
|
|
- session.setAttribute(this.sessionAttributeName, authorizationRequests);
|
|
|
- }
|
|
|
+ /**
|
|
|
+ * Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}
|
|
|
+ * @param request
|
|
|
+ * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}.
|
|
|
+ */
|
|
|
+ private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
|
|
|
+ HttpSession session = request.getSession(false);
|
|
|
+ Map<String, OAuth2AuthorizationRequest> authorizationRequests = session == null ? null :
|
|
|
+ (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
|
|
|
+ if (authorizationRequests == null) {
|
|
|
+ return new HashMap<>();
|
|
|
}
|
|
|
return authorizationRequests;
|
|
|
}
|