Backport the class cast exception

This commit is contained in:
vjanakiram
2025-05-15 13:35:50 +05:30
parent 668472d872
commit 04e7b3a316

View File

@@ -75,7 +75,8 @@ import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestOperations;
class SpringBasedIdentityServiceFacade implements IdentityServiceFacade { class SpringBasedIdentityServiceFacade implements IdentityServiceFacade
{
private static final Log LOGGER = LogFactory.getLog(SpringBasedIdentityServiceFacade.class); private static final Log LOGGER = LogFactory.getLog(SpringBasedIdentityServiceFacade.class);
private static final Instant SOME_INSIGNIFICANT_DATE_IN_THE_PAST = Instant.MIN.plusSeconds(12345); private static final Instant SOME_INSIGNIFICANT_DATE_IN_THE_PAST = Instant.MIN.plusSeconds(12345);
private final Map<AuthorizationGrantType, OAuth2AccessTokenResponseClient> clients; private final Map<AuthorizationGrantType, OAuth2AccessTokenResponseClient> clients;
@@ -83,7 +84,8 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
private final JwtDecoder jwtDecoder; private final JwtDecoder jwtDecoder;
SpringBasedIdentityServiceFacade(RestOperations restOperations, ClientRegistration clientRegistration, SpringBasedIdentityServiceFacade(RestOperations restOperations, ClientRegistration clientRegistration,
JwtDecoder jwtDecoder) { JwtDecoder jwtDecoder)
{
requireNonNull(restOperations); requireNonNull(restOperations);
this.clientRegistration = requireNonNull(clientRegistration); this.clientRegistration = requireNonNull(clientRegistration);
this.jwtDecoder = requireNonNull(jwtDecoder); this.jwtDecoder = requireNonNull(jwtDecoder);
@@ -94,17 +96,21 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
} }
@Override @Override
public AccessTokenAuthorization authorize(AuthorizationGrant authorizationGrant) { public AccessTokenAuthorization authorize(AuthorizationGrant authorizationGrant)
{
final AbstractOAuth2AuthorizationGrantRequest request = createRequest(authorizationGrant); final AbstractOAuth2AuthorizationGrantRequest request = createRequest(authorizationGrant);
final OAuth2AccessTokenResponseClient client = getClient(request); final OAuth2AccessTokenResponseClient client = getClient(request);
final OAuth2AccessTokenResponse response; final OAuth2AccessTokenResponse response;
try { try
{
response = client.getTokenResponse(request); response = client.getTokenResponse(request);
} catch (OAuth2AuthorizationException e) { }
catch (OAuth2AuthorizationException e) {
LOGGER.debug("Failed to authorize against Authorization Server. Reason: " + e.getError() + "."); LOGGER.debug("Failed to authorize against Authorization Server. Reason: " + e.getError() + ".");
throw new AuthorizationException("Failed to obtain access token. " + e.getError(), e); throw new AuthorizationException("Failed to obtain access token. " + e.getError(), e);
} catch (RuntimeException e) { }
catch (RuntimeException e) {
LOGGER.warn("Failed to authorize against Authorization Server. Reason: " + e.getMessage()); LOGGER.warn("Failed to authorize against Authorization Server. Reason: " + e.getMessage());
throw new AuthorizationException("Failed to obtain access token.", e); throw new AuthorizationException("Failed to obtain access token.", e);
} }
@@ -113,7 +119,8 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
} }
@Override @Override
public Optional<OIDCUserInfo> getUserInfo(String tokenParameter, String principalAttribute) { public Optional<OIDCUserInfo> getUserInfo(String tokenParameter, String principalAttribute)
{
return Optional.ofNullable(tokenParameter) return Optional.ofNullable(tokenParameter)
.filter(Predicate.not(String::isEmpty)) .filter(Predicate.not(String::isEmpty))
.flatMap(token -> Optional.ofNullable(clientRegistration) .flatMap(token -> Optional.ofNullable(clientRegistration)
@@ -121,19 +128,24 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint) .map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUri) .map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUri)
.flatMap(uri -> { .flatMap(uri -> {
try { try
{
return Optional.of( return Optional.of(
new UserInfoRequest(new URI(uri), new BearerAccessToken(token)).toHTTPRequest().send()); new UserInfoRequest(new URI(uri), new BearerAccessToken(token)).toHTTPRequest().send());
} catch (IOException | URISyntaxException e) { }
catch (IOException | URISyntaxException e)
{
LOGGER.warn("Failed to get user information. Reason: " + e.getMessage()); LOGGER.warn("Failed to get user information. Reason: " + e.getMessage());
return Optional.empty(); return Optional.empty();
} }
}) })
.flatMap(httpResponse -> { .flatMap(httpResponse -> {
try { try
{
UserInfoResponse userInfoResponse = UserInfoResponse.parse(httpResponse); UserInfoResponse userInfoResponse = UserInfoResponse.parse(httpResponse);
if (userInfoResponse instanceof UserInfoErrorResponse userInfoErrorResponse) { if (userInfoResponse instanceof UserInfoErrorResponse userInfoErrorResponse)
{
String errorMessage = Optional.ofNullable(userInfoErrorResponse.getErrorObject()) String errorMessage = Optional.ofNullable(userInfoErrorResponse.getErrorObject())
.map(ErrorObject::getDescription) .map(ErrorObject::getDescription)
.orElse("No error description found"); .orElse("No error description found");
@@ -141,7 +153,9 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
throw new UserInfoException(errorMessage); throw new UserInfoException(errorMessage);
} }
return Optional.of(userInfoResponse); return Optional.of(userInfoResponse);
} catch (ParseException e) { }
catch (ParseException e)
{
LOGGER.warn("Failed to parse user info response. Reason: " + e.getMessage()); LOGGER.warn("Failed to parse user info response. Reason: " + e.getMessage());
return Optional.empty(); return Optional.empty();
} }
@@ -153,30 +167,39 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
} }
@Override @Override
public ClientRegistration getClientRegistration() { public ClientRegistration getClientRegistration()
{
return clientRegistration; return clientRegistration;
} }
@Override @Override
public DecodedAccessToken decodeToken(String token) { public DecodedAccessToken decodeToken(String token)
{
final Jwt validToken; final Jwt validToken;
try { try
{
validToken = jwtDecoder.decode(token); validToken = jwtDecoder.decode(token);
} catch (RuntimeException e) { }
catch (RuntimeException e)
{
throw new TokenDecodingException("Failed to decode token. " + e.getMessage(), e); throw new TokenDecodingException("Failed to decode token. " + e.getMessage(), e);
} }
if (LOGGER.isDebugEnabled()) { if (LOGGER.isDebugEnabled())
{
LOGGER.debug("Bearer token outcome: " + validToken.getClaims()); LOGGER.debug("Bearer token outcome: " + validToken.getClaims());
} }
return new SpringDecodedAccessToken(validToken); return new SpringDecodedAccessToken(validToken);
} }
private AbstractOAuth2AuthorizationGrantRequest createRequest(AuthorizationGrant grant) { private AbstractOAuth2AuthorizationGrantRequest createRequest(AuthorizationGrant grant)
if (grant.isPassword()) { {
if (grant.isPassword())
{
return new OAuth2PasswordGrantRequest(clientRegistration, grant.getUsername(), grant.getPassword()); return new OAuth2PasswordGrantRequest(clientRegistration, grant.getUsername(), grant.getPassword());
} }
if (grant.isRefreshToken()) { if (grant.isRefreshToken())
{
final OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( final OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken(
TokenType.BEARER, TokenType.BEARER,
"JUST_FOR_FULFILLING_THE_SPRING_API", "JUST_FOR_FULFILLING_THE_SPRING_API",
@@ -188,7 +211,8 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
clientRegistration.getScopes()); clientRegistration.getScopes());
} }
if (grant.isAuthorizationCode()) { if (grant.isAuthorizationCode())
{
final OAuth2AuthorizationExchange authzExchange = new OAuth2AuthorizationExchange( final OAuth2AuthorizationExchange authzExchange = new OAuth2AuthorizationExchange(
OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest.authorizationCode()
.clientId(clientRegistration.getClientId()) .clientId(clientRegistration.getClientId())
@@ -205,31 +229,36 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
throw new UnsupportedOperationException("Unsupported grant type."); throw new UnsupportedOperationException("Unsupported grant type.");
} }
private OAuth2AccessTokenResponseClient getClient(AbstractOAuth2AuthorizationGrantRequest request) { private OAuth2AccessTokenResponseClient getClient(AbstractOAuth2AuthorizationGrantRequest request)
{
final AuthorizationGrantType grantType = request.getGrantType(); final AuthorizationGrantType grantType = request.getGrantType();
final OAuth2AccessTokenResponseClient client = clients.get(grantType); final OAuth2AccessTokenResponseClient client = clients.get(grantType);
if (client == null) { if (client == null)
{
throw new UnsupportedOperationException("Unsupported grant type `" + grantType + "`."); throw new UnsupportedOperationException("Unsupported grant type `" + grantType + "`.");
} }
return client; return client;
} }
private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> createAuthorizationCodeClient( private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> createAuthorizationCodeClient(
RestOperations rest) { RestOperations rest)
{
final DefaultAuthorizationCodeTokenResponseClient client = new DefaultAuthorizationCodeTokenResponseClient(); final DefaultAuthorizationCodeTokenResponseClient client = new DefaultAuthorizationCodeTokenResponseClient();
client.setRestOperations(rest); client.setRestOperations(rest);
return client; return client;
} }
private static OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> createRefreshTokenClient( private static OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> createRefreshTokenClient(
RestOperations rest) { RestOperations rest)
{
final DefaultRefreshTokenTokenResponseClient client = new DefaultRefreshTokenTokenResponseClient(); final DefaultRefreshTokenTokenResponseClient client = new DefaultRefreshTokenTokenResponseClient();
client.setRestOperations(rest); client.setRestOperations(rest);
return client; return client;
} }
private static OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> createPasswordClient(RestOperations rest, private static OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> createPasswordClient(RestOperations rest,
ClientRegistration clientRegistration) { ClientRegistration clientRegistration)
{
final DefaultPasswordTokenResponseClient client = new DefaultPasswordTokenResponseClient(); final DefaultPasswordTokenResponseClient client = new DefaultPasswordTokenResponseClient();
client.setRestOperations(rest); client.setRestOperations(rest);
Optional.of(clientRegistration) Optional.of(clientRegistration)
@@ -247,7 +276,8 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
} }
private static Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> audienceParameterConverter( private static Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> audienceParameterConverter(
String audienceValue) { String audienceValue)
{
return (grantRequest) -> { return (grantRequest) -> {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(); MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set("audience", audienceValue); parameters.set("audience", audienceValue);
@@ -256,20 +286,24 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
}; };
} }
private static class SpringAccessTokenAuthorization implements AccessTokenAuthorization { private static class SpringAccessTokenAuthorization implements AccessTokenAuthorization
{
private final OAuth2AccessTokenResponse tokenResponse; private final OAuth2AccessTokenResponse tokenResponse;
private SpringAccessTokenAuthorization(OAuth2AccessTokenResponse tokenResponse) { private SpringAccessTokenAuthorization(OAuth2AccessTokenResponse tokenResponse)
{
this.tokenResponse = requireNonNull(tokenResponse); this.tokenResponse = requireNonNull(tokenResponse);
} }
@Override @Override
public AccessToken getAccessToken() { public AccessToken getAccessToken()
{
return new SpringAccessToken(tokenResponse.getAccessToken()); return new SpringAccessToken(tokenResponse.getAccessToken());
} }
@Override @Override
public String getRefreshTokenValue() { public String getRefreshTokenValue()
{
return Optional.of(tokenResponse) return Optional.of(tokenResponse)
.map(OAuth2AccessTokenResponse::getRefreshToken) .map(OAuth2AccessTokenResponse::getRefreshToken)
.map(AbstractOAuth2Token::getTokenValue) .map(AbstractOAuth2Token::getTokenValue)
@@ -277,34 +311,41 @@ class SpringBasedIdentityServiceFacade implements IdentityServiceFacade {
} }
} }
private static class SpringAccessToken implements AccessToken { private static class SpringAccessToken implements AccessToken
{
private final AbstractOAuth2Token token; private final AbstractOAuth2Token token;
private SpringAccessToken(AbstractOAuth2Token token) { private SpringAccessToken(AbstractOAuth2Token token)
{
this.token = requireNonNull(token); this.token = requireNonNull(token);
} }
@Override @Override
public String getTokenValue() { public String getTokenValue()
{
return token.getTokenValue(); return token.getTokenValue();
} }
@Override @Override
public Instant getExpiresAt() { public Instant getExpiresAt()
{
return token.getExpiresAt(); return token.getExpiresAt();
} }
} }
private static class SpringDecodedAccessToken extends SpringAccessToken implements DecodedAccessToken { private static class SpringDecodedAccessToken extends SpringAccessToken implements DecodedAccessToken
{
private final Jwt jwt; private final Jwt jwt;
private SpringDecodedAccessToken(Jwt jwt) { private SpringDecodedAccessToken(Jwt jwt)
{
super(jwt); super(jwt);
this.jwt = jwt; this.jwt = jwt;
} }
@Override @Override
public Object getClaim(String claim) { public Object getClaim(String claim)
{
return jwt.getClaim(claim); return jwt.getClaim(claim);
} }
} }