AuthorizationHeaderUtil.java

package org.entando.kubernetes.security.oauth2;

import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

@Component
public class AuthorizationHeaderUtil {

    private final OAuth2AuthorizedClientService clientService;
    private final RestTemplateBuilder restTemplateBuilder;
    private final Logger log = LoggerFactory.getLogger(AuthorizationHeaderUtil.class);

    public AuthorizationHeaderUtil(OAuth2AuthorizedClientService clientService,
            RestTemplateBuilder restTemplateBuilder) {
        this.clientService = clientService;
        this.restTemplateBuilder = restTemplateBuilder;
    }

    public Optional<String> getAuthorizationHeader() {
        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        if (authentication instanceof OAuth2AuthenticationToken) {
            OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication;
            String name = oauthToken.getName();
            String registrationId = oauthToken.getAuthorizedClientRegistrationId();
            OAuth2AuthorizedClient client = clientService.loadAuthorizedClient(
                    registrationId,
                    name);

            if (null == client) {
                throw new OAuth2AuthorizationException(new OAuth2Error("access_denied", "The token is expired", null));
            }
            OAuth2AccessToken accessToken = client.getAccessToken();

            if (accessToken != null) {
                String tokenType = accessToken.getTokenType().getValue();
                String accessTokenValue = accessToken.getTokenValue();
                if (isExpired(accessToken)) {
                    log.info("AccessToken expired, refreshing automatically");
                    accessTokenValue = refreshToken(client, oauthToken);
                    if (null == accessTokenValue) {
                        SecurityContextHolder.getContext().setAuthentication(null);
                        throw new OAuth2AuthorizationException(
                                new OAuth2Error("access_denied", "The token is expired", null));
                    }
                }
                String authorizationHeaderValue = String.format("%s %s", tokenType, accessTokenValue);
                return Optional.of(authorizationHeaderValue);
            }

        } else if (authentication instanceof JwtAuthenticationToken) {
            JwtAuthenticationToken accessToken = (JwtAuthenticationToken) authentication;
            String tokenValue = accessToken.getToken().getTokenValue();
            String authorizationHeaderValue = String
                    .format("%s %s", OAuth2AccessToken.TokenType.BEARER.getValue(), tokenValue);
            return Optional.of(authorizationHeaderValue);
        }
        return Optional.empty();
    }

    private String refreshToken(OAuth2AuthorizedClient client, OAuth2AuthenticationToken oauthToken) {
        OAuth2AccessTokenResponse atr = refreshTokenClient(client);
        if (atr == null || atr.getAccessToken() == null) {
            log.info("Failed to refresh token for user");
            return null;
        }

        OAuth2RefreshToken refreshToken =
                atr.getRefreshToken() != null ? atr.getRefreshToken() : client.getRefreshToken();
        OAuth2AuthorizedClient updatedClient = new OAuth2AuthorizedClient(
                client.getClientRegistration(),
                client.getPrincipalName(),
                atr.getAccessToken(),
                refreshToken
        );

        clientService.saveAuthorizedClient(updatedClient, oauthToken);
        return atr.getAccessToken().getTokenValue();
    }

    private OAuth2AccessTokenResponse refreshTokenClient(OAuth2AuthorizedClient currentClient) {

        MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
        formParameters.add(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
        formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, currentClient.getRefreshToken().getTokenValue());
        formParameters.add(OAuth2ParameterNames.CLIENT_ID, currentClient.getClientRegistration().getClientId());
        RequestEntity<?> requestEntity = RequestEntity
                .post(URI.create(currentClient.getClientRegistration().getProviderDetails().getTokenUri()))
                .contentType(MediaType.APPLICATION_FORM_URLENCODED)
                .body(formParameters);
        try {
            RestTemplate r = restTemplate(currentClient.getClientRegistration().getClientId(),
                    currentClient.getClientRegistration().getClientSecret());
            ResponseEntity<OAuthIdpTokenResponseDTO> responseEntity = r
                    .exchange(requestEntity, OAuthIdpTokenResponseDTO.class);
            return toOAuth2AccessTokenResponse(responseEntity.getBody());
        } catch (OAuth2AuthorizationException e) {
            throw new OAuth2AuthenticationException(e.getError(), "Unable to refresh the token", e);
        }
    }

    private OAuth2AccessTokenResponse toOAuth2AccessTokenResponse(OAuthIdpTokenResponseDTO oauthIdpResponse) {
        Map<String, Object> additionalParameters = new HashMap<>();
        additionalParameters.put("id_token", oauthIdpResponse.getIdToken());
        additionalParameters.put("not-before-policy", oauthIdpResponse.getNotBefore());
        additionalParameters.put("refresh_expires_in", oauthIdpResponse.getRefreshExpiresIn());
        additionalParameters.put("session_state", oauthIdpResponse.getSessionState());
        return OAuth2AccessTokenResponse.withToken(oauthIdpResponse.getAccessToken())
                .expiresIn(oauthIdpResponse.getExpiresIn())
                .refreshToken(oauthIdpResponse.getRefreshToken())
                .scopes(Pattern.compile("\\s").splitAsStream(oauthIdpResponse.getScope()).collect(Collectors.toSet()))
                .tokenType(OAuth2AccessToken.TokenType.BEARER)
                .additionalParameters(additionalParameters)
                .build();
    }

    private RestTemplate restTemplate(String clientId, String clientSecret) {
        return restTemplateBuilder
                .additionalMessageConverters(
                        new FormHttpMessageConverter(),
                        new OAuth2AccessTokenResponseHttpMessageConverter())
                .errorHandler(new OAuth2ErrorResponseErrorHandler())
                .basicAuthentication(clientId, clientSecret)
                .build();
    }

    private boolean isExpired(OAuth2AccessToken accessToken) {
        Instant now = Instant.now();
        Instant expiresAt = accessToken.getExpiresAt();
        return now.isAfter(expiresAt.minus(Duration.ofMinutes(1L)));
    }
}