From 5443e2c76435e3424ad76ff683b3b85f9455db1d Mon Sep 17 00:00:00 2001 From: wuchunlei Date: Thu, 22 May 2025 17:53:19 +0800 Subject: [PATCH] =?UTF-8?q?ai=E9=97=AE=E7=AD=94=E6=94=B9=E6=88=90websocket?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 4 ++ .../peanut/common/utils/RagFlowApiUtil.java | 23 +++---- .../java/com/peanut/config/ShiroConfig.java | 2 +- .../java/com/peanut/config/WebSocket.java | 64 +++++++++++++++++++ .../com/peanut/config/WebSocketConfig.java | 16 +++++ .../controller/RagFlowApiController.java | 12 ++-- 6 files changed, 97 insertions(+), 24 deletions(-) create mode 100644 src/main/java/com/peanut/config/WebSocket.java create mode 100644 src/main/java/com/peanut/config/WebSocketConfig.java diff --git a/pom.xml b/pom.xml index 02644484..b9a64dd6 100644 --- a/pom.xml +++ b/pom.xml @@ -76,6 +76,10 @@ jacob 1.18 + + org.springframework.boot + spring-boot-starter-websocket + org.springframework.boot diff --git a/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java b/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java index d6ce206c..7e7b7eeb 100644 --- a/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java +++ b/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java @@ -1,9 +1,9 @@ package com.peanut.common.utils; import com.alibaba.fastjson.JSONObject; +import com.peanut.config.WebSocket; import com.peanut.modules.common.entity.AiChatContent; import com.peanut.modules.common.service.AiChatContentService; -import org.apache.commons.lang.StringUtils; import org.apache.http.Consts; import org.apache.http.HttpEntity; import org.apache.http.client.methods.CloseableHttpResponse; @@ -18,7 +18,6 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import org.springframework.transaction.interceptor.TransactionAspectSupport; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,6 +32,8 @@ public class RagFlowApiUtil { private String authorization; @Autowired private AiChatContentService aiChatContentService; + @Autowired + private WebSocket webSocket; //聊天助手列表 public List> getChatAssistants(String chatId) throws Exception{ @@ -85,15 +86,6 @@ public class RagFlowApiUtil { if ("0".equals(jsonObject.get("code").toString())){ List l = jsonObject.getJSONArray("data"); list.addAll(l); -// for (Object o : l) { -// Map map = new HashMap<>(); -// Map m = (Map)o; -// map.put("chatId",chatId); -// map.put("id",m.get("id")); -// map.put("name",m.get("name")); -// map.put("messages",m.get("messages")); -// list.add(map); -// } } return list; } @@ -146,7 +138,7 @@ public class RagFlowApiUtil { } //与助手聊天流式 - public Flux chatToAssistantStream(String chatId,String chatName,String sessionId,String sessionName,String question,String patientName) { + public void chatToAssistantStream(String chatId,String chatName,String sessionId,String sessionName,String question,String patientName) { try { String userId = ShiroUtils.getUId()+""; Map entity = new HashMap<>(); @@ -165,7 +157,7 @@ public class RagFlowApiUtil { content.setContent(question); aiChatContentService.save(content); List list = new ArrayList<>(); - return WebClient.create().post() + WebClient.create().post() .uri(url+"/api/v1/chats/"+chatId+"/completions") .header("Authorization", authorization) .header("Content-Type", "application/json;chartset=utf-8") @@ -176,11 +168,13 @@ public class RagFlowApiUtil { JSONObject jsonObject = JSONObject.parseObject(data); if ("0".equals(jsonObject.get("code").toString())){ if (!"true".equals(jsonObject.get("data").toString())){ + webSocket.sendMessage(data); list.add(((JSONObject)jsonObject.get("data")).get("answer").toString()); } } }) .doFinally(data -> { + webSocket.sendMessage("{\"code\":0,\"data\":true}"); AiChatContent answer = new AiChatContent(); answer.setUserId(Integer.parseInt(userId)); answer.setChatAssistantId(chatId); @@ -191,12 +185,11 @@ public class RagFlowApiUtil { answer.setType(1); answer.setContent(list.get(list.size()-1)); aiChatContentService.save(answer); - }); + }).subscribe(); }catch (Exception e){ e.printStackTrace(); TransactionAspectSupport.currentTransactionStatus().setRollbackOnly(); } - return null; } diff --git a/src/main/java/com/peanut/config/ShiroConfig.java b/src/main/java/com/peanut/config/ShiroConfig.java index 5256f933..0ba92ff8 100644 --- a/src/main/java/com/peanut/config/ShiroConfig.java +++ b/src/main/java/com/peanut/config/ShiroConfig.java @@ -17,7 +17,6 @@ import org.apache.shiro.spring.web.ShiroFilterFactoryBean; import org.apache.shiro.web.mgt.DefaultWebSecurityManager; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; - import javax.servlet.Filter; import java.util.HashMap; import java.util.LinkedHashMap; @@ -58,6 +57,7 @@ public class ShiroConfig { filterMap.put("/common/apkConfig/getApkUrl","anon");//获取apk下载地址 filterMap.put("/common/sysFeedback/addSysFeedback","anon");//问题反馈-密码问题 + filterMap.put("/websocket/**","anon"); filterMap.put("/oss/**","anon"); filterMap.put("/image/**","anon"); diff --git a/src/main/java/com/peanut/config/WebSocket.java b/src/main/java/com/peanut/config/WebSocket.java new file mode 100644 index 00000000..fd764cf1 --- /dev/null +++ b/src/main/java/com/peanut/config/WebSocket.java @@ -0,0 +1,64 @@ +package com.peanut.config; + + +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; +import javax.websocket.OnClose; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpoint; +import java.util.concurrent.CopyOnWriteArraySet; + + +/** + *WebSocket的服务端 + */ + +@Slf4j +@Component +@ServerEndpoint("/websocket") +public class WebSocket { + + + private Session session; + + private static CopyOnWriteArraySet webSocketSet = new CopyOnWriteArraySet<>(); + + @OnOpen + public void onOpen(Session session) { + this.session = session; + webSocketSet.add(this); + log.info("【websocket消息】有新的连接, 总数:{}", webSocketSet.size()); + } + + //前端关闭时一个websocket时 + @OnClose + public void onClose() { + webSocketSet.remove(this); + log.info("【websocket消息】连接断开, 总数:{}", webSocketSet.size()); + } + + //前端向后端发送消息 + @OnMessage + public void onMessage(String message) { + log.info("【websocket消息】收到客户端发来的消息:{}", message); + } + + //新增一个方法用于主动向客户端发送消息 + public static void sendMessage(String message) { + for (WebSocket webSocket: webSocketSet) { + log.info("【websocket消息】, message={}", message); + try { + webSocket.session.getBasicRemote().sendText(message); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + + + + +} diff --git a/src/main/java/com/peanut/config/WebSocketConfig.java b/src/main/java/com/peanut/config/WebSocketConfig.java new file mode 100644 index 00000000..bb884d69 --- /dev/null +++ b/src/main/java/com/peanut/config/WebSocketConfig.java @@ -0,0 +1,16 @@ +package com.peanut.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.server.standard.ServerEndpointExporter; + +@Configuration +@EnableWebSocket +public class WebSocketConfig{ + @Bean + public ServerEndpointExporter serverEndpointExporter() { + return new ServerEndpointExporter(); + } +} + diff --git a/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java b/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java index dfd5eb4a..aa41790f 100644 --- a/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java +++ b/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java @@ -9,17 +9,12 @@ import com.peanut.modules.common.service.AiChatContentService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang.StringUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.MediaType; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import reactor.core.publisher.Flux; -import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; @Slf4j @RestController("commonRagFlowApi") @@ -82,10 +77,11 @@ public class RagFlowApiController { } //与助手聊天流式 - @RequestMapping(value = "/chatToAssistantStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @RequestMapping(value = "/chatToAssistantStream") @Transactional - public Flux chatToAssistantStream(String chatId,String chatName,String sessionId,String sessionName,String question,String patientName){ - return ragFlowApiUtil.chatToAssistantStream(chatId,chatName,sessionId,sessionName,question,patientName); + public R chatToAssistantStream(String chatId,String chatName,String sessionId,String sessionName,String question,String patientName){ + ragFlowApiUtil.chatToAssistantStream(chatId,chatName,sessionId,sessionName,question,patientName); + return R.ok(); }