WebSocket 单体分布式

1.架构

采用单体架构(需要微服务架构请自行拆解),支持分布式部署 技术栈:nacos+redis+websocket

1.1 项目结构

websocket-new-module
└─src
    └─main
       ├─java
       │  └─com
       │      └─yan
       │          ├─config
       │          ├─constant
       │          ├─domain
       │          ├─endpoint
       │          └─service
       └─resources

1.2 pom 代码内容

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>2.7.9</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.yan</groupId>
	<artifactId>websocket-new-module</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>websocket-new-module</name>
	<description>websocket-new-module</description>
	<properties>
		<java.version>1.8</java.version>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>

		<spring-boot.version>2.7.18</spring-boot.version>
		<project-module.version>0.0.1-SNAPSHOT</project-module.version>
		<spring-cloud.version>2021.0.5</spring-cloud.version>
		<alibaba-cloud.version>2021.0.5.0</alibaba-cloud.version>

		<hutool.version>5.8.20</hutool.version>

	</properties>

	<dependencyManagement>
		<dependencies>
			<dependency>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-dependencies</artifactId>
				<version>${spring-boot.version}</version>
				<type>pom</type>
				<scope>import</scope>
			</dependency>
			<!-- Spring Cloud Version BOM -->
			<dependency>
				<groupId>org.springframework.cloud</groupId>
				<artifactId>spring-cloud-dependencies</artifactId>
				<version>${spring-cloud.version}</version>
				<type>pom</type>
				<scope>import</scope>
			</dependency>
			<dependency>
				<groupId>com.alibaba.cloud</groupId>
				<artifactId>spring-cloud-alibaba-dependencies</artifactId>
				<version>${alibaba-cloud.version}</version>
				<type>pom</type>
				<scope>import</scope>
			</dependency>
		</dependencies>
	</dependencyManagement>

	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-websocket</artifactId>
		</dependency>

		<!--nacos-->
		<!--用于加载bootstrap-->
		<dependency>
			<groupId>org.springframework.cloud</groupId>
			<artifactId>spring-cloud-starter-bootstrap</artifactId>
		</dependency>
		<!-- Spring Cloud LoadBalancer 依赖 -->
		<dependency>
			<groupId>org.springframework.cloud</groupId>
			<artifactId>spring-cloud-starter-loadbalancer</artifactId>
		</dependency>
		<!-- Spring Cloud Starter Alibaba Nacos Discovery -->
		<dependency>
			<groupId>com.alibaba.cloud</groupId>
			<artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId>
		</dependency>
		<!-- Spring Cloud Starter Alibaba Nacos Config -->
		<dependency>
			<groupId>com.alibaba.cloud</groupId>
			<artifactId>spring-cloud-starter-alibaba-nacos-config</artifactId>
		</dependency>
		<!--nacos-->

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-data-redis</artifactId>
		</dependency>

		<dependency>
			<groupId>cn.hutool</groupId>
			<artifactId>hutool-all</artifactId>
			<version>${hutool.version}</version>
		</dependency>

		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>

	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

2.公共接口配置

package com.yan.service;

import cn.hutool.core.util.ObjectUtil;
import cn.hutool.extra.spring.SpringUtil;
import com.fasterxml.jackson.annotation.JsonIgnore;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Marker;
import org.springframework.core.env.Environment;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;

/**
 * @Author yan
 * @Date 2024/9/22 上午10:46:15
 * @Description
 */
public interface AbstractBean {
    @Data
    @Accessors(chain = true)
    @NoArgsConstructor
    @AllArgsConstructor
    class LogBean {
        private org.slf4j.Logger logger;
        private Class<?> aClass;
    }

    @JsonIgnore
    default LogBean getLogBean() {
        LogBean logBean = new LogBean()
                .setAClass(this.getClass())
                .setLogger(LoggerFactory.getLogger(this.getClass()));
        return logBean;
    }

    @JsonIgnore
    default Logger getLogger() {
        return getLogBean().getLogger();
    }

    @JsonIgnore
    default Class<?> getAClass() {
        return getLogBean().getAClass();
    }

    /**
     * 初始化
     */
    @PostConstruct
    default void init() {
        LogBean logBean = getLogBean();
        if (!isProd()) {
            logBean.getLogger().info("init {} ...", logBean.getAClass().getSimpleName());
        }
    }

    /**
     * 销毁
     */
    @PreDestroy
    default void destroy() {
        LogBean logBean = getLogBean();
        if (!isProd()) {
            logBean.getLogger().info("destroy {} ...", logBean.getAClass().getSimpleName());
        }
    }

    default boolean isProd() {
        Environment env = SpringUtil.getBean(Environment.class);
        String active = env.getProperty("spring.profiles.active");
        return ObjectUtil.equals("prod", active);
    }


    default boolean isTraceEnabled() {
        return getLogger().isTraceEnabled();
    }

    default void trace(String var1) {
        if (!isProd()) {
            getLogger().trace(var1);
        }
    }

    default void trace(String var1, Object var2) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default void trace(String var1, Object var2, Object var3) {
        if (!isProd()) {
            getLogger().trace(var1, var2, var3);
        }
    }

    default void trace(String var1, Object... var2) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default void trace(String var1, Throwable var2) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default boolean isTraceEnabled(Marker var1) {
        return getLogger().isTraceEnabled(var1);
    }

    default void trace(Marker var1, String var2) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default void trace(Marker var1, String var2, Object var3) {
        if (!isProd()) {
            getLogger().trace(var1, var2, var3);
        }
    }

    default void trace(Marker var1, String var2, Object var3, Object var4) {
        if (!isProd()) {
            getLogger().trace(var1, var2, var3, var4);
        }
    }

    default void trace(Marker var1, String var2, Object... var3) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default void trace(Marker var1, String var2, Throwable var3) {
        if (!isProd()) {
            getLogger().trace(var1, var2);
        }
    }

    default boolean isDebugEnabled() {
        return getLogger().isDebugEnabled();
    }

    default void debug(String var1) {
        if (!isProd()) {
            getLogger().debug(var1);
        }
    }

    default void debug(String var1, Object var2) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(String var1, Object var2, Object var3) {
        if (!isProd()) {
            getLogger().debug(var1, var2, var3);
        }
    }

    default void debug(String var1, Object... var2) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(String var1, Throwable var2) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default boolean isDebugEnabled(Marker var1) {
        return getLogger().isDebugEnabled(var1);
    }

    default void debug(Marker var1, String var2) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(Marker var1, String var2, Object var3) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(Marker var1, String var2, Object var3, Object var4) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(Marker var1, String var2, Object... var3) {
        if (!isProd()) {
            getLogger().debug(var1, var2);
        }
    }

    default void debug(Marker var1, String var2, Throwable var3) {
        getLogger().debug(var1, var2);
    }

    default boolean isInfoEnabled() {
        return getLogger().isInfoEnabled();
    }

    default void info(String var1) {
        if (!isProd()) {
            getLogger().info(var1);
        }
    }

    default void info(String var1, Object var2) {
        if (!isProd()) {
            getLogger().info(var1, var2);
        }
    }

    default void info(String var1, Object var2, Object var3) {
        if (!isProd()) {
            getLogger().info(var1, var2, var3);
        }
    }

    default void info(String var1, Object... var2) {
        if (!isProd()) {
            getLogger().info(var1, var2);
        }
    }


    default void info(String var1, Throwable var2) {
        if (!isProd()) {
            getLogger().info(var1, var2);
        }
    }

    default boolean isInfoEnabled(Marker var1) {
        return getLogger().isInfoEnabled(var1);
    }

    default void info(Marker var1, String var2) {
        if (!isProd()) {
            getLogger().info(var1, var2);
        }
    }

    default void info(Marker var1, String var2, Object var3) {
        if (!isProd()) {
            getLogger().info(var1, var2, var3);
        }
    }

    default void info(Marker var1, String var2, Object var3, Object var4) {
        if (!isProd()) {
            getLogger().info(var1, var2, var3, var4);
        }
    }

    default void info(Marker var1, String var2, Object... var3) {
        if (!isProd()) {
            getLogger().info(var1, var2, var3);
        }
    }

    default void info(Marker var1, String var2, Throwable var3) {
        if (!isProd()) {
            getLogger().info(var1, var2, var3);
        }
    }

    default boolean isWarnEnabled() {
        return getLogger().isWarnEnabled();
    }

    default void warn(String var1) {
        if (!isProd()) {
            getLogger().warn(var1);
        }
    }

    default void warn(String var1, Object var2) {
        if (!isProd()) {
            getLogger().warn(var1, var2);
        }
    }

    default void warn(String var1, Object... var2) {
        if (!isProd()) {
            getLogger().warn(var1, var2);
        }
    }

    default void warn(String var1, Object var2, Object var3) {
        if (!isProd()) {
            getLogger().warn(var1, var2, var3);
        }
    }

    default void warn(String var1, Throwable var2) {
        if (!isProd()) {
            getLogger().warn(var1, var2);
        }
    }

    default boolean isWarnEnabled(Marker var1) {
        return getLogger().isWarnEnabled(var1);
    }

    default void warn(Marker var1, String var2) {
        if (!isProd()) {
            getLogger().warn(var1, var2);
        }
    }

    default void warn(Marker var1, String var2, Object var3) {
        if (!isProd()) {
            getLogger().warn(var1, var2, var3);
        }
    }

    default void warn(Marker var1, String var2, Object var3, Object var4) {
        if (!isProd()) {
            getLogger().warn(var1, var2, var3, var4);
        }
    }

    default void warn(Marker var1, String var2, Object... var3) {
        if (!isProd()) {
            getLogger().warn(var1, var2, var3);
        }
    }

    default void warn(Marker var1, String var2, Throwable var3) {
        if (!isProd()) {
            getLogger().warn(var1, var2, var3);
        }
    }

    default boolean isErrorEnabled() {
        return getLogger().isErrorEnabled();
    }

    default void error(String var1) {
        getLogger().error(var1);
    }

    default void error(String var1, Object var2) {
        getLogger().error(var1, var2);
    }

    default void error(String var1, Object var2, Object var3) {
        getLogger().error(var1, var2, var3);
    }

    default void error(String var1, Object... var2) {
        getLogger().error(var1, var2);
    }

    default void error(String var1, Throwable var2) {
        getLogger().error(var1, var2);
    }

    default boolean isErrorEnabled(Marker var1) {
        return getLogger().isErrorEnabled(var1);
    }

    default void error(Marker var1, String var2) {
        getLogger().error(var1, var2);
    }

    default void error(Marker var1, String var2, Object var3) {
        getLogger().error(var1, var2, var3);
    }

    default void error(Marker var1, String var2, Object var3, Object var4) {
        getLogger().error(var1, var2, var3, var4);
    }

    default void error(Marker var1, String var2, Object... var3) {
        getLogger().error(var1, var2, var3);
    }

    default void error(Marker var1, String var2, Throwable var3) {
        getLogger().error(var1, var2, var3);
    }
}

3.Redis配置

package com.yan.config;

import cn.hutool.core.date.DatePattern;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.deser.std.DateDeserializers;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import com.fasterxml.jackson.databind.ser.std.DateSerializer;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateDeserializer;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateTimeDeserializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateSerializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateTimeSerializer;
import com.yan.service.AbstractBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

import java.text.DateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Date;

/**
 * @Author yan
 * @Date 2025/2/5 15:02:07
 * @Description
 */
@Configuration
//@EnableAspectJAutoProxy
//@EnableCaching // 开启Spring Redis Cache,使用注解驱动缓存机制
public class RedisConfig implements AbstractBean {

    @Bean
    @Primary
    @SuppressWarnings(value = {"unchecked", "rawtypes"})
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        return initRedisTemplate(connectionFactory);
    }
    /**
     * 初始化RedisTemplate
     *
     * @param connectionFactory
     * @return
     */
    public RedisTemplate<String, Object> initRedisTemplate(RedisConnectionFactory connectionFactory) {
        info("init RedisTemplate");
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);
        // 使用Jackson进行JSON序列化
        GenericJackson2JsonRedisSerializer serializer = new GenericJackson2JsonRedisSerializer(initObjectMapper(null));

        // 使用StringRedisSerializer来序列化和反序列化redis的key值
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        template.setKeySerializer(stringRedisSerializer);
        template.setValueSerializer(serializer);

        // Hash的key也采用StringRedisSerializer的序列化方式
        template.setHashKeySerializer(stringRedisSerializer);
        template.setHashValueSerializer(serializer);

        template.afterPropertiesSet();
        return template;
    }

    /**
     * 初始化ObjectMapper
     * @return
     */
    public ObjectMapper initObjectMapper(ObjectMapper objectMapper) {
        objectMapper = ObjectUtil.isEmpty(objectMapper)? new ObjectMapper():objectMapper;
        // 设置可见性
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        // 序列化后添加类信息(不配置,序列化后就是一个Json字符串)
        objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
        // 将日期序列化为可读字符串而不是时间戳
        objectMapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
        // 设置时间模块(格式化,不设置,则输出默认格式)
        JavaTimeModule timeModule = new JavaTimeModule();
        // LocalDateTime
        String zoneId = "Asia/Shanghai";
        timeModule.addSerializer(LocalDateTime.class, new LocalDateTimeSerializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATETIME_PATTERN).withZone(ZoneId.of(zoneId))));
        timeModule.addDeserializer(LocalDateTime.class, new LocalDateTimeDeserializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATETIME_PATTERN).withZone(ZoneId.of(zoneId))));
        // LocalDate
        timeModule.addSerializer(LocalDate.class, new LocalDateSerializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATE_PATTERN).withZone(ZoneId.of(zoneId))));
        timeModule.addDeserializer(LocalDate.class, new LocalDateDeserializer(DateTimeFormatter.ofPattern(DatePattern.NORM_DATE_PATTERN).withZone(ZoneId.of(zoneId))));
        // Date
        timeModule.addSerializer(Date.class, new DateSerializer(false, DateFormat.getDateInstance()));
        timeModule.addDeserializer(Date.class, new DateDeserializers.DateDeserializer());
        // 设置自定义时间模块
        objectMapper.registerModule(timeModule);
        return objectMapper;
    }

}

4.WebSocket Nacos 配置Bean

package com.yan.config;

import com.alibaba.cloud.nacos.NacosDiscoveryProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

/**
 * @Author yan
 * @Date 2024/3/7 0007 17:17
 * @Description
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig {

    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }

    // 注册到Nacos
    @Bean
    public NacosDiscoveryProperties nacosProperties() {
        return new NacosDiscoveryProperties();
    }
}

4.1常量

package com.yan.constant;

/**
 * @Author yan
 * @Date 2025/2/4 15:56:57
 * @Description
 */
public interface WebSocket {
    String WS_INSTANCE = "ws:instance:";
    String WS_USER = "ws:user:";
    String WS_MSG ="ws:msg:";
}

4.2消息类

package com.yan.domain;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
import java.util.Date;

/**
 * @Author yan
 * @Date 2025/2/4 1:48:04
 * WebSocket 消息实体类
 */
@Data @NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
public class Message {
    // 消息类型常量(可根据业务扩展)
    public static final String TYPE_TEXT = "text";
    public static final String TYPE_IMAGE = "image";
    public static final String TYPE_SYSTEM = "system";
    // 消息发送者ID(例如用户ID)
    private String senderId;
    // 消息接收者ID(用户ID、群组ID或频道ID)
    private String targetId;
    // 消息内容(文本、JSON或Base64编码的二进制数据)
    private String content;
    // 消息类型(text/image/system...)
    private String type = TYPE_TEXT;
    // 消息时间戳
    private Date timestamp = new Date();

}

4.3 WebSocketEndpoint(核心代码)

package com.yan.endpoint;

import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.cloud.nacos.NacosDiscoveryProperties;
import com.yan.domain.Message;
import com.yan.constant.WebSocket;
import com.yan.service.AbstractBean;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.annotation.Resource;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

@ServerEndpoint("/ws/{userId}")
@Component
public class WebSocketEndpoint implements AbstractBean {


    // 本地存储当前实例的在线用户(线程安全)
    private static final Set<String> LOCAL_USER_IDS = ConcurrentHashMap.newKeySet();

    // 本地会话存储(线程安全)
    private static final ConcurrentHashMap<String, Session> SESSION_MAP = new ConcurrentHashMap<>();

    // 静态依赖注入
    @Resource
    private RedisTemplate<String, String> redisTemplate = getRedisTemplate();
    @Resource
    private NacosDiscoveryProperties nacosProperties = getNacosDiscoveryProperties();



    // --- 生命周期管理 ---
    @PostConstruct
    public void init() {
        redisTemplate = getRedisTemplate();
    }

    @PreDestroy
    public void destroy() {
        // 清理Redis中当前实例的所有数据
        redisTemplate.delete(WebSocket.WS_INSTANCE + getInstanceId());
        LOCAL_USER_IDS.forEach(userId ->
                redisTemplate.delete(WebSocket.WS_USER + userId)
        );
    }
    // 当前实例ID(格式:ip:port)
    public static String getInstanceId(){
        NacosDiscoveryProperties nacosDiscoveryProperties = getNacosDiscoveryProperties();
        return  nacosDiscoveryProperties.getIp() + ":" + nacosDiscoveryProperties.getPort();
    }
    public static RedisTemplate getRedisTemplate() {
        return SpringUtil.getBean(RedisTemplate.class);
    }
    public static NacosDiscoveryProperties getNacosDiscoveryProperties() {
        return SpringUtil.getBean(NacosDiscoveryProperties.class);
    }
    // --- WebSocket事件处理 ---
    @OnOpen
    public void onOpen(Session session, @PathParam("userId") String userId) {
        info("WebSocket连接已建立,用户ID:{}", userId);
        // 存储会话和用户信息
        SESSION_MAP.put(session.getId(), session);
        LOCAL_USER_IDS.add(userId);

        // 更新Redis映射(用户->实例)
        redisTemplate.opsForValue().set(WebSocket.WS_USER + userId, getInstanceId());

        // 存储实例的会话信息(实例->会话ID:用户ID)
        String sessionId = session.getId();
        String key = WebSocket.WS_INSTANCE + getInstanceId();
        redisTemplate.opsForHash().put(key, sessionId, userId);
    }

    @OnClose
    public void onClose(Session session, @PathParam("userId") String userId) {
        info("WebSocket连接已关闭,用户ID:{}", userId);
        // 清理本地存储
        SESSION_MAP.remove(session.getId());
        LOCAL_USER_IDS.remove(userId);

        // 清理Redis数据
        redisTemplate.opsForHash().delete(WebSocket.WS_INSTANCE + getInstanceId(), session.getId());
        redisTemplate.delete(WebSocket.WS_USER + userId);
    }

    @OnMessage
    public void onMessage(String message, Session session, @PathParam("userId") String senderId) {
        info("收到消息: {}", message);
        try {
            Message msg = JSONUtil.toBean(message, Message.class);
            String targetUserId = msg.getTargetId();

            // 查询目标用户所在实例
            String targetInstanceId = redisTemplate.opsForValue().get(WebSocket.WS_USER + targetUserId);

            if (targetInstanceId == null) {
                handleOfflineMessage(msg);
                return;
            }

            if (targetInstanceId.equals(getInstanceId())) {
                sendLocalMessage(targetUserId, msg); // 本地发送
            } else {
                sendCrossInstanceMessage(targetInstanceId, msg); // 跨实例发送
            }
        } catch (Exception e) {
            // 异常处理逻辑
        }
    }

    // --- 消息发送方法 ---
    public void sendLocalMessage(String userId, Message msg) {
        info("[发送本地消息]sendLocalMessage");
        Session session = findSessionByUserId(userId);
        if (session != null && session.isOpen()) {
            String toJsonStr = JSONUtil.toJsonStr(msg);
            //String compressed = compress(message);
            session.getAsyncRemote().sendText(toJsonStr);
        }
    }

    private void sendCrossInstanceMessage(String targetInstanceId, Message msg) {
        info("[发送跨实例消息]sendCrossInstanceMessage");
        redisTemplate.convertAndSend(WebSocket.WS_MSG + targetInstanceId, msg);
    }

    // --- 辅助方法 ---
    private static Session findSessionByUserId(String userId) {
        RedisTemplate<String, String> redisTemplate = getRedisTemplate();
        return SESSION_MAP.values().stream()
                .filter(session -> {
                    String sessionUserId = (String) redisTemplate.opsForHash()
                            .get(WebSocket.WS_INSTANCE + getInstanceId(), session.getId());
                    return userId.equals(sessionUserId);
                })
                .findFirst()
                .orElse(null);
    }

    private void handleOfflineMessage(Message msg) {
        // 离线消息存储逻辑(如存入数据库)
    }

    // 处理来自Redis的跨实例消息
    public void handleRedisMessage(String messageJson) {
        Message msg = JSONUtil.toBean(messageJson, Message.class);
        Session session = findSessionByUserId(msg.getTargetId());
        if (session != null) {
            info("[处理跨实例消息]handleRedisMessage");
            session.getAsyncRemote().sendText(messageJson);
        }
    }

    // 获取当前实例在线用户列表
    public static Set<String> getLocalOnlineUsers() {
        return Collections.unmodifiableSet(LOCAL_USER_IDS);
    }
}

5.Redis跨实例处理消息

采用Redis 频道订阅方式处理 跨实例消息 通过实例ID判断是否是当前实例需要处理的消息

package com.yan.config;

import cn.hutool.extra.spring.SpringUtil;
import com.yan.constant.WebSocket;
import com.yan.endpoint.WebSocketEndpoint;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
/**
 * @Author yan
 * @Date 2025/2/5 13:02:50
 * @Description
 */
@Configuration
public class RedisMessageConfig {

    @Bean
    public RedisMessageListenerContainer redisMessageListenerContainer(
            RedisConnectionFactory connectionFactory
    ) {
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        container.setConnectionFactory(connectionFactory);

        // 订阅所有 ws:msg:* 频道
        container.addMessageListener(new MessageListener() {
            @Override
            public void onMessage(Message message, byte[] pattern) {
                String channel = new String(message.getChannel());
                String msg = new String(message.getBody());
                SpringUtil.getBean(WebSocketEndpoint.class).handleRedisMessage(msg);
            }
        }, new PatternTopic(WebSocket.WS_MSG+"*"));
        return container;
    }
}

6. yml配置文件

6.1 bootstrap.yml

spring:
  application:
    name: websocket

  profiles:
    active: dev

  main:
   allow-circular-references: true

  cloud:
    nacos:
      discovery:
        service: ${spring.application.name}
        server-addr: ${nacos.server-addr}
        username: ${nacos.username}
        password: ${nacos.username}
      config:
        server-addr: ${nacos.server-addr}
        username: ${nacos.username}
        password: ${nacos.username}
        refresh-enabled: true
        refreshable-dataids: ${spring.cloud.nacos.discovery.service:${spring.application.name}}-${spring.profiles.active:test}.${spring.cloud.nacos.config.file-extension:yaml}


nacos:
  server-addr: 127.0.0.1:8848
  username: nacos
  password: nacos

server:
  port: 18900
  servlet:
    context-path: /websocket

6.2 bootstrap-dev.yml

nacos:
  server-addr: 192.168.3.22:8848
  username: nacos
  password: nacos

spring:
  redis:
    host: 192.168.3.22
    database: 1

6.3 bootstrap-test.yml

nacos:
  server-addr: 192.168.3.22:8848
  username: nacos
  password: nacos

spring:
  redis:
    host: 192.168.3.22
    database: 1

server:
  port: 18800

7.可支持gateway网关扩展(如使用NGINX为网关请忽略)

7.1 gateway网关 yml配置

#在配置nacos实现聚合swagger时由于
#/v3/api-docs引起的网关路由找不到具体服务器的api-docs文档,导致swagger生成文档失败
#解决方法:可以直接在application.yaml配置后,网关路由正常
server:
  servlet:
    context-path: /test
springfox:
  documentation:
    open-api:
      v3:
        path: /v3/api-docs

spring:
  cloud:
    nacos:
      server-addr: 127.0.0.1:8848
    gateway:
      discovery:
        #docker启动时远程调用失败  需要将docker所在服务器地址设置一下
        #        ip: ${LOCAL_IP_ADDRESS:192.168.200.128}
        locator:
          # 服务名小写转换
          lower-case-service-id: true
          # 字段映射转发 满足 http://网关地址:网关端口/服务名/服务接口地址
          # 自动转为 http://服务名/服务接口地址
          # 商用false
          #  这里需要注意的是,如果我们的配置了 server.servlet.context-path 属性,这会导致自动路由失败的问题,因此我们需要做如下两个修改:
          enabled: true
          # 重写过滤链,解决项目设置了 server.servlet.context-path 导致 locator.enabled=true 默认路由策略404的问题
          # 现实生产可使用nginx先处理掉 server.servlet.context-path,再通过gateway处理
          filters[0]: PreserveHostHeader
      routes:
        - id: gateway
          uri: no://op
          predicates:
            - Path=${server.servlet.context-path}/websocket,${server.servlet.context-path},false
          filters:
            #本地
            - RedirectTo=302, http://127.0.0.1:${server.port}${server.servlet.context-path}/home

        - id: static-resource
          uri: lb://gateway
          order: 1
          predicates:
            - Path=${server.servlet.context-path}/assert/**
          filters:
            - StripPrefix=1
   
        - id: websocket
          uri: lb://websocket
          order: 2
          metadata:
            summary: websocket服务
            description: websocket服务
            document-uri: /doc.html
          predicates:
            - Path=${server.servlet.context-path}/websocket/**
          filters:
            - StripPrefix=1

7.3 gateway网关 跨域配置处理(注意nginx也需要处理跨域)

package com.yan.gateway.filter;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.json.JSONUtil;
import com.yan.abstractinterface.bean.AbstractBean;
import com.yan.utils.json.JsonUtils;
import com.yan.utils.object.ObjectUtils;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.core.Ordered;

import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.*;
import java.util.stream.Collectors;


/**
 * @Author yan
 * @Date 2024/11/14 下午12:34:07
 * @Description
 */
@Slf4j
public class CorsResponseHeaderFilter implements GlobalFilter, Ordered, AbstractBean {

    @Override
    public int getOrder() {
        // 即待处理完响应体后接着处理响应头
        // 比 NettyWriteResponseFilter 先执行,确保响应头是可修改的
        return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 2;
    }

    public Mono getFromRunnableMono(ServerWebExchange exchange, GatewayFilterChain chain) {
        Mono<Object> fromRunnable = Mono.fromRunnable(() -> {
            try {
                ServerHttpResponse response = exchange.getResponse();
                HttpHeaders headers = response.getHeaders();
                info("处理前的header头: {}", JSONUtil.toJsonStr(headers));
                // 创建新Header Map(避免修改不可变Entry)
                Map<String, List<String>> newHeaders = new HashMap<>();
                headers.forEach((key, values) -> {
                    // 去重处理
                    List<String> distinctValues = values.stream()
                            .distinct()
                            .collect(Collectors.toList());
                    // 保留第一个值(如果需特殊处理)
                    /*if (distinctValues.size() > 1) {
                        distinctValues = CollUtil.newArrayList(distinctValues.get(0));
                    }*/
                    newHeaders.put(key, distinctValues);
                });

                // 清空并重置Header
                headers.clear();
                headers.putAll(newHeaders);

                info("处理后的header头: {}", JSONUtil.toJsonStr(headers));
            } catch (Exception e) {
                error("去除重复请求头异常: ", e);
            }
        });

        return fromRunnable;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        try {
        /**
          解决 ReadOnlyHttpHeaders 不可修改的问题
        */
            ServerHttpResponse response = exchange.getResponse();
            ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(response) {
                @Override
                public HttpHeaders getHeaders() {
                    HttpHeaders headers = new HttpHeaders();
                    headers.putAll(super.getHeaders());
                    return headers;
                }
            };

            ServerWebExchange mutatedExchange = exchange.mutate()
                    .response(decoratedResponse)
                    .build();
            //Mono<Void> distinctResponseHeaders = getDeferMono(exchange, chain);
            Mono<Void> distinctResponseHeaders = getFromRunnableMono(mutatedExchange, chain);
            return chain.filter(mutatedExchange).then(distinctResponseHeaders);
        } catch (Exception e) {
            error("去除重复请求头error,{},{}", e);
        }

        return chain.filter(exchange);
    }

}

8.测试

在线测试工具

8.1以实例测试

注:需要用户在线可测试成功 离线消息未处理配置(根据实际情况可能需要入库)
ws://127.0.0.1:18800/websocket/ws/1
发送消息:{"senderId":"1","targetId":"2","content":"110","type":"text"}
ws://127.0.0.1:18900/websocket/ws/2
发送消息:{"senderId":"2","targetId":"1","content":"220","type":"text"}
WS_Snipaste_2025-02-05_16-59-52.jpg
WS_Snipaste_2025-02-05_16-59-27.jpg

8.2 以网关测试

ws://127.0.0.1:10010/test/websocket/ws/1
ws://127.0.0.1:10010/test/websocket/ws/2
GWS——Snipaste_2025-02-05_17-19-08.jpg
用户连接时网关会处理跨域请求头并打印前后差异
G_WS_Snipaste_2025-02-05_17-17-14.jpg

注:可通过不断连接断开使得俩个用户处于不同实例上进行测试
连接到不同实例时Redis key如下所示

R_Snipaste_2025-02-05_17-26-54.jpg
用户1发送消息给用户2 测试如下为成功:
Send12——Snipaste_2025-02-05_17-34-35.jpg
Out——Snipaste_2025-02-05_17-37-40.jpg