springboot整合websocket
程序员文章站
2024-03-23 10:40:16
...
首先,websocket的jar包引入
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
接着注入websocket的配置
/**
* websocket的配置
*
*/
@Configuration
@EnableWebSocket
public class WebsocketConfig implements WebSocketConfigurer{
private static final Logger logger = LoggerFactory.getLogger(WebsocketConfig.class);
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(orderProcessWebSocketHandler(), "/websocket/orders").addInterceptors(new WebSocketInterceptor()).setAllowedOrigins("*");
logger.info("register websocket success!!!");
}
@Bean("orderWebSocketHandler")
public WebSocketHandler orderProcessWebSocketHandler() {
return new OrderWebSocketHandler();
}
}
配置类里声明了websocket的处理类,并将其与url和拦截器绑定起来。
public class WebSocketInterceptor implements HandshakeInterceptor{
private static final String _WEBSOCKET_USER_IDENTIFICATION = "USER_IDENTIFICATION";
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) throws Exception {
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest req = (ServletServerHttpRequest) request;
String id = req.getServletRequest().getHeader("shopId");
if (null == id) {
id = req.getServletRequest().getParameter("shopId");
}
if (null != id) {
attributes.put(_WEBSOCKET_USER_IDENTIFICATION, id);
return true;
}
}
return false;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Exception exception) {
}
}
拦截器实现HandshakeInterceptor接口有握手前和握手后两个方法分别在建立连接前后进行处理。我们这里对websocket通过握手前的beforeHandshake来进行预处理,通过携带的参数shopId来对连接进行标志。
接着看我们websocket的处理类:
@Component
public class OrderWebSocketHandler extends BaseWebsocketHandler {
private static final ConcurrentHashMap<String, List<WebSocketSession>> orderClients = new ConcurrentHashMap<>();
@Override
public ConcurrentHashMap<String, List<WebSocketSession>> getMap() {
return orderClients;
}
}
orderClients 存储了所有对应的websocket连接,key值就是我们之前提到了连接标志shopId,value采用list形式是因为我们的项目支持多点登陆,一个shopId可以在多处登录从而建立多个不同websocket连接。
具体如建立连接、连接异常、连接关闭、发送消息等的websocket处理放在了OrderWebSocketHandler 的父类BaseWebsocketHandler 中。
/**
* websocket提供了客户端连接,关闭,错误,发送等方法
*/
public class BaseWebsocketHandler extends TextWebSocketHandler{
private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketHandler.class);
/**
* 用户标识
*/
private final String _WEBSOCKET_USER_IDENTIFICATION = "USER_IDENTIFICATION";
private ConcurrentHashMap<String, List<WebSocketSession>> map = new ConcurrentHashMap<>();
public BaseWebsocketHandler() {
super();
}
/**
* 获取用户标识,获取websocket对象的map集合
* @param webSession
* @return
*/
private String getUserToken(WebSocketSession webSession) {
try {
String userToken = (String) webSession.getAttributes().get(_WEBSOCKET_USER_IDENTIFICATION);
return userToken;
} catch (Exception e) {
logger.error("获取用户标识报错==》{}"+e);
return null;
}
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String id = getUserToken(session);
if (id != null) {
WebSocketUtils.addOnline(getMap(), id, session);
// session.sendMessage(new TextMessage("we build connect success!!!"));
logger.info("websocket成功建立链接绑定用户{}信息", id);
}
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
WebSocketUtils.receive(getUserToken(session), message.getPayload());
try {
//返回给客户端
session.sendMessage(new TextMessage("server:"+message));
} catch (Exception e) {
logger.error("发送给客户端消息报错+exception==={}"+e);
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
logger.info("websocket链接出现异常,message={}, exception={}",exception.getMessage(), exception);
if (session.isOpen()) {
session.close();
}
WebSocketUtils.remove(getMap(), getUserToken(session), session);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
logger.info("websocket链接关闭status={}"+status);
WebSocketUtils.remove(getMap(), getUserToken(session), session);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
/**
* 发送消息给用户
* @param message
* @return
*/
public boolean sendMessageToUsers(ToClientMessage message) {
return WebSocketUtils.sendMessage(getMap(), message);
}
/**
* 发送消息给用户
* @param message
* @return
*/
public <T> boolean sendMessageToUsers(String id, T message) {
return WebSocketUtils.sendMessage(getMap(), id, message);
}
public ConcurrentHashMap<String, List<WebSocketSession>> getMap() {
return map;
}
public void setMap(ConcurrentHashMap<String, List<WebSocketSession>> map) {
this.map = map;
}
}
看到这里大致应该明白了websokcet的使用流程,我们把具体的websocket抽象出来放在父类中,是为了方便扩展,如果我们需要增更多类型的websocket连接,只需要增加一个子类即可,可以从OrderWebSocketHandler看出子类中内容及其简单。
最后我们还将一个方法放到了工具类WebSocketUtils:
public class WebSocketUtils {
private static final Logger logger = LoggerFactory.getLogger(WebSocketUtils.class);
/**
* 获取某个绑定用户的websocket
* @param id
* @return
*/
public static List<WebSocketSession> getSessions(ConcurrentHashMap<String, List<WebSocketSession>> map, String id) {
if (null == map || 0 == map.size()) return new CopyOnWriteArrayList<>();
List<WebSocketSession> list = map.get(id);
if (null == list) return new CopyOnWriteArrayList<>();
return list;
}
public static Set<String> geOnlinetAllUser(ConcurrentHashMap<String, List<WebSocketSession>> map) {
return map.keySet();
}
public static void addOnline(ConcurrentHashMap<String, List<WebSocketSession>> map, String id, WebSocketSession session) {
List<WebSocketSession> list = getSessions(map, id);
list.add(session);
map.put(id, list);
}
/**
* 删除连接
* @param id
*/
public static void remove(ConcurrentHashMap<String, List<WebSocketSession>> map, String id, WebSocketSession session) {
if (null != id) {
List<WebSocketSession> list = map.get(id);
if (null != list) {
list.remove(session);
}
} else {
map.values().forEach(list -> {
list.remove(session);
});
}
}
/**
* 接收消息
* @param id
* @param message
*/
public static void receive(String id, String message) {
logger.info("收到id={}发来的消息message={}", id, message);
}
/**
* 发送消息
* @param id
* @param message
* @return
*/
public static boolean sendMessage(ConcurrentHashMap<String, List<WebSocketSession>> map, ToClientMessage message) {
logger.info("准备给{}发送消息", message.getId());
List<WebSocketSession> list = getSessions(map, message.getId());
list.forEach(session -> {
try {
TextMessage text = new TextMessage(JSON.toJSONString(message.getMessage()));
session.sendMessage(text);
} catch (IOException e) {
logger.error("发送消息出错id={}, session={}, message={}, exception={}", message.getId(), session.getRemoteAddress().getAddress().getHostAddress(), message, e);
}
logger.info("发送消息id={}, session={}, message={}", message.getId(), session.getRemoteAddress().getAddress().getHostAddress(), message);
});
return true;
}
public static <T> boolean sendMessage(ConcurrentHashMap<String, List<WebSocketSession>> map, String id, T message) {
logger.info("准备给{}发送消息", id);
List<WebSocketSession> list = getSessions(map, id);
list.forEach(session -> {
try {
TextMessage textMessage = new TextMessage(JSON.toJSONString(message));
session.sendMessage(textMessage);
} catch (IOException e) {
logger.error("发送消息出错id={}, session={}, message={}, exception={}", id, session.getRemoteAddress().getAddress().getHostAddress(), message, e);
}
logger.info("发送消息id={}, session={}, message={}", id, session.getRemoteAddress().getAddress().getHostAddress(), message);
});
return true;
}
}
至此,springboot整合websocket的代码完成。