|
@@ -19,6 +19,10 @@ import org.springframework.security.core.GrantedAuthority;
|
|
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
|
|
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
|
|
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
|
|
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
|
|
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
|
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
|
|
|
+import org.springframework.util.Assert;
|
|
|
|
+
|
|
|
|
+import java.util.HashMap;
|
|
|
|
+import java.util.Map;
|
|
|
|
|
|
/**
|
|
/**
|
|
* A {@link GrantedAuthority} that is associated with an {@link OidcUser}.
|
|
* A {@link GrantedAuthority} that is associated with an {@link OidcUser}.
|
|
@@ -40,7 +44,7 @@ public class OidcUserAuthority extends OAuth2UserAuthority {
|
|
}
|
|
}
|
|
|
|
|
|
public OidcUserAuthority(String authority, OidcIdToken idToken, OidcUserInfo userInfo) {
|
|
public OidcUserAuthority(String authority, OidcIdToken idToken, OidcUserInfo userInfo) {
|
|
- super(authority, OidcUser.collectClaims(idToken, userInfo));
|
|
|
|
|
|
+ super(authority, collectClaims(idToken, userInfo));
|
|
this.idToken = idToken;
|
|
this.idToken = idToken;
|
|
this.userInfo = userInfo;
|
|
this.userInfo = userInfo;
|
|
}
|
|
}
|
|
@@ -82,4 +86,14 @@ public class OidcUserAuthority extends OAuth2UserAuthority {
|
|
result = 31 * result + (this.getUserInfo() != null ? this.getUserInfo().hashCode() : 0);
|
|
result = 31 * result + (this.getUserInfo() != null ? this.getUserInfo().hashCode() : 0);
|
|
return result;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ static Map<String, Object> collectClaims(OidcIdToken idToken, OidcUserInfo userInfo) {
|
|
|
|
+ Assert.notNull(idToken, "idToken cannot be null");
|
|
|
|
+ Map<String, Object> claims = new HashMap<>();
|
|
|
|
+ if (userInfo != null) {
|
|
|
|
+ claims.putAll(userInfo.getClaims());
|
|
|
|
+ }
|
|
|
|
+ claims.putAll(idToken.getClaims());
|
|
|
|
+ return claims;
|
|
|
|
+ }
|
|
}
|
|
}
|