diff --git a/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java b/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java index fc2ed516..0f40a821 100644 --- a/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java +++ b/src/main/java/com/peanut/common/utils/RagFlowApiUtil.java @@ -217,5 +217,104 @@ public class RagFlowApiUtil { } } + //代理列表 + public List> getChatAgents(String agentId) throws Exception{ + CloseableHttpClient httpClient = HttpClients.createDefault(); + HttpGet get = new HttpGet(url+"/api/v1/agents?id="+agentId); + get.setHeader("Authorization", authorization); + get.setHeader("Content-Type", "application/json;chartset=utf-8"); + CloseableHttpResponse response = httpClient.execute(get); + int statusCode = response.getStatusLine().getStatusCode(); + if (statusCode >= 400) { + throw new RuntimeException("API调用失败,状态码:" + statusCode); + } + HttpEntity responseEntity = response.getEntity(); + String responseString = EntityUtils.toString(responseEntity, Consts.UTF_8); + JSONObject jsonObject = JSONObject.parseObject(responseString); + List> list = new ArrayList(); + if ("0".equals(jsonObject.get("code").toString())){ + List l = jsonObject.getJSONArray("data"); + for (Object o : l) { + Map map = new HashMap<>(); + Map m = (Map)o; + map.put("id",m.get("id")); + map.put("name",m.get("title")); + list.add(map); + } + } + return list; + } + + //新建代理对话 + public String createAgentChat(Map params) throws Exception{ + CloseableHttpClient httpClient = HttpClients.createDefault(); + String agentId = params.get("agentId").toString(); + HttpPost post = new HttpPost(url+"/api/v1/agents/"+agentId+"/sessions?user_id="+ShiroUtils.getUId()); + post.setHeader("Authorization", authorization); + post.setHeader("Content-Type", "application/json;chartset=utf-8"); + CloseableHttpResponse response = httpClient.execute(post); + int statusCode = response.getStatusLine().getStatusCode(); + if (statusCode >= 400) { + throw new RuntimeException("API调用失败,状态码:" + statusCode); + } + HttpEntity responseEntity = response.getEntity(); + String responseString = EntityUtils.toString(responseEntity, Consts.UTF_8); + JSONObject jsonObject = JSONObject.parseObject(responseString); + if ("0".equals(jsonObject.get("code").toString())){ + return ((JSONObject)jsonObject.get("data")).get("id").toString(); + }else { + return jsonObject.get("message").toString(); + } + } + + //与代理聊天流式 + public void chatToAgentStream(String agentId,String sessionId,String question) { + try { + String userId = ShiroUtils.getUId()+""; + Map entity = new HashMap<>(); + entity.put("question", question); + entity.put("stream", true); + entity.put("session_id", sessionId); + entity.put("user_id", userId); + AiChatContent content = new AiChatContent(); + content.setUserId(Integer.parseInt(userId)); + content.setChatAssistantId(agentId); + content.setChatId(sessionId); + content.setType(0); + content.setContent(question); + aiChatContentService.save(content); + List list = new ArrayList<>(); + WebClient.create().post() + .uri(url+"/api/v1/agents/"+agentId+"/completions") + .header("Authorization", authorization) + .header("Content-Type", "application/json;chartset=utf-8") + .bodyValue(JSONObject.toJSONString(entity)) + .retrieve() + .bodyToFlux(String.class) + .doOnNext(data -> { + JSONObject jsonObject = JSONObject.parseObject(data); + if ("0".equals(jsonObject.get("code").toString())){ + if (!"true".equals(jsonObject.get("data").toString())){ + webSocket.sendMessage(data); + list.add(((JSONObject)jsonObject.get("data")).get("answer").toString()); + } + } + }) + .doFinally(data -> { + webSocket.sendMessage("{\"code\":0,\"data\":true}"); + AiChatContent answer = new AiChatContent(); + answer.setUserId(Integer.parseInt(userId)); + answer.setChatAssistantId(agentId); + answer.setChatId(sessionId); + answer.setType(1); + answer.setContent(list.get(list.size()-1)); + aiChatContentService.save(answer); + }).subscribe(); + }catch (Exception e){ + e.printStackTrace(); + TransactionAspectSupport.currentTransactionStatus().setRollbackOnly(); + } + } + } diff --git a/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java b/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java index 8f4cf148..859f9c75 100644 --- a/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java +++ b/src/main/java/com/peanut/modules/common/controller/RagFlowApiController.java @@ -123,6 +123,28 @@ public class RagFlowApiController { return R.ok(); } + //代理列表 + @RequestMapping("/getChatAgents") + public R getChatAgents() throws Exception{ + List> list = ragFlowApiUtil.getChatAgents(""); + return R.ok().put("list",list); + } + + //创建代理会话 + @RequestMapping("/createAgentChat") + public R createAgentChat(@RequestBody Map params) throws Exception{ + String agentId = ragFlowApiUtil.createAgentChat(params); + return R.ok().put("id",agentId); + } + + //与代理聊天流式 + @RequestMapping(value = "/chatToAgentStream") + @Transactional + public R chatToAgentStream(String agentId,String sessionId,String question){ + ragFlowApiUtil.chatToAgentStream(agentId,sessionId,question); + return R.ok(); + } +