|
@@ -17,46 +17,56 @@ package org.springframework.security.oauth2.client.registration;
|
|
|
|
|
|
import org.springframework.util.Assert;
|
|
|
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
-import java.util.Optional;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
/**
|
|
|
- * A basic implementation of a {@link ClientRegistrationRepository} that accepts
|
|
|
- * a <code>List</code> of {@link ClientRegistration}(s) via it's constructor and stores it <i>in-memory</i>.
|
|
|
+ * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) <i>in-memory</i>.
|
|
|
*
|
|
|
* @author Joe Grandja
|
|
|
* @since 5.0
|
|
|
* @see ClientRegistration
|
|
|
*/
|
|
|
public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository {
|
|
|
- private final List<ClientRegistration> clientRegistrations;
|
|
|
+ private final ClientRegistrationIdentifierStrategy<String> identifierStrategy = new ClientAliasIdentifierStrategy();
|
|
|
+ private final Map<String, ClientRegistration> registrations;
|
|
|
|
|
|
- public InMemoryClientRegistrationRepository(List<ClientRegistration> clientRegistrations) {
|
|
|
- Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty");
|
|
|
- this.clientRegistrations = Collections.unmodifiableList(clientRegistrations);
|
|
|
+ public InMemoryClientRegistrationRepository(List<ClientRegistration> registrations) {
|
|
|
+ Assert.notEmpty(registrations, "registrations cannot be empty");
|
|
|
+ Map<String, ClientRegistration> registrationsMap = new HashMap<>();
|
|
|
+ registrations.forEach(registration -> {
|
|
|
+ String identifier = this.identifierStrategy.getIdentifier(registration);
|
|
|
+ if (registrationsMap.containsKey(identifier)) {
|
|
|
+ throw new IllegalArgumentException("ClientRegistration must be unique. Found duplicate identifier: " + identifier);
|
|
|
+ }
|
|
|
+ registrationsMap.put(identifier, registration);
|
|
|
+ });
|
|
|
+ this.registrations = Collections.unmodifiableMap(registrationsMap);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public ClientRegistration getRegistrationByClientId(String clientId) {
|
|
|
- Optional<ClientRegistration> clientRegistration =
|
|
|
- this.clientRegistrations.stream()
|
|
|
- .filter(c -> c.getClientId().equals(clientId))
|
|
|
- .findFirst();
|
|
|
- return clientRegistration.isPresent() ? clientRegistration.get() : null;
|
|
|
+ public List<ClientRegistration> getRegistrationsByClientId(String clientId) {
|
|
|
+ Assert.hasText(clientId, "clientId cannot be empty");
|
|
|
+ return this.registrations.values().stream()
|
|
|
+ .filter(registration -> registration.getClientId().equals(clientId))
|
|
|
+ .collect(Collectors.toList());
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public ClientRegistration getRegistrationByClientAlias(String clientAlias) {
|
|
|
- Optional<ClientRegistration> clientRegistration =
|
|
|
- this.clientRegistrations.stream()
|
|
|
- .filter(c -> c.getClientAlias().equals(clientAlias))
|
|
|
- .findFirst();
|
|
|
- return clientRegistration.isPresent() ? clientRegistration.get() : null;
|
|
|
+ Assert.hasText(clientAlias, "clientAlias cannot be empty");
|
|
|
+ return this.registrations.values().stream()
|
|
|
+ .filter(registration -> registration.getClientAlias().equals(clientAlias))
|
|
|
+ .findFirst()
|
|
|
+ .orElseGet(null);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public List<ClientRegistration> getRegistrations() {
|
|
|
- return this.clientRegistrations;
|
|
|
+ return new ArrayList<>(this.registrations.values());
|
|
|
}
|
|
|
}
|