引言

在当今信息爆炸的时代,如何从海量文档中快速准确地检索出相关信息是一个重要课题。RAG(Retrieval-Augmented Generation)技术结合了信息检索和文本生成的优势,能够有效地解决这一问题。本文将介绍如何使用Java实现一个基于bge-m3模型的RAG系统,完成文档向量化和相似度匹配功能。

一、RAG系统概述

RAG系统主要由以下几个部分组成:

  1. 文档读取模块:支持多种格式文档的读取

  2. 文本分片模块:将长文本分割为适合处理的片段

  3. 向量化模块:将文本转换为向量表示

  4. 相似度计算模块:计算查询文本与文档片段的相似度

  5. 结果整合模块:返回最相关的文本片段

二、核心代码实现

1. 文档读取模块

ReadFiles类负责从指定目录读取各种格式的文档:

java

复制

下载

public class ReadFiles {
    public static List<String> readDirectoryFiles(String directoryPath) throws IOException {
        List<String> resultList = new ArrayList<>();
        try (DirectoryStream<Path> stream = Files.newDirectoryStream(Paths.get(directoryPath))) {
            for (Path path : stream) {
                if (Files.isRegularFile(path)) {
                    String fileName = path.getFileName().toString();
                    String fileContent = "";
                    // 根据文件类型调用不同的读取方法
                    if (fileName.toLowerCase().endsWith(".txt")) {
                        fileContent = readTextFile(path);
                    } else if (fileName.toLowerCase().endsWith(".docx")) {
                        fileContent = readDocxFile(path);
                    }
                    // 其他格式处理...
                    resultList.add(fileContent);
                }
            }
        }
        return resultList;
    }
    // 各种格式的具体读取方法...
}

该模块支持TXT、DOCX、DOC和PDF等多种格式,确保系统能够处理常见的文档类型。

2. 文本分片模块

SubString类负责将长文本分割为适合处理的片段:

java

复制

下载

public class SubString {
    public static List<String> splitBySentenceWithinLimit(String input, Integer maxNumber) {
        List<String> result = new ArrayList<>();
        int currentPosition = 0;
        int totalLength = input.length();
        
        while (currentPosition < totalLength) {
            int endPosition = Math.min(currentPosition + maxNumber, totalLength);
            String currentChunk = input.substring(currentPosition, endPosition);
            // 查找最后一个句号的位置
            int lastDotIndex = currentChunk.lastIndexOf('。');
            if (lastDotIndex != -1) {
                int actualEnd = currentPosition + lastDotIndex + 1;
                result.add(input.substring(currentPosition, actualEnd));
                currentPosition = actualEnd;
            } else {
                result.add(currentChunk);
                currentPosition = endPosition;
            }
        }
        return result;
    }
}

该方法确保每个文本片段不超过指定长度,并且尽量在句子边界处分割,保持语义完整性。

3. 向量化与相似度计算

核心功能由Collection类实现:

java

复制

下载

public class Collection {
    // 添加文档并生成嵌入向量
    public void addDocuments(List<Document> documents) {
        int numCores = Runtime.getRuntime().availableProcessors();
        ExecutorService executor = Executors.newFixedThreadPool(numCores);
        // 多线程处理文档向量化
        for (Document doc : documents) {
            executor.submit(() -> {
                List<Double> embedding = embeddingFunction.generateEmbedding(doc.getContent());
                doc.setEmbedding(embedding);
                this.documentList.add(doc);
            });
        }
        executor.shutdown();
    }
    
    // 查询相似文档
    public List<Result> query(String queryText, int topK) {
        List<Double> queryEmbedding = embeddingFunction.generateEmbedding(queryText);
        List<Result> results = new ArrayList<>();
        // 计算余弦相似度
        for (Document doc : documentList) {
            double similarity = cosineSimilarity(queryEmbedding, doc.getEmbedding());
            results.add(new Result(doc, similarity));
        }
        // 按相似度排序并返回topK结果
        results.sort((a, b) -> Double.compare(b.getSimilarity(), a.getSimilarity()));
        return results.subList(0, Math.min(topK, results.size()));
    }
}

4. 与Ollama交互的嵌入函数

OllamaEmbeddingFunction类实现了与Ollama服务的交互:

java

复制

下载

public class OllamaEmbeddingFunction implements EmbeddingFunction {
    @Override
    public List<Double> generateEmbedding(String text) throws Exception {
        Map<String, Object> requestBody = new HashMap<>();
        requestBody.put("model", model);
        requestBody.put("prompt", text);
        RequestBody body = RequestBody.create(mapper.writeValueAsString(requestBody), 
            MediaType.parse("application/json"));
        Request request = new Request.Builder().url(url).post(body).build();
        
        try (Response response = client.newCall(request).execute()) {
            JsonNode root = mapper.readTree(response.body().bytes());
            JsonNode embeddingNode = root.get("embedding");
            List<Double> embedding = new ArrayList<>();
            for (JsonNode node : embeddingNode) {
                embedding.add(node.asDouble());
            }
            return embedding;
        }
    }
}

三、系统使用示例

java

复制

下载

public class Rag {
    public static void main(String[] args) {
        // 读取文档
        List<String> inputContentList = ReadFiles.readFilesWork("src/main/resources/");
        // 设置参数
        String queryText = "乌鸦怎么喝到的水?";
        // 执行RAG查询
        String knowledge = Rag.ragWork(inputContentList, 500, queryText, 5, 0.35);
        // 构建最终提示
        String content = "你的角色是:问题答复专家。\n" + 
                         "你的任务是:根据已知指示答复我的问题。\n" + 
                         "已知知识:\n" + knowledge + "\n" + 
                         "根据以上要求,我输入的问题是:" + queryText;
        System.err.println(content);
    }
}

四、性能优化

  1. 多线程处理:在向量化文档时使用多线程,充分利用多核CPU

  2. 批量处理:一次性处理多个文档,减少网络请求开销

  3. 合理分片:控制每个文本片段的大小,平衡处理效率和语义完整性

五、应用场景

  1. 智能客服系统

  2. 企业内部知识库检索

  3. 学术文献检索与分析

  4. 法律条文查询系统

结语

本文介绍了一个基于Java实现的RAG系统,展示了从文档读取、文本处理到向量化检索的完整流程。该系统具有以下特点:

  1. 支持多种文档格式

  2. 高效的文本分片策略

  3. 基于bge-m3模型的向量化能力

  4. 灵活的相似度检索功能

读者可以根据实际需求调整参数,如分片大小、相似度阈值等,以获得最佳效果。完整代码已在上文展示,可直接用于项目开发或作为学习参考。

Logo

加入社区!打开量化的大门,首批课程上线啦!

更多推荐