[websocket] Support token authentication through header (#4515)

Signed-off-by: Florian Hotze <dev@florianhotze.com>
pull/4532/head
Florian Hotze 2024-12-31 17:00:40 +01:00 committed by GitHub
parent 139a2e20a1
commit 25ca43d165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 2 deletions

View File

@ -248,6 +248,13 @@ public class AuthFilter implements ContainerRequestFilter {
}
}
public @Nullable SecurityContext getSecurityContext(@Nullable String bearerToken) throws AuthenticationException {
if (bearerToken == null) {
return null;
}
return authenticateBearerToken(bearerToken);
}
public @Nullable SecurityContext getSecurityContext(HttpServletRequest request, boolean allowQueryToken)
throws AuthenticationException, IOException {
String altTokenHeader = request.getHeader(ALT_AUTH_HEADER);

View File

@ -13,8 +13,11 @@
package org.openhab.core.io.websocket;
import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Servlet;
import javax.servlet.ServletException;
@ -43,10 +46,19 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The {@link CommonWebSocketServlet} provides the servlet for WebSocket connections
* The {@link CommonWebSocketServlet} provides the servlet for WebSocket connections.
*
* <p>
* Clients can authorize in two ways:
* <ul>
* <li>By setting <code>org.openhab.ws.accessToken.base64.</code> + base64-encoded access token and the
* {@link CommonWebSocketServlet#WEBSOCKET_PROTOCOL_DEFAULT} in the <code>Sec-WebSocket-Protocol</code> header.</li>
* <li>By providing the access token as query parameter <code>accessToken</code>.</li>
* </ul>
*
* @author Jan N. Klug - Initial contribution
* @author Miguel Álvarez Díez - Refactor into a common servlet
* @author Florian Hotze - Support passing access token through Sec-WebSocket-Protocol header
*/
@NonNullByDefault
@HttpWhiteboardServletName(CommonWebSocketServlet.SERVLET_PATH)
@ -55,6 +67,11 @@ import org.slf4j.LoggerFactory;
public class CommonWebSocketServlet extends WebSocketServlet {
private static final long serialVersionUID = 1L;
public static final String SEC_WEBSOCKET_PROTOCOL_HEADER = "Sec-WebSocket-Protocol";
public static final String WEBSOCKET_PROTOCOL_DEFAULT = "org.openhab.ws.protocol.default";
private static final Pattern WEBSOCKET_ACCESS_TOKEN_PATTERN = Pattern
.compile("org.openhab.ws.accessToken.base64.(?<base64>[A-Za-z0-9+/]*)");
public static final String SERVLET_PATH = "/ws";
public static final String DEFAULT_ADAPTER_ID = EventWebSocketAdapter.ADAPTER_ID;
@ -94,7 +111,31 @@ public class CommonWebSocketServlet extends WebSocketServlet {
if (servletUpgradeRequest == null || servletUpgradeResponse == null) {
return null;
}
if (isAuthorizedRequest(servletUpgradeRequest)) {
String accessToken = null;
String secWebSocketProtocolHeader = servletUpgradeRequest.getHeader(SEC_WEBSOCKET_PROTOCOL_HEADER);
if (secWebSocketProtocolHeader != null) { // if the client sends the Sec-WebSocket-Protocol header
// respond with the default protocol
servletUpgradeResponse.setHeader(SEC_WEBSOCKET_PROTOCOL_HEADER, WEBSOCKET_PROTOCOL_DEFAULT);
// extract the base64 encoded access token from the requested protocols
Matcher matcher = WEBSOCKET_ACCESS_TOKEN_PATTERN.matcher(secWebSocketProtocolHeader);
if (matcher.find() && matcher.group("base64") != null) {
String base64 = matcher.group("base64");
try {
accessToken = new String(Base64.getDecoder().decode(base64));
} catch (IllegalArgumentException e) {
logger.warn("Invalid base64 encoded access token in Sec-WebSocket-Protocol header from {}.",
servletUpgradeRequest.getRemoteAddress());
return null;
}
} else {
logger.warn("Invalid use of Sec-WebSocket-Protocol header from {}.",
servletUpgradeRequest.getRemoteAddress());
return null;
}
}
if (accessToken != null ? isAuthorizedRequest(accessToken) : isAuthorizedRequest(servletUpgradeRequest)) {
String requestPath = servletUpgradeRequest.getRequestURI().getPath();
String pathPrefix = SERVLET_PATH + "/";
boolean useDefaultAdapter = requestPath.equals(pathPrefix) || !requestPath.startsWith(pathPrefix);
@ -122,6 +163,17 @@ public class CommonWebSocketServlet extends WebSocketServlet {
return null;
}
private boolean isAuthorizedRequest(String bearerToken) {
try {
var securityContext = authFilter.getSecurityContext(bearerToken);
return securityContext != null
&& (securityContext.isUserInRole(Role.USER) || securityContext.isUserInRole(Role.ADMIN));
} catch (AuthenticationException e) {
logger.warn("Error handling WebSocket authorization", e);
return false;
}
}
private boolean isAuthorizedRequest(ServletUpgradeRequest servletUpgradeRequest) {
try {
var securityContext = authFilter.getSecurityContext(servletUpgradeRequest.getHttpServletRequest(),