WebSocket 单体分布式
1.架构
采用单体架构(需要微服务架构请自行拆解),支持分布式部署 技术栈:nacos+redis+websocket
1.1 项目结构
websocket-new-module
└─src
└─main
├─java
│ └─com
│ └─yan
│ ├─config
│ ├─constant
│ ├─domain
│ ├─endpoint
│ └─service
└─resources
1.2 pom 代码内容
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.9</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.yan</groupId>
<artifactId>websocket-new-module</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>websocket-new-module</name>
<description>websocket-new-module</description>
<properties>
<java.version>1.8</java.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<spring-boot.version>2.7.18</spring-boot.version>
<project-module.version>0.0.1-SNAPSHOT</project-module.version>
<spring-cloud.version>2021.0.5</spring-cloud.version>
<alibaba-cloud.version>2021.0.5.0</alibaba-cloud.version>
<hutool.version>5.8.20</hutool.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<!-- Spring Cloud Version BOM -->
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-dependencies</artifactId>
<version>${spring-cloud.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-alibaba-dependencies</artifactId>
<version>${alibaba-cloud.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!--nacos-->
<!--用于加载bootstrap-->
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-bootstrap</artifactId>
</dependency>
<!-- Spring Cloud LoadBalancer 依赖 -->
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-loadbalancer</artifactId>
</dependency>
<!-- Spring Cloud Starter Alibaba Nacos Discovery -->
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId>
</dependency>
<!-- Spring Cloud Starter Alibaba Nacos Config -->
<dependency>
<groupId>com.alibaba.cloud</groupId>
<artifactId>spring-cloud-starter-alibaba-nacos-config</artifactId>
</dependency>
<!--nacos-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>${hutool.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
2.公共接口配置
package com.yan.service;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.extra.spring.SpringUtil;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.springframework.core.env.Environment;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
/**
* @Author yan
* @Date 2024/9/22 上午10:46:15
* @Description
*/
public interface AbstractBean {
@Data
@Accessors(chain = true)
@NoArgsConstructor
@AllArgsConstructor
class LogBean {
private org.slf4j.Logger logger;
private Class<?> aClass;
}
@JsonIgnore
default LogBean getLogBean() {
LogBean logBean = new LogBean()
.setAClass(this.getClass())
.setLogger(LoggerFactory.getLogger(this.getClass()));
return logBean;
}
@JsonIgnore
default Logger getLogger() {
return getLogBean().getLogger();
}
@JsonIgnore
default Class<?> getAClass() {
return getLogBean().getAClass();
}
/**
* 初始化
*/
@PostConstruct
default void init() {
LogBean logBean = getLogBean();
if (!isProd()) {
logBean.getLogger().info("init {} ...", logBean.getAClass().getSimpleName());
}
}
/**
* 销毁
*/
@PreDestroy
default void destroy() {
LogBean logBean = getLogBean();
if (!isProd()) {
logBean.getLogger().info("destroy {} ...", logBean.getAClass().getSimpleName());
}
}
default boolean isProd() {
Environment env = SpringUtil.getBean(Environment.class);
String active = env.getProperty("spring.profiles.active");
return ObjectUtil.equals("prod", active);
}
default boolean isTraceEnabled() {
return getLogger().isTraceEnabled();
}
default void trace(String var1) {
if (!isProd()) {
getLogger().trace(var1);
}
}
default void trace(String var1, Object var2) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default void trace(String var1, Object var2, Object var3) {
if (!isProd()) {
getLogger().trace(var1, var2, var3);
}
}
default void trace(String var1, Object... var2) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default void trace(String var1, Throwable var2) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default boolean isTraceEnabled(Marker var1) {
return getLogger().isTraceEnabled(var1);
}
default void trace(Marker var1, String var2) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default void trace(Marker var1, String var2, Object var3) {
if (!isProd()) {
getLogger().trace(var1, var2, var3);
}
}
default void trace(Marker var1, String var2, Object var3, Object var4) {
if (!isProd()) {
getLogger().trace(var1, var2, var3, var4);
}
}
default void trace(Marker var1, String var2, Object... var3) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default void trace(Marker var1, String var2, Throwable var3) {
if (!isProd()) {
getLogger().trace(var1, var2);
}
}
default boolean isDebugEnabled() {
return getLogger().isDebugEnabled();
}
default void debug(String var1) {
if (!isProd()) {
getLogger().debug(var1);
}
}
default void debug(String var1, Object var2) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(String var1, Object var2, Object var3) {
if (!isProd()) {
getLogger().debug(var1, var2, var3);
}
}
default void debug(String var1, Object... var2) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(String var1, Throwable var2) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default boolean isDebugEnabled(Marker var1) {
return getLogger().isDebugEnabled(var1);
}
default void debug(Marker var1, String var2) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(Marker var1, String var2, Object var3) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(Marker var1, String var2, Object var3, Object var4) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(Marker var1, String var2, Object... var3) {
if (!isProd()) {
getLogger().debug(var1, var2);
}
}
default void debug(Marker var1, String var2, Throwable var3) {
getLogger().debug(var1, var2);
}
default boolean isInfoEnabled() {
return getLogger().isInfoEnabled();
}
default void info(String var1) {
if (!isProd()) {
getLogger().info(var1);
}
}
default void info(String var1, Object var2) {
if (!isProd()) {
getLogger().info(var1, var2);
}
}
default void info(String var1, Object var2, Object var3) {
if (!isProd()) {
getLogger().info(var1, var2, var3);
}
}
default void info(String var1, Object... var2) {
if (!isProd()) {
getLogger().info(var1, var2);
}
}
default void info(String var1, Throwable var2) {
if (!isProd()) {
getLogger().info(var1, var2);
}
}
default boolean isInfoEnabled(Marker var1) {
return getLogger().isInfoEnabled(var1);
}
default void info(Marker var1, String var2) {
if (!isProd()) {
getLogger().info(var1, var2);
}
}
default void info(Marker var1, String var2, Object var3) {
if (!isProd()) {
getLogger().info(var1, var2, var3);
}
}
default void info(Marker var1, String var2, Object var3, Object var4) {
if (!isProd()) {
getLogger().info(var1, var2, var3, var4);
}
}
default void info(Marker var1, String var2, Object... var3) {
if (!isProd()) {
getLogger().info(var1, var2, var3);
}
}
default void info(Marker var1, String var2, Throwable var3) {
if (!isProd()) {
getLogger().info(var1, var2, var3);
}
}
default boolean isWarnEnabled() {
return getLogger().isWarnEnabled();
}
default void warn(String var1) {
if (!isProd()) {
getLogger().warn(var1);
}
}
default void warn(String var1, Object var2) {
if (!isProd()) {
getLogger().warn(var1, var2);
}
}
default void warn(String var1, Object... var2) {
if (!isProd()) {
getLogger().warn(var1, var2);
}
}
default void warn(String var1, Object var2, Object var3) {
if (!isProd()) {
getLogger().warn(var1, var2, var3);
}
}
default void warn(String var1, Throwable var2) {
if (!isProd()) {
getLogger().warn(var1, var2);
}
}
default boolean isWarnEnabled(Marker var1) {
return getLogger().isWarnEnabled(var1);
}
default void warn(Marker var1, String var2) {
if (!isProd()) {
getLogger().warn(var1, var2);
}
}
default void warn(Marker var1, String var2, Object var3) {
if (!isProd()) {
getLogger().warn(var1, var2, var3);
}
}
default void warn(Marker var1, String var2, Object var3, Object var4) {
if (!isProd()) {
getLogger().warn(var1, var2, var3, var4);
}
}
default void warn(Marker var1, String var2, Object... var3) {
if (!isProd()) {
getLogger().warn(var1, var2, var3);
}
}
default void warn(Marker var1, String var2, Throwable var3) {
if (!isProd()) {
getLogger().warn(var1, var2, var3);
}
}
default boolean isErrorEnabled() {
return getLogger().isErrorEnabled();
}
default void error(String var1) {
getLogger().error(var1);
}
default void error(String var1, Object var2) {
getLogger().error(var1, var2);
}
default void error(String var1, Object var2, Object var3) {
getLogger().error(var1, var2, var3);
}
default void error(String var1, Object... var2) {
getLogger().error(var1, var2);
}
default void error(String var1, Throwable var2) {
getLogger().error(var1, var2);
}
default boolean isErrorEnabled(Marker var1) {
return getLogger().isErrorEnabled(var1);
}
default void error(Marker var1, String var2) {
getLogger().error(var1, var2);
}
default void error(Marker var1, String var2, Object var3) {
getLogger().error(var1, var2, var3);
}
default void error(Marker var1, String var2, Object var3, Object var4) {
getLogger().error(var1, var2, var3, var4);
}
default void error(Marker var1, String var2, Object... var3) {
getLogger().error(var1, var2, var3);
}
default void error(Marker var1, String var2, Throwable var3) {
getLogger().error(var1, var2, var3);
}
}
3.Redis配置
package com.yan.config;
import cn.hutool.core.date.DatePattern;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.deser.std.DateDeserializers;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import com.fasterxml.jackson.databind.ser.std.DateSerializer;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateDeserializer;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateTimeDeserializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateSerializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateTimeSerializer;
import com.yan.service.AbstractBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import java.text.DateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Date;
/**
* @Author yan
* @Date 2025/2/5 15:02:07
* @Description
*/
@Configuration
//@EnableAspectJAutoProxy
//@EnableCaching // 开启Spring Redis Cache,使用注解驱动缓存机制
public class RedisConfig implements AbstractBean {
@Bean
@Primary
@SuppressWarnings(value = {"unchecked", "rawtypes"})
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
return initRedisTemplate(connectionFactory);
}
/**
* 初始化RedisTemplate
*
* @param connectionFactory
* @return
*/
public RedisTemplate<String, Object> initRedisTemplate(RedisConnectionFactory connectionFactory) {
info("init RedisTemplate");
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
// 使用Jackson进行JSON序列化
GenericJackson2JsonRedisSerializer serializer = new GenericJackson2JsonRedisSerializer(initObjectMapper(null));
// 使用StringRedisSerializer来序列化和反序列化redis的key值
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
template.setKeySerializer(stringRedisSerializer);
template.setValueSerializer(serializer);
// Hash的key也采用StringRedisSerializer的序列化方式
template.setHashKeySerializer(stringRedisSerializer);
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
/**
* 初始化ObjectMapper
* @return
*/
public ObjectMapper initObjectMapper(ObjectMapper objectMapper) {
objectMapper = ObjectUtil.isEmpty(objectMapper)? new ObjectMapper():objectMapper;
// 设置可见性
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
// 序列化后添加类信息(不配置,序列化后就是一个Json字符串)
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
// 将日期序列化为可读字符串而不是时间戳
objectMapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
// 设置时间模块(格式化,不设置,则输出默认格式)
JavaTimeModule timeModule = new JavaTimeModule();
// LocalDateTime
String zoneId = "Asia/Shanghai";
timeModule.addSerializer(LocalDateTime.class, new LocalDateTimeSerializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATETIME_PATTERN).withZone(ZoneId.of(zoneId))));
timeModule.addDeserializer(LocalDateTime.class, new LocalDateTimeDeserializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATETIME_PATTERN).withZone(ZoneId.of(zoneId))));
// LocalDate
timeModule.addSerializer(LocalDate.class, new LocalDateSerializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATE_PATTERN).withZone(ZoneId.of(zoneId))));
timeModule.addDeserializer(LocalDate.class, new LocalDateDeserializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATE_PATTERN).withZone(ZoneId.of(zoneId))));
// Date
timeModule.addSerializer(Date.class, new DateSerializer(false, DateFormat.getDateInstance()));
timeModule.addDeserializer(Date.class, new DateDeserializers.DateDeserializer());
// 设置自定义时间模块
objectMapper.registerModule(timeModule);
return objectMapper;
}
}
4.WebSocket Nacos 配置Bean
package com.yan.config;
import com.alibaba.cloud.nacos.NacosDiscoveryProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
/**
* @Author yan
* @Date 2024/3/7 0007 17:17
* @Description
*/
@Configuration
@EnableWebSocket
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
// 注册到Nacos
@Bean
public NacosDiscoveryProperties nacosProperties() {
return new NacosDiscoveryProperties();
}
}
4.1常量
package com.yan.constant;
/**
* @Author yan
* @Date 2025/2/4 15:56:57
* @Description
*/
public interface WebSocket {
String WS_INSTANCE = "ws:instance:";
String WS_USER = "ws:user:";
String WS_MSG ="ws:msg:";
}
4.2消息类
package com.yan.domain;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import java.util.Date;
/**
* @Author yan
* @Date 2025/2/4 1:48:04
* WebSocket 消息实体类
*/
@Data @NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
public class Message {
// 消息类型常量(可根据业务扩展)
public static final String TYPE_TEXT = "text";
public static final String TYPE_IMAGE = "image";
public static final String TYPE_SYSTEM = "system";
// 消息发送者ID(例如用户ID)
private String senderId;
// 消息接收者ID(用户ID、群组ID或频道ID)
private String targetId;
// 消息内容(文本、JSON或Base64编码的二进制数据)
private String content;
// 消息类型(text/image/system...)
private String type = TYPE_TEXT;
// 消息时间戳
private Date timestamp = new Date();
}
4.3 WebSocketEndpoint(核心代码)
package com.yan.endpoint;
import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.cloud.nacos.NacosDiscoveryProperties;
import com.yan.domain.Message;
import com.yan.constant.WebSocket;
import com.yan.service.AbstractBean;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.annotation.Resource;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
@ServerEndpoint("/ws/{userId}")
@Component
public class WebSocketEndpoint implements AbstractBean {
// 本地存储当前实例的在线用户(线程安全)
private static final Set<String> LOCAL_USER_IDS = ConcurrentHashMap.newKeySet();
// 本地会话存储(线程安全)
private static final ConcurrentHashMap<String, Session> SESSION_MAP = new ConcurrentHashMap<>();
// 静态依赖注入
@Resource
private RedisTemplate<String, String> redisTemplate = getRedisTemplate();
@Resource
private NacosDiscoveryProperties nacosProperties = getNacosDiscoveryProperties();
// --- 生命周期管理 ---
@PostConstruct
public void init() {
redisTemplate = getRedisTemplate();
}
@PreDestroy
public void destroy() {
// 清理Redis中当前实例的所有数据
redisTemplate.delete(WebSocket.WS_INSTANCE + getInstanceId());
LOCAL_USER_IDS.forEach(userId ->
redisTemplate.delete(WebSocket.WS_USER + userId)
);
}
// 当前实例ID(格式:ip:port)
public static String getInstanceId(){
NacosDiscoveryProperties nacosDiscoveryProperties = getNacosDiscoveryProperties();
return nacosDiscoveryProperties.getIp() + ":" + nacosDiscoveryProperties.getPort();
}
public static RedisTemplate getRedisTemplate() {
return SpringUtil.getBean(RedisTemplate.class);
}
public static NacosDiscoveryProperties getNacosDiscoveryProperties() {
return SpringUtil.getBean(NacosDiscoveryProperties.class);
}
// --- WebSocket事件处理 ---
@OnOpen
public void onOpen(Session session, @PathParam("userId") String userId) {
info("WebSocket连接已建立,用户ID:{}", userId);
// 存储会话和用户信息
SESSION_MAP.put(session.getId(), session);
LOCAL_USER_IDS.add(userId);
// 更新Redis映射(用户->实例)
redisTemplate.opsForValue().set(WebSocket.WS_USER + userId, getInstanceId());
// 存储实例的会话信息(实例->会话ID:用户ID)
String sessionId = session.getId();
String key = WebSocket.WS_INSTANCE + getInstanceId();
redisTemplate.opsForHash().put(key, sessionId, userId);
}
@OnClose
public void onClose(Session session, @PathParam("userId") String userId) {
info("WebSocket连接已关闭,用户ID:{}", userId);
// 清理本地存储
SESSION_MAP.remove(session.getId());
LOCAL_USER_IDS.remove(userId);
// 清理Redis数据
redisTemplate.opsForHash().delete(WebSocket.WS_INSTANCE + getInstanceId(), session.getId());
redisTemplate.delete(WebSocket.WS_USER + userId);
}
@OnMessage
public void onMessage(String message, Session session, @PathParam("userId") String senderId) {
info("收到消息: {}", message);
try {
Message msg = JSONUtil.toBean(message, Message.class);
String targetUserId = msg.getTargetId();
// 查询目标用户所在实例
String targetInstanceId = redisTemplate.opsForValue().get(WebSocket.WS_USER + targetUserId);
if (targetInstanceId == null) {
handleOfflineMessage(msg);
return;
}
if (targetInstanceId.equals(getInstanceId())) {
sendLocalMessage(targetUserId, msg); // 本地发送
} else {
sendCrossInstanceMessage(targetInstanceId, msg); // 跨实例发送
}
} catch (Exception e) {
// 异常处理逻辑
}
}
// --- 消息发送方法 ---
public void sendLocalMessage(String userId, Message msg) {
info("[发送本地消息]sendLocalMessage");
Session session = findSessionByUserId(userId);
if (session != null && session.isOpen()) {
String toJsonStr = JSONUtil.toJsonStr(msg);
//String compressed = compress(message);
session.getAsyncRemote().sendText(toJsonStr);
}
}
private void sendCrossInstanceMessage(String targetInstanceId, Message msg) {
info("[发送跨实例消息]sendCrossInstanceMessage");
redisTemplate.convertAndSend(WebSocket.WS_MSG + targetInstanceId, msg);
}
// --- 辅助方法 ---
private static Session findSessionByUserId(String userId) {
RedisTemplate<String, String> redisTemplate = getRedisTemplate();
return SESSION_MAP.values().stream()
.filter(session -> {
String sessionUserId = (String) redisTemplate.opsForHash()
.get(WebSocket.WS_INSTANCE + getInstanceId(), session.getId());
return userId.equals(sessionUserId);
})
.findFirst()
.orElse(null);
}
private void handleOfflineMessage(Message msg) {
// 离线消息存储逻辑(如存入数据库)
}
// 处理来自Redis的跨实例消息
public void handleRedisMessage(String messageJson) {
Message msg = JSONUtil.toBean(messageJson, Message.class);
Session session = findSessionByUserId(msg.getTargetId());
if (session != null) {
info("[处理跨实例消息]handleRedisMessage");
session.getAsyncRemote().sendText(messageJson);
}
}
// 获取当前实例在线用户列表
public static Set<String> getLocalOnlineUsers() {
return Collections.unmodifiableSet(LOCAL_USER_IDS);
}
}
5.Redis跨实例处理消息
采用Redis 频道订阅方式处理 跨实例消息 通过实例ID判断是否是当前实例需要处理的消息
package com.yan.config;
import cn.hutool.extra.spring.SpringUtil;
import com.yan.constant.WebSocket;
import com.yan.endpoint.WebSocketEndpoint;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
/**
* @Author yan
* @Date 2025/2/5 13:02:50
* @Description
*/
@Configuration
public class RedisMessageConfig {
@Bean
public RedisMessageListenerContainer redisMessageListenerContainer(
RedisConnectionFactory connectionFactory
) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
// 订阅所有 ws:msg:* 频道
container.addMessageListener(new MessageListener() {
@Override
public void onMessage(Message message, byte[] pattern) {
String channel = new String(message.getChannel());
String msg = new String(message.getBody());
SpringUtil.getBean(WebSocketEndpoint.class).handleRedisMessage(msg);
}
}, new PatternTopic(WebSocket.WS_MSG+"*"));
return container;
}
}
6. yml配置文件
6.1 bootstrap.yml
spring:
application:
name: websocket
profiles:
active: dev
main:
allow-circular-references: true
cloud:
nacos:
discovery:
service: ${spring.application.name}
server-addr: ${nacos.server-addr}
username: ${nacos.username}
password: ${nacos.username}
config:
server-addr: ${nacos.server-addr}
username: ${nacos.username}
password: ${nacos.username}
refresh-enabled: true
refreshable-dataids: ${spring.cloud.nacos.discovery.service:${spring.application.name}}-${spring.profiles.active:test}.${spring.cloud.nacos.config.file-extension:yaml}
nacos:
server-addr: 127.0.0.1:8848
username: nacos
password: nacos
server:
port: 18900
servlet:
context-path: /websocket
6.2 bootstrap-dev.yml
nacos:
server-addr: 192.168.3.22:8848
username: nacos
password: nacos
spring:
redis:
host: 192.168.3.22
database: 1
6.3 bootstrap-test.yml
nacos:
server-addr: 192.168.3.22:8848
username: nacos
password: nacos
spring:
redis:
host: 192.168.3.22
database: 1
server:
port: 18800
7.可支持gateway网关扩展(如使用NGINX为网关请忽略)
7.1 gateway网关 yml配置
#在配置nacos实现聚合swagger时由于
#/v3/api-docs引起的网关路由找不到具体服务器的api-docs文档,导致swagger生成文档失败
#解决方法:可以直接在application.yaml配置后,网关路由正常
server:
servlet:
context-path: /test
springfox:
documentation:
open-api:
v3:
path: /v3/api-docs
spring:
cloud:
nacos:
server-addr: 127.0.0.1:8848
gateway:
discovery:
#docker启动时远程调用失败 需要将docker所在服务器地址设置一下
# ip: ${LOCAL_IP_ADDRESS:192.168.200.128}
locator:
# 服务名小写转换
lower-case-service-id: true
# 字段映射转发 满足 http://网关地址:网关端口/服务名/服务接口地址
# 自动转为 http://服务名/服务接口地址
# 商用false
# 这里需要注意的是,如果我们的配置了 server.servlet.context-path 属性,这会导致自动路由失败的问题,因此我们需要做如下两个修改:
enabled: true
# 重写过滤链,解决项目设置了 server.servlet.context-path 导致 locator.enabled=true 默认路由策略404的问题
# 现实生产可使用nginx先处理掉 server.servlet.context-path,再通过gateway处理
filters[0]: PreserveHostHeader
routes:
- id: gateway
uri: no://op
predicates:
- Path=${server.servlet.context-path}/websocket,${server.servlet.context-path},false
filters:
#本地
- RedirectTo=302, http://127.0.0.1:${server.port}${server.servlet.context-path}/home
- id: static-resource
uri: lb://gateway
order: 1
predicates:
- Path=${server.servlet.context-path}/assert/**
filters:
- StripPrefix=1
- id: websocket
uri: lb://websocket
order: 2
metadata:
summary: websocket服务
description: websocket服务
document-uri: /doc.html
predicates:
- Path=${server.servlet.context-path}/websocket/**
filters:
- StripPrefix=1
7.3 gateway网关 跨域配置处理(注意nginx也需要处理跨域)
package com.yan.gateway.filter;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.json.JSONUtil;
import com.yan.abstractinterface.bean.AbstractBean;
import com.yan.utils.json.JsonUtils;
import com.yan.utils.object.ObjectUtils;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.*;
import java.util.stream.Collectors;
/**
* @Author yan
* @Date 2024/11/14 下午12:34:07
* @Description
*/
@Slf4j
public class CorsResponseHeaderFilter implements GlobalFilter, Ordered, AbstractBean {
@Override
public int getOrder() {
// 即待处理完响应体后接着处理响应头
// 比 NettyWriteResponseFilter 先执行,确保响应头是可修改的
return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 2;
}
public Mono getFromRunnableMono(ServerWebExchange exchange, GatewayFilterChain chain) {
Mono<Object> fromRunnable = Mono.fromRunnable(() -> {
try {
ServerHttpResponse response = exchange.getResponse();
HttpHeaders headers = response.getHeaders();
info("处理前的header头: {}", JSONUtil.toJsonStr(headers));
// 创建新Header Map(避免修改不可变Entry)
Map<String, List<String>> newHeaders = new HashMap<>();
headers.forEach((key, values) -> {
// 去重处理
List<String> distinctValues = values.stream()
.distinct()
.collect(Collectors.toList());
// 保留第一个值(如果需特殊处理)
/*if (distinctValues.size() > 1) {
distinctValues = CollUtil.newArrayList(distinctValues.get(0));
}*/
newHeaders.put(key, distinctValues);
});
// 清空并重置Header
headers.clear();
headers.putAll(newHeaders);
info("处理后的header头: {}", JSONUtil.toJsonStr(headers));
} catch (Exception e) {
error("去除重复请求头异常: ", e);
}
});
return fromRunnable;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
try {
/**
解决 ReadOnlyHttpHeaders 不可修改的问题
*/
ServerHttpResponse response = exchange.getResponse();
ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(response) {
@Override
public HttpHeaders getHeaders() {
HttpHeaders headers = new HttpHeaders();
headers.putAll(super.getHeaders());
return headers;
}
};
ServerWebExchange mutatedExchange = exchange.mutate()
.response(decoratedResponse)
.build();
//Mono<Void> distinctResponseHeaders = getDeferMono(exchange, chain);
Mono<Void> distinctResponseHeaders = getFromRunnableMono(mutatedExchange, chain);
return chain.filter(mutatedExchange).then(distinctResponseHeaders);
} catch (Exception e) {
error("去除重复请求头error,{},{}", e);
}
return chain.filter(exchange);
}
}
8.测试
8.1以实例测试
注:需要用户在线可测试成功 离线消息未处理配置(根据实际情况可能需要入库)
ws://127.0.0.1:18800/websocket/ws/1
发送消息:{"senderId":"1","targetId":"2","content":"110","type":"text"}
ws://127.0.0.1:18900/websocket/ws/2
发送消息:{"senderId":"2","targetId":"1","content":"220","type":"text"}
8.2 以网关测试
ws://127.0.0.1:10010/test/websocket/ws/1
ws://127.0.0.1:10010/test/websocket/ws/2
用户连接时网关会处理跨域请求头并打印前后差异
注:可通过不断连接断开使得俩个用户处于不同实例上进行测试
连接到不同实例时Redis key如下所示
用户1发送消息给用户2 测试如下为成功: