package com.ruoyi.websocket.config; import java.util.Collections; import java.util.Enumeration; import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.stereotype.Component; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; import com.ruoyi.common.constant.Constants; import com.ruoyi.common.core.domain.model.LoginUser; import com.ruoyi.common.utils.StringUtils; import com.ruoyi.framework.web.service.TokenService; /** * WebSocket 握手拦截器:校验 JWT Token(支持 query 参数 token 或 Authorization header) * * @author ruoyi */ @Component public class WebSocketHandshakeHandler implements HandshakeInterceptor { private final TokenService tokenService; public WebSocketHandshakeHandler(TokenService tokenService) { this.tokenService = tokenService; } @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { if (!(request instanceof ServletServerHttpRequest)) { return false; } ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request; HttpServletRequest req = servletRequest.getServletRequest(); String token = req.getParameter("token"); if (StringUtils.isEmpty(token)) { String auth = req.getHeader("Authorization"); if (StringUtils.isNotEmpty(auth) && auth.startsWith(Constants.TOKEN_PREFIX)) { token = auth.substring(Constants.TOKEN_PREFIX.length()).trim(); } } if (StringUtils.isEmpty(token)) { return false; } final String tokenFinal = token; HttpServletRequest wrappedReq = new HttpServletRequestWrapper(req) { @Override public String getHeader(String name) { if ("Authorization".equalsIgnoreCase(name)) { return Constants.TOKEN_PREFIX + tokenFinal; } return super.getHeader(name); } @Override public Enumeration getHeaders(String name) { if ("Authorization".equalsIgnoreCase(name)) { return Collections.enumeration(Collections.singletonList(Constants.TOKEN_PREFIX + tokenFinal)); } return super.getHeaders(name); } }; LoginUser loginUser = tokenService.getLoginUser(wrappedReq); if (loginUser == null) { return false; } attributes.put("loginUser", loginUser); return true; } @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { } }