OAuth RFC-8628 various post-fixes (#4699)

* add timeout to oauth calls

Signed-off-by: Andrew Fiddian-Green <software@whitebear.ch>
pull/4715/head
Andrew Fiddian-Green 2025-04-09 09:42:23 +01:00 committed by GitHub
parent bda8c2a2a4
commit 013de041ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 99 additions and 37 deletions

View File

@ -66,6 +66,8 @@ public class OAuthClientServiceImpl implements OAuthClientService {
private final transient Logger logger = LoggerFactory.getLogger(OAuthClientServiceImpl.class); private final transient Logger logger = LoggerFactory.getLogger(OAuthClientServiceImpl.class);
private final Object refreshTokenProcessLock = new Object();
private @NonNullByDefault({}) OAuthStoreHandler storeHandler; private @NonNullByDefault({}) OAuthStoreHandler storeHandler;
// Constructor params - static // Constructor params - static
@ -291,11 +293,26 @@ public class OAuthClientServiceImpl implements OAuthClientService {
} }
@Override @Override
public synchronized AccessTokenResponse refreshToken() throws OAuthException, IOException, OAuthResponseException { public AccessTokenResponse refreshToken() throws OAuthException, IOException, OAuthResponseException {
if (isClosed()) { if (isClosed()) {
throw new OAuthException(EXCEPTION_MESSAGE_CLOSED); throw new OAuthException(EXCEPTION_MESSAGE_CLOSED);
} }
return refreshTokenInner(true);
}
/**
* Inner private method for refreshToken. If 'forceRefresh' is false then only fetch a new token if
* the prior token is not expired, otherwise return the prior token. If 'forceRefresh' is true
* then always fetch a new token.
*
* @param forceRefresh determines whether to force a refresh or check for token expiry
* @return either the prior AccessTokenResponse or a new one
*/
private AccessTokenResponse refreshTokenInner(boolean forceRefresh)
throws OAuthException, IOException, OAuthResponseException {
AccessTokenResponse accessTokenResponse = null;
synchronized (refreshTokenProcessLock) {
AccessTokenResponse lastAccessToken; AccessTokenResponse lastAccessToken;
try { try {
lastAccessToken = storeHandler.loadAccessTokenResponse(handle); lastAccessToken = storeHandler.loadAccessTokenResponse(handle);
@ -314,26 +331,34 @@ public class OAuthClientServiceImpl implements OAuthClientService {
throw new OAuthException("tokenUrl is required but null"); throw new OAuthException("tokenUrl is required but null");
} }
if (forceRefresh || lastAccessToken.isExpired(Instant.now(), tokenExpiresInSeconds)) {
GsonBuilder gsonBuilder = this.gsonBuilder; GsonBuilder gsonBuilder = this.gsonBuilder;
OAuthConnector connector = gsonBuilder == null ? new OAuthConnector(httpClientFactory, extraAuthFields) OAuthConnector connector = gsonBuilder == null ? new OAuthConnector(httpClientFactory, extraAuthFields)
: new OAuthConnector(httpClientFactory, extraAuthFields, gsonBuilder); : new OAuthConnector(httpClientFactory, extraAuthFields, gsonBuilder);
AccessTokenResponse accessTokenResponse = connector.grantTypeRefreshToken(tokenUrl, accessTokenResponse = connector.grantTypeRefreshToken(tokenUrl, lastAccessToken.getRefreshToken(),
lastAccessToken.getRefreshToken(), persistedParams.clientId, persistedParams.clientSecret, persistedParams.clientId, persistedParams.clientSecret, persistedParams.scope,
persistedParams.scope, Boolean.TRUE.equals(persistedParams.supportsBasicAuth)); Boolean.TRUE.equals(persistedParams.supportsBasicAuth));
// The service may not return the refresh token so use the last refresh token otherwise it's not stored. // The service may not return the refresh token so use the last refresh token otherwise it's not stored.
String refreshToken = accessTokenResponse.getRefreshToken(); String refreshToken = accessTokenResponse.getRefreshToken();
if (refreshToken == null || refreshToken.isBlank()) { if (refreshToken == null || refreshToken.isBlank()) {
accessTokenResponse.setRefreshToken(lastAccessToken.getRefreshToken()); accessTokenResponse.setRefreshToken(lastAccessToken.getRefreshToken());
} }
// store it // store it
storeHandler.saveAccessTokenResponse(handle, accessTokenResponse); storeHandler.saveAccessTokenResponse(handle, accessTokenResponse);
accessTokenRefreshListeners.forEach(l -> l.onAccessTokenResponse(accessTokenResponse)); } else {
// No need to refresh the token, just return the last one
return lastAccessToken;
}
}
notifyAccessTokenResponse(accessTokenResponse);
return accessTokenResponse; return accessTokenResponse;
} }
@Override @Override
public synchronized @Nullable AccessTokenResponse getAccessTokenResponse() public @Nullable AccessTokenResponse getAccessTokenResponse()
throws OAuthException, IOException, OAuthResponseException { throws OAuthException, IOException, OAuthResponseException {
if (isClosed()) { if (isClosed()) {
throw new OAuthException(EXCEPTION_MESSAGE_CLOSED); throw new OAuthException(EXCEPTION_MESSAGE_CLOSED);
@ -351,7 +376,7 @@ public class OAuthClientServiceImpl implements OAuthClientService {
if (lastAccessToken.isExpired(Instant.now(), tokenExpiresInSeconds) if (lastAccessToken.isExpired(Instant.now(), tokenExpiresInSeconds)
&& lastAccessToken.getRefreshToken() != null) { && lastAccessToken.getRefreshToken() != null) {
return refreshToken(); return refreshTokenInner(false);
} }
return lastAccessToken; return lastAccessToken;
} }

View File

@ -23,6 +23,7 @@ import java.time.ZoneId;
import java.time.format.DateTimeParseException; import java.time.format.DateTimeParseException;
import java.util.Base64; import java.util.Base64;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.NonNullByDefault;
@ -47,6 +48,8 @@ import com.google.gson.FieldNamingPolicy;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.GsonBuilder; import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializer; import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException; import com.google.gson.JsonSyntaxException;
/** /**
@ -61,6 +64,7 @@ import com.google.gson.JsonSyntaxException;
public class OAuthConnector { public class OAuthConnector {
private static final String HTTP_CLIENT_CONSUMER_NAME = "OAuthConnector"; private static final String HTTP_CLIENT_CONSUMER_NAME = "OAuthConnector";
private static final int TIMEOUT_SECONDS = 10;
protected final HttpClientFactory httpClientFactory; protected final HttpClientFactory httpClientFactory;
@ -87,6 +91,29 @@ public class OAuthConnector {
this.extraFields = extraFields; this.extraFields = extraFields;
gson = gsonBuilder.setDateFormat(DateTimeType.DATE_PATTERN_JSON_COMPAT) gson = gsonBuilder.setDateFormat(DateTimeType.DATE_PATTERN_JSON_COMPAT)
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
.registerTypeAdapter(OAuthResponseException.class,
(JsonDeserializer<OAuthResponseException>) (json, typeOfT, context) -> {
OAuthResponseException result = new OAuthResponseException();
JsonObject jsonObject = json.getAsJsonObject();
JsonElement jsonElement;
jsonElement = jsonObject.get("error");
if (jsonElement != null) {
result.setError(jsonElement.getAsString());
}
jsonElement = jsonObject.get("error_description");
if (jsonElement != null) {
result.setErrorDescription(jsonElement.getAsString());
}
jsonElement = jsonObject.get("error_uri");
if (jsonElement != null) {
result.setErrorUri(jsonElement.getAsString());
}
jsonElement = jsonObject.get("state");
if (jsonElement != null) {
result.setState(jsonElement.getAsString());
}
return result;
})
.registerTypeAdapter(Instant.class, (JsonDeserializer<Instant>) (json, typeOfT, context) -> { .registerTypeAdapter(Instant.class, (JsonDeserializer<Instant>) (json, typeOfT, context) -> {
try { try {
return Instant.parse(json.getAsString()); return Instant.parse(json.getAsString());
@ -273,7 +300,8 @@ public class OAuthConnector {
} }
private Request getMethod(HttpClient httpClient, String tokenUrl) { private Request getMethod(HttpClient httpClient, String tokenUrl) {
Request request = httpClient.newRequest(tokenUrl).method(HttpMethod.POST); Request request = httpClient.newRequest(tokenUrl).method(HttpMethod.POST).timeout(TIMEOUT_SECONDS,
TimeUnit.SECONDS);
request.header(HttpHeader.ACCEPT, "application/json"); request.header(HttpHeader.ACCEPT, "application/json");
request.header(HttpHeader.ACCEPT_CHARSET, StandardCharsets.UTF_8.name()); request.header(HttpHeader.ACCEPT_CHARSET, StandardCharsets.UTF_8.name());
return request; return request;

View File

@ -40,6 +40,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import com.google.gson.GsonBuilder; import com.google.gson.GsonBuilder;
import com.google.gson.JsonSyntaxException;
/** /**
* The {@link OAuthConnectorRFC8628} extends {@link OAuthConnector} to implement * The {@link OAuthConnectorRFC8628} extends {@link OAuthConnector} to implement
@ -292,9 +293,10 @@ public class OAuthConnectorRFC8628 extends OAuthConnector implements AutoCloseab
request.param(PARAM_SCOPE, scopeParameter); request.param(PARAM_SCOPE, scopeParameter);
logger.trace("fetchDeviceCodeResponse() request: {}", request.getURI()); logger.trace("fetchDeviceCodeResponse() request: {}", request.getURI());
String content = null;
try { try {
ContentResponse response = request.send(); ContentResponse response = request.send();
String content = response.getContentAsString(); content = response.getContentAsString();
logger.trace("fetchDeviceCodeResponse() response: {}", content); logger.trace("fetchDeviceCodeResponse() response: {}", content);
if (response.getStatus() == HttpStatus.OK_200) { if (response.getStatus() == HttpStatus.OK_200) {
@ -310,6 +312,9 @@ public class OAuthConnectorRFC8628 extends OAuthConnector implements AutoCloseab
} }
} }
throw new OAuthException("fetchDeviceCodeResponse() error: " + response); throw new OAuthException("fetchDeviceCodeResponse() error: " + response);
} catch (JsonSyntaxException e) {
logger.warn("fetchDeviceCodeResponse() error parsing content:{}", content);
throw new OAuthException("fetchDeviceCodeResponse() error", e);
} catch (InterruptedException | TimeoutException | ExecutionException e) { } catch (InterruptedException | TimeoutException | ExecutionException e) {
throw new OAuthException("fetchDeviceCodeResponse() error", e); throw new OAuthException("fetchDeviceCodeResponse() error", e);
} }
@ -341,9 +346,10 @@ public class OAuthConnectorRFC8628 extends OAuthConnector implements AutoCloseab
request.param(PARAM_DEVICE_CODE, dcr.getDeviceCode()); request.param(PARAM_DEVICE_CODE, dcr.getDeviceCode());
logger.trace("fetchAccessTokenResponse() request: {}", request.getURI()); logger.trace("fetchAccessTokenResponse() request: {}", request.getURI());
String content = null;
try { try {
ContentResponse response = request.send(); ContentResponse response = request.send();
String content = response.getContentAsString(); content = response.getContentAsString();
logger.trace("fetchAccessTokenResponse() response: {}", content); logger.trace("fetchAccessTokenResponse() response: {}", content);
switch (response.getStatus()) { switch (response.getStatus()) {
@ -367,6 +373,9 @@ public class OAuthConnectorRFC8628 extends OAuthConnector implements AutoCloseab
* completed the verification process * completed the verification process
*/ */
return null; return null;
} catch (JsonSyntaxException e) {
logger.warn("fetchAccessTokenResponse() error parsing content:{}", content);
throw new OAuthException("fetchAccessTokenResponse() error", e);
} catch (InterruptedException | TimeoutException | ExecutionException e) { } catch (InterruptedException | TimeoutException | ExecutionException e) {
throw new OAuthException("fetchAccessTokenResponse() error", e); throw new OAuthException("fetchAccessTokenResponse() error", e);
} }