四时宝库

程序员的知识宝库

聊聊langchain4j的Advanced RAG(lanib-4)

本文主要研究一下langchain4j的Advanced RAG

核心流程

  • 将UserMessage转换为一个原始的Query
  • QueryTransformer将原始的Query转换为多个Query
  • 每个Query通过QueryRouter被路由到一个或多个ContentRetriever
  • 每个ContentRetriever检索对应Query相关的Content
  • ContentAggregator将所有检索到的Content合并成一个最终排序的列表
  • 这个内容列表被注入到原始的UserMessage中
  • 最后包含原始查询以及注入的相关内容的UserMessage被发送到LLM

示例

public class _02_Advanced_RAG_with_Query_Routing_Example {

    /**
     * Please refer to {@link Naive_RAG_Example} for a basic context.
     * 

* Advanced RAG in LangChain4j is described here: https://github.com/langchain4j/langchain4j/pull/538 *

* This example showcases the implementation of a more advanced RAG application * using a technique known as "query routing". *

* Often, private data is spread across multiple sources and formats. * This might include internal company documentation on Confluence, your project's code in a Git repository, * a relational database with user data, or a search engine with the products you sell, among others. * In a RAG flow that utilizes data from multiple sources, you will likely have multiple * {@link EmbeddingStore}s or {@link ContentRetriever}s. * While you could route each user query to all available {@link ContentRetriever}s, * this approach might be inefficient and counterproductive. *

* "Query routing" is the solution to this challenge. It involves directing a query to the most appropriate * {@link ContentRetriever} (or several). Routing can be implemented in various ways: * - Using rules (e.g., depending on the user's privileges, location, etc.). * - Using keywords (e.g., if a query contains words X1, X2, X3, route it to {@link ContentRetriever} X, etc.). * - Using semantic similarity (see EmbeddingModelTextClassifierExample in this repository). * - Using an LLM to make a routing decision. *

* For scenarios 1, 2, and 3, you can implement a custom {@link QueryRouter}. * For scenario 4, this example will demonstrate how to use a {@link LanguageModelQueryRouter}. */ public static void main(String[] args) { Assistant assistant = createAssistant(); // First, ask "What is the legacy of John Doe?" // Then, ask "Can I cancel my reservation?" // Now, see the logs to observe how the queries are routed to different retrievers. startConversationWith(assistant); } private static Assistant createAssistant() { EmbeddingModel embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel(); // Let's create a separate embedding store specifically for biographies. EmbeddingStore biographyEmbeddingStore = embed(toPath("documents/biography-of-john-doe.txt"), embeddingModel); ContentRetriever biographyContentRetriever = EmbeddingStoreContentRetriever.builder() .embeddingStore(biographyEmbeddingStore) .embeddingModel(embeddingModel) .maxResults(2) .minScore(0.6) .build(); // Additionally, let's create a separate embedding store dedicated to terms of use. EmbeddingStore termsOfUseEmbeddingStore = embed(toPath("documents/miles-of-smiles-terms-of-use.txt"), embeddingModel); ContentRetriever termsOfUseContentRetriever = EmbeddingStoreContentRetriever.builder() .embeddingStore(termsOfUseEmbeddingStore) .embeddingModel(embeddingModel) .maxResults(2) .minScore(0.6) .build(); ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder() .apiKey(OPENAI_API_KEY) .modelName(GPT_4_O_MINI) .build(); // Let's create a query router. Map retrieverToDescription = new HashMap<>(); retrieverToDescription.put(biographyContentRetriever, "biography of John Doe"); retrieverToDescription.put(termsOfUseContentRetriever, "terms of use of car rental company"); QueryRouter queryRouter = new LanguageModelQueryRouter(chatLanguageModel, retrieverToDescription); RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() .queryRouter(queryRouter) .build(); return AiServices.builder(Assistant.class) .chatLanguageModel(chatLanguageModel) .retrievalAugmentor(retrievalAugmentor) .chatMemory(MessageWindowChatMemory.withMaxMessages(10)) .build(); } private static EmbeddingStore embed(Path documentPath, EmbeddingModel embeddingModel) { DocumentParser documentParser = new TextDocumentParser(); Document document = loadDocument(documentPath, documentParser); DocumentSplitter splitter = DocumentSplitters.recursive(300, 0); List segments = splitter.split(document); List embeddings = embeddingModel.embedAll(segments).content(); EmbeddingStore embeddingStore = new InMemoryEmbeddingStore<>(); embeddingStore.addAll(embeddings, segments); return embeddingStore; } }

这里使用了DefaultRetrievalAugmentor来设置了LanguageModelQueryRouter,这里设置了biographyContentRetriever、
termsOfUseContentRetriever两个ContentRetriever。

源码解析

RetrievalAugmentor

dev/langchain4j/rag/RetrievalAugmentor.java

@Experimental
public interface RetrievalAugmentor {

    /**
     * Augments the {@link ChatMessage} provided in the {@link AugmentationRequest} with retrieved {@link Content}s.
     * 
* This method has a default implementation in order to temporarily support * current custom implementations of {@code RetrievalAugmentor}. The default implementation will be removed soon. * * @param augmentationRequest The {@code AugmentationRequest} containing the {@code ChatMessage} to augment. * @return The {@link AugmentationResult} containing the augmented {@code ChatMessage}. */ default AugmentationResult augment(AugmentationRequest augmentationRequest) { if (!(augmentationRequest.chatMessage() instanceof UserMessage)) { throw runtime("Please implement 'AugmentationResult augment(AugmentationRequest)' method " + "in order to augment " + augmentationRequest.chatMessage().getClass()); } UserMessage augmented = augment((UserMessage) augmentationRequest.chatMessage(), augmentationRequest.metadata()); return AugmentationResult.builder() .chatMessage(augmented) .build(); } /** * Augments the provided {@link UserMessage} with retrieved content. * * @param userMessage The {@link UserMessage} to be augmented. * @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation. * @return The augmented {@link UserMessage}. * @deprecated Use/implement {@link #augment(AugmentationRequest)} instead. */ @Deprecated UserMessage augment(UserMessage userMessage, Metadata metadata); }

RetrievalAugmentor接口定义了augment(AugmentationRequest augmentationRequest)方法,它作为langchain4j的RAG入口,负责根据AugmentationRequest来检索相关Content,它提供了默认实现主要是适配废弃的augment(UserMessage userMessage, Metadata metadata)方法

DefaultRetrievalAugmentor

dev/langchain4j/rag/DefaultRetrievalAugmentor.java

public class DefaultRetrievalAugmentor implements RetrievalAugmentor {

    private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);

    private final QueryTransformer queryTransformer;
    private final QueryRouter queryRouter;
    private final ContentAggregator contentAggregator;
    private final ContentInjector contentInjector;
    private final Executor executor;

    public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
                                     QueryRouter queryRouter,
                                     ContentAggregator contentAggregator,
                                     ContentInjector contentInjector,
                                     Executor executor) {
        this.queryTransformer = getOrDefault(queryTransformer, DefaultQueryTransformer::new);
        this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
        this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
        this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
        this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
    }

    private static ExecutorService createDefaultExecutor() {
        return new ThreadPoolExecutor(
            0, Integer.MAX_VALUE,
            1, SECONDS,
            new SynchronousQueue<>()
        );
    }

    /**
     * @deprecated use {@link #augment(AugmentationRequest)} instead.
     */
    @Override
    @Deprecated
    public UserMessage augment(UserMessage userMessage, Metadata metadata) {
        AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
        return (UserMessage) augment(augmentationRequest).chatMessage();
    }

    @Override
    public AugmentationResult augment(AugmentationRequest augmentationRequest) {

        ChatMessage chatMessage = augmentationRequest.chatMessage();
        Metadata metadata = augmentationRequest.metadata();

        Query originalQuery = Query.from(chatMessage.text(), metadata);

        Collection queries = queryTransformer.transform(originalQuery);
        logQueries(originalQuery, queries);

        Map<Query, Collection<List>> queryToContents = process(queries);

        List contents = contentAggregator.aggregate(queryToContents);
        log(queryToContents, contents);

        ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
        log(augmentedChatMessage);

        return AugmentationResult.builder()
            .chatMessage(augmentedChatMessage)
            .contents(contents)
            .build();
    }

    //......

}    

DefaultRetrievalAugmentor实现了RetrievalAugmentor接口,它定义了queryTransformer、queryRouter、contentAggregator、contentInjector、executor;其augment方法先是将chatMessage转换为originalQuery,接着通过
queryTransformer.transform得到queries,接着通过process得到一系列的Content,然后通过
contentAggregator.aggregate来聚合contents,最后通过contentInjector.inject来将contents注入到chatMessage得到augmentedChatMessage,返回结果AugmentationResult包含了augmentedChatMessage及注入的contents。

其构造器要求queryRouter不为null,对于queryTransformer为null的默认使用DefaultQueryTransformer,对于contentAggregator为null的默认使用DefaultContentAggregator,对于contentInjector为null的默认使用DefaultContentInjector,对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor

    private Map<Query, Collection<List>> process(Collection queries) {
        if (queries.size() == 1) {
            Query query = queries.iterator().next();
            Collection retrievers = queryRouter.route(query);
            if (retrievers.size() == 1) {
                ContentRetriever contentRetriever = retrievers.iterator().next();
                List contents = contentRetriever.retrieve(query);
                return singletonMap(query, singletonList(contents));
            } else if (retrievers.size() > 1) {
                Collection<List> contents = retrieveFromAll(retrievers, query).join();
                return singletonMap(query, contents);
            } else {
                return emptyMap();
            }
        } else if (queries.size() > 1) {
            Map<Query, CompletableFuture<Collection<List>>> queryToFutureContents = new ConcurrentHashMap<>();
            queries.forEach(query -> {
                CompletableFuture<Collection<List>> futureContents =
                    supplyAsync(() -> {
                            Collection retrievers = queryRouter.route(query);
                            log(query, retrievers);
                            return retrievers;
                        },
                        executor
                    ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
                queryToFutureContents.put(query, futureContents);
            });
            return join(queryToFutureContents);
        } else {
            return emptyMap();
        }
    }

    private CompletableFuture<Collection<List>> retrieveFromAll(Collection retrievers,
                                                                         Query query) {
        List<CompletableFuture<List>> futureContents = retrievers.stream()
            .map(retriever -> supplyAsync(() -> retrieve(retriever, query), executor))
            .collect(Collectors.toList());

        return allOf(futureContents.toArray(new CompletableFuture[0]))
            .thenApply(ignored ->
                futureContents.stream()
                    .map(CompletableFuture::join)
                    .collect(Collectors.toList()));
    }

    private static List retrieve(ContentRetriever retriever, Query query) {
        List contents = retriever.retrieve(query);
        log(query, retriever, contents);
        return contents;
    }

    private static Map<Query, Collection<List>> join(
        Map<Query, CompletableFuture<Collection<List>>> queryToFutureContents) {
        return allOf(queryToFutureContents.values().toArray(new CompletableFuture[0]))
            .thenApply(ignored ->
                queryToFutureContents.entrySet().stream()
                    .collect(toMap(
                        Map.Entry::getKey,
                        entry -> entry.getValue().join()
                    ))
            ).join();
    }    

process方法主要是通过queryRouter.route(query)来获取retrievers的路由,之后对每个ContentRetriever执行retrieve获取对应的List,它针对queries是1个还是多个做了特殊处理,多个则通过executor来并发执行,最后通过join来等待。

QueryTransformer

dev/langchain4j/rag/query/transformer/QueryTransformer.java

@Experimental
public interface QueryTransformer {

    /**
     * Transforms the given {@link Query} into one or multiple {@link Query}s.
     *
     * @param query The {@link Query} to be transformed.
     * @return A collection of one or more {@link Query}s derived from the original {@link Query}.
     */
    Collection transform(Query query);
}

QueryTransformer定义了transform方法,主要用于修改或者扩展原始Query,已知的场景比如:查询压缩、查询扩展、查询重写、Step-back prompting、Hypothetical document embeddings (HyDE)。它有几个实现分别是:DefaultQueryTransformer,
CompressingQueryTransformer, ExpandingQueryTransformer。

DefaultQueryTransformer

dev/langchain4j/rag/query/transformer/DefaultQueryTransformer.java

public class DefaultQueryTransformer implements QueryTransformer {

    @Override
    public Collection transform(Query query) {
        return singletonList(query);
    }
}

DefaultQueryTransformer将query包装为list返回

CompressingQueryTransformer

dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.java

public class CompressingQueryTransformer implements QueryTransformer {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Read and understand the conversation between the User and the AI. \
                    Then, analyze the new query from the User. \
                    Identify all relevant details, terms, and context from both the conversation and the new query. \
                    Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.
                    
                    Conversation:
                    {{chatMemory}}
                    
                    User query: {{query}}
                    
                    It is very important that you provide only reformulated query and nothing else! \
                    Do not prepend a query with anything!"""
    );

    protected final PromptTemplate promptTemplate;
    protected final ChatLanguageModel chatLanguageModel;

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE);
    }

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
    }

    public static CompressingQueryTransformerBuilder builder() {
        return new CompressingQueryTransformerBuilder();
    }

    @Override
    public Collection transform(Query query) {

        List chatMemory = query.metadata().chatMemory();
        if (chatMemory.isEmpty()) {
            // no need to compress if there are no previous messages
            return singletonList(query);
        }

        Prompt prompt = createPrompt(query, format(chatMemory));
        String compressedQueryText = chatLanguageModel.chat(prompt.text());
        Query compressedQuery = query.metadata() == null
                ? Query.from(compressedQueryText)
                : Query.from(compressedQueryText, query.metadata());
        return singletonList(compressedQuery);
    }

    protected String format(List chatMemory) {
        return chatMemory.stream()
                .map(this::format)
                .filter(Objects::nonNull)
                .collect(joining("\n"));
    }

    protected String format(ChatMessage message) {
        if (message instanceof UserMessage) {
            return "User: " + message.text();
        } else if (message instanceof AiMessage aiMessage) {
            if (aiMessage.hasToolExecutionRequests()) {
                return null;
            }
            return "AI: " + aiMessage.text();
        } else {
            return null;
        }
    }

    protected Prompt createPrompt(Query query, String chatMemory) {
        Map variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("chatMemory", chatMemory);
        return promptTemplate.apply(variables);
    }

    //......
}    


CompressingQueryTransformer在开启ChatMemory的时候才可以用,它通过DEFAULT_PROMPT_TEMPLATE将query和历史ChatMessage一起构建prompt发给LLM,让LLM进行压缩

ExpandingQueryTransformer

dev/langchain4j/rag/query/transformer/ExpandingQueryTransformer.java

public class ExpandingQueryTransformer implements QueryTransformer {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Generate {{n}} different versions of a provided user query. \
                    Each version should be worded differently, using synonyms or alternative sentence structures, \
                    but they should all retain the original meaning. \
                    These versions will be used to retrieve relevant documents. \
                    It is very important to provide each query version on a separate line, \
                    without enumerations, hyphens, or any additional formatting!
                    User query: {{query}}"""
    );
    public static final int DEFAULT_N = 3;

    protected final ChatLanguageModel chatLanguageModel;
    protected final PromptTemplate promptTemplate;
    protected final int n;

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, DEFAULT_N);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, int n) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE, n);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
        this(chatLanguageModel, ensureNotNull(promptTemplate, "promptTemplate"), DEFAULT_N);
    }

    public ExpandingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate, Integer n) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        this.n = ensureGreaterThanZero(getOrDefault(n, DEFAULT_N), "n");
    }

    public static ExpandingQueryTransformerBuilder builder() {
        return new ExpandingQueryTransformerBuilder();
    }

    @Override
    public Collection transform(Query query) {
        Prompt prompt = createPrompt(query);
        String response = chatLanguageModel.chat(prompt.text());
        List queries = parse(response);
        return queries.stream()
                .map(queryText -> query.metadata() == null
                        ? Query.from(queryText)
                        : Query.from(queryText, query.metadata()))
                .collect(toList());
    }

    protected Prompt createPrompt(Query query) {
        Map variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("n", n);
        return promptTemplate.apply(variables);
    }

    protected List parse(String queries) {
        return stream(queries.split("\n"))
                .filter(Utils::isNotNullOrBlank)
                .collect(toList());
    }

    //......
}    

ExpandingQueryTransformer利用DEFAULT_PROMPT_TEMPLATE让LLM给出用户query的n个不同版本,默认n为3

QueryRouter

dev/langchain4j/rag/query/router/QueryRouter.java

@Experimental
public interface QueryRouter {

    /**
     * Routes the given {@link Query} to one or multiple {@link ContentRetriever}s.
     *
     * @param query The {@link Query} to be routed.
     * @return A collection of one or more {@link ContentRetriever}s to which the {@link Query} should be routed.
     */
    Collection route(Query query);
}

QueryRouter定义了route方法,根据query返回一系列的ContentRetriever,它有DefaultQueryRouter、LanguageModelQueryRouter两个实现

DefaultQueryRouter

dev/langchain4j/rag/query/router/DefaultQueryRouter.java

public class DefaultQueryRouter implements QueryRouter {

    private final Collection contentRetrievers;

    public DefaultQueryRouter(ContentRetriever... contentRetrievers) {
        this(asList(contentRetrievers));
    }

    public DefaultQueryRouter(Collection contentRetrievers) {
        this.contentRetrievers = unmodifiableCollection(ensureNotEmpty(contentRetrievers, "contentRetrievers"));
    }

    @Override
    public Collection route(Query query) {
        return contentRetrievers;
    }
}

DefaultQueryRouter构造器要求输入contentRetrievers,route方法直接返回

LanguageModelQueryRouter

dev/langchain4j/rag/query/router/LanguageModelQueryRouter.java

public class LanguageModelQueryRouter implements QueryRouter {

    private static final Logger log = LoggerFactory.getLogger(LanguageModelQueryRouter.class);

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    Based on the user query, determine the most suitable data source(s) \
                    to retrieve relevant information from the following options:
                    {{options}}
                    It is very important that your answer consists of either a single number \
                    or multiple numbers separated by commas and nothing else!
                    User query: {{query}}"""
    );

    protected final ChatLanguageModel chatLanguageModel;
    protected final PromptTemplate promptTemplate;
    protected final String options;
    protected final Map idToRetriever;
    protected final FallbackStrategy fallbackStrategy;

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
                                    Map retrieverToDescription) {
        this(chatLanguageModel, retrieverToDescription, DEFAULT_PROMPT_TEMPLATE, DO_NOT_ROUTE);
    }

    public LanguageModelQueryRouter(ChatLanguageModel chatLanguageModel,
                                    Map retrieverToDescription,
                                    PromptTemplate promptTemplate,
                                    FallbackStrategy fallbackStrategy) {
        this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
        ensureNotEmpty(retrieverToDescription, "retrieverToDescription");
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);

        Map idToRetriever = new HashMap<>();
        StringBuilder optionsBuilder = new StringBuilder();
        int id = 1;
        for (Map.Entry entry : retrieverToDescription.entrySet()) {
            idToRetriever.put(id, ensureNotNull(entry.getKey(), "ContentRetriever"));

            if (id > 1) {
                optionsBuilder.append("\n");
            }
            optionsBuilder.append(id);
            optionsBuilder.append(": ");
            optionsBuilder.append(ensureNotBlank(entry.getValue(), "ContentRetriever description"));

            id++;
        }
        this.idToRetriever = idToRetriever;
        this.options = optionsBuilder.toString();
        this.fallbackStrategy = getOrDefault(fallbackStrategy, DO_NOT_ROUTE);
    }

    public static LanguageModelQueryRouterBuilder builder() {
        return new LanguageModelQueryRouterBuilder();
    }

    @Override
    public Collection route(Query query) {
        Prompt prompt = createPrompt(query);
        try {
            String response = chatLanguageModel.chat(prompt.text());
            return parse(response);
        } catch (Exception e) {
            log.warn("Failed to route query '{}'", query.text(), e);
            return fallback(query, e);
        }
    }

    protected Collection fallback(Query query, Exception e) {
        return switch (fallbackStrategy) {
            case DO_NOT_ROUTE -> {
                log.debug("Fallback: query '{}' will not be routed", query.text());
                yield emptyList();
            }
            case ROUTE_TO_ALL -> {
                log.debug("Fallback: query '{}' will be routed to all available content retrievers", query.text());
                yield new ArrayList<>(idToRetriever.values());
            }
            default -> throw new RuntimeException(e);
        };
    }

    protected Prompt createPrompt(Query query) {
        Map variables = new HashMap<>();
        variables.put("query", query.text());
        variables.put("options", options);
        return promptTemplate.apply(variables);
    }

    protected Collection parse(String choices) {
        return stream(choices.split(","))
                .map(String::trim)
                .map(Integer::parseInt)
                .map(idToRetriever::get)
                .collect(toList());
    }

    /**
     * Strategy applied if the call to the LLM fails of if LLM does not return a valid response.
     * It could be because it was formatted improperly, or it is unclear where to route.
     */
    public enum FallbackStrategy {

        /**
         * In this case, the {@link Query} will not be routed to any {@link ContentRetriever},
         * thus skipping the RAG flow. No content will be appended to the original {@link UserMessage}.
         */
        DO_NOT_ROUTE,

        /**
         * In this case, the {@link Query} will be routed to all {@link ContentRetriever}s.
         */
        ROUTE_TO_ALL,

        /**
         * In this case, an original exception will be re-thrown, and the RAG flow will fail.
         */
        FAIL
    }

    //......
 }   

LanguageModelQueryRouter使用chatLanguageModel去进行路由决策,其构造器要求输入chatLanguageModel以及Map,其中String是关于这个ContentRetriever的描述用于帮助chatLanguageModel决策路由到哪个ContentRetriever,它有定义一个fallbackStrategy用于指定调用chatLanguageModel发生异常的时候如何处理,默认是DO_NOT_ROUTE。

ContentRetriever

dev/langchain4j/rag/content/retriever/ContentRetriever.java

public interface ContentRetriever {

    /**
     * Retrieves relevant {@link Content}s using a given {@link Query}.
     * The {@link Content}s are sorted by relevance, with the most relevant {@link Content}s appearing
     * at the beginning of the returned {@code List}.
     *
     * @param query The {@link Query} to use for retrieval.
     * @return A list of retrieved {@link Content}s.
     */
    List retrieve(Query query);
}

ContentRetriever接口定义了retrieve方法,用于根据query来返回一系列的Content,它有
EmbeddingStoreContentRetriever、WebSearchContentRetriever实现

ContentAggregator

dev/langchain4j/rag/content/aggregator/ContentAggregator.java

@Experimental
public interface ContentAggregator {

    /**
     * Aggregates all {@link Content}s retrieved by all {@link ContentRetriever}s using all {@link Query}s.
     * The {@link Content}s, both on input and output, are sorted by relevance,
     * with the most relevant {@link Content}s appearing at the beginning of {@code List}.
     *
     * @param queryToContents A map from a {@link Query} to all {@code List} retrieved with that {@link Query}.
     *                        Given that each {@link Query} can be routed to multiple {@link ContentRetriever}s, the
     *                        value of this map is a {@code Collection<List>}
     *                        rather than a simple {@code List}.
     * @return A list of aggregated {@link Content}s.
     */
    List aggregate(Map<Query, Collection<List>> queryToContents);
}

ContentAggregator接口定义了aggregate方法,它聚合queryToContents返回一系列的Content,这一步主要是确保传给LLM的Contents是最相关的且是没有冗余的。一些有效的方法包括:重排序(
ReRankingContentAggregator
)、Reciprocal Rank Fusion(ReciprocalRankFuser,DefaultContentAggregator和
ReRankingContentAggregator都有用到)

DefaultContentAggregator

dev/langchain4j/rag/content/aggregator/DefaultContentAggregator.java

public class DefaultContentAggregator implements ContentAggregator {

    @Override
    public List aggregate(Map<Query, Collection<List>> queryToContents) {

        // First, for each query, fuse all contents retrieved from different sources using that query.
        Map<Query, List> fused = fuse(queryToContents);

        // Then, fuse all contents retrieved using all queries
        return ReciprocalRankFuser.fuse(fused.values());
    }

    protected Map<Query, List> fuse(Map<Query, Collection<List>> queryToContents) {
        Map<Query, List> fused = new LinkedHashMap<>();
        for (Query query : queryToContents.keySet()) {
            Collection<List> contents = queryToContents.get(query);
            fused.put(query, ReciprocalRankFuser.fuse(contents));
        }
        return fused;
    }
}

DefaultContentAggregator主要是使用了两阶段的fuse,第一阶段先将每个query检索的所有List合并为一个List;第二阶段再将所有List(第一阶段的结果)合并为一个List。这里使用的是ReciprocalRankFuser.fuse进行合并。

ReRankingContentAggregator

dev/langchain4j/rag/content/aggregator/ReRankingContentAggregator.java

public class ReRankingContentAggregator implements ContentAggregator {

    public static final Function<Map<Query, Collection<List>>, Query> DEFAULT_QUERY_SELECTOR =
            (queryToContents) -> {
                if (queryToContents.size() > 1) {
                    throw illegalArgument(
                            "The 'queryToContents' contains %s queries, making the re-ranking ambiguous. " +
                                    "Because there are multiple queries, it is unclear which one should be " +
                                    "used for re-ranking. Please provide a 'querySelector' in the constructor/builder.",
                            queryToContents.size()
                    );
                }
                return queryToContents.keySet().iterator().next();
            };

    private final ScoringModel scoringModel;
    private final Function<Map<Query, Collection<List>>, Query> querySelector;
    private final Double minScore;
    private final Integer maxResults;

    public ReRankingContentAggregator(ScoringModel scoringModel) {
        this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel,
                                      Function<Map<Query, Collection<List>>, Query> querySelector,
                                      Double minScore) {
        this(scoringModel, querySelector, minScore, null);
    }

    public ReRankingContentAggregator(ScoringModel scoringModel,
                                      Function<Map<Query, Collection<List>>, Query> querySelector,
                                      Double minScore,
                                      Integer maxResults) {
        this.scoringModel = ensureNotNull(scoringModel, "scoringModel");
        this.querySelector = getOrDefault(querySelector, DEFAULT_QUERY_SELECTOR);
        this.minScore = minScore;
        this.maxResults = getOrDefault(maxResults, Integer.MAX_VALUE);
    }

    public static ReRankingContentAggregatorBuilder builder() {
        return new ReRankingContentAggregatorBuilder();
    }

    @Override
    public List aggregate(Map<Query, Collection<List>> queryToContents) {

        if (queryToContents.isEmpty()) {
            return emptyList();
        }

        // Select a query against which all contents will be re-ranked
        Query query = querySelector.apply(queryToContents);

        // For each query, fuse all contents retrieved from different sources using that query
        Map<Query, List> queryToFusedContents = fuse(queryToContents);

        // Fuse all contents retrieved using all queries
        List fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());

        if (fusedContents.isEmpty()) {
            return fusedContents;
        }

        // Re-rank all the fused contents against the query selected by the query selector
        return reRankAndFilter(fusedContents, query);
    }

    protected Map<Query, List> fuse(Map<Query, Collection<List>> queryToContents) {
        Map<Query, List> fused = new LinkedHashMap<>();
        for (Query query : queryToContents.keySet()) {
            Collection<List> contents = queryToContents.get(query);
            fused.put(query, ReciprocalRankFuser.fuse(contents));
        }
        return fused;
    }

    protected List reRankAndFilter(List contents, Query query) {

        List segments = contents.stream()
                .map(Content::textSegment)
                .collect(Collectors.toList());

        List scores = scoringModel.scoreAll(segments, query.text()).content();

        Map segmentToScore = new HashMap<>();
        for (int i = 0; i < segments.size i segmenttoscore.putsegments.geti scores.geti return segmenttoscore.entryset.stream .filterentry -> minScore == null || entry.getValue() >= minScore)
                .sorted(Map.Entry.comparingByValue().reversed())
                .map(entry ->  Content.from(entry.getKey(), Map.of(RERANKED_SCORE, entry.getValue())))
                .limit(maxResults)
                .collect(Collectors.toList());
    }

    //......
}    


ReRankingContentAggregator使用诸如Cohere的ScoringModel进行re-ranking;ScoringModel根据Query来给Contents进行打分,如果输入了多个Query(
比如使用了ExpandingQueryTransformer)那么必须提供一个querySelector来选择一个Query用于对所有Content进行排名;也可以自定义实现根据用于检索的Query对所有Contents进行评分,然后基于这些评分进行重新排序;其aggregate方法先通过querySelector选择一个query,之后进行两阶段fuse,最后通过scoringModel选出来的query和fusedContents进行评分,再根据minScore进行过滤、转换、返回maxResults。

ContentInjector

dev/langchain4j/rag/content/injector/ContentInjector.java

@Experimental
public interface ContentInjector {

    /**
     * Injects given {@link Content}s into a given {@link ChatMessage}.
     * 
* This method has a default implementation in order to temporarily support * current custom implementations of {@code ContentInjector}. The default implementation will be removed soon. * * @param contents The list of {@link Content} to be injected. * @param chatMessage The {@link ChatMessage} into which the {@link Content}s are to be injected. * Can be either a {@link UserMessage} or a {@link SystemMessage}. * @return The {@link UserMessage} with the injected {@link Content}s. */ default ChatMessage inject(List contents, ChatMessage chatMessage) { if (!(chatMessage instanceof UserMessage)) { throw runtime("Please implement 'ChatMessage inject(List, ChatMessage)' method " + "in order to inject contents into " + chatMessage); } return inject(contents, (UserMessage) chatMessage); } /** * Injects given {@link Content}s into a given {@link UserMessage}. * * @param contents The list of {@link Content} to be injected. * @param userMessage The {@link UserMessage} into which the {@link Content}s are to be injected. * @return The {@link UserMessage} with the injected {@link Content}s. * @deprecated Use/implement {@link #inject(List, ChatMessage)} instead. */ @Deprecated UserMessage inject(List contents, UserMessage userMessage); }

ContentInjector定义了inject方法,它将contents注入到userMessage,返回新的ChatMessage

DefaultContentInjector

dev/langchain4j/rag/content/injector/DefaultContentInjector.java

public class DefaultContentInjector implements ContentInjector {

    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
            """
                    {{userMessage}}

                    Answer using the following information:
                    {{contents}}""");

    private final PromptTemplate promptTemplate;
    private final List metadataKeysToInclude;

    public DefaultContentInjector() {
        this(DEFAULT_PROMPT_TEMPLATE, null);
    }

    public DefaultContentInjector(List metadataKeysToInclude) {
        this(DEFAULT_PROMPT_TEMPLATE, ensureNotEmpty(metadataKeysToInclude, "metadataKeysToInclude"));
    }

    public DefaultContentInjector(PromptTemplate promptTemplate) {
        this(ensureNotNull(promptTemplate, "promptTemplate"), null);
    }

    public DefaultContentInjector(PromptTemplate promptTemplate, List metadataKeysToInclude) {
        this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
        this.metadataKeysToInclude = copyIfNotNull(metadataKeysToInclude);
    }

    public static DefaultContentInjectorBuilder builder() {
        return new DefaultContentInjectorBuilder();
    }

    @Override
    public ChatMessage inject(List contents, ChatMessage chatMessage) {

        if (contents.isEmpty()) {
            return chatMessage;
        }

        Prompt prompt = createPrompt(chatMessage, contents);
        if (chatMessage instanceof UserMessage message && isNotNullOrBlank(message.name())) {
            return prompt.toUserMessage(message.name());
        }

        return prompt.toUserMessage();
    }

    protected Prompt createPrompt(ChatMessage chatMessage, List contents) {
        return createPrompt((UserMessage) chatMessage, contents);
    }

    /**
     * @deprecated use {@link #inject(List, ChatMessage)} instead.
     */
    @Override
    @Deprecated
    public UserMessage inject(List contents, UserMessage userMessage) {

        if (contents.isEmpty()) {
            return userMessage;
        }

        Prompt prompt = createPrompt(userMessage, contents);
        if (isNotNullOrBlank(userMessage.name())) {
            return prompt.toUserMessage(userMessage.name());
        }
        return prompt.toUserMessage();
    }

    /**
     * @deprecated implement/override {@link #createPrompt(ChatMessage, List)} instead.
     */
    @Deprecated
    protected Prompt createPrompt(UserMessage userMessage, List contents) {
        Map variables = new HashMap<>();
        variables.put("userMessage", userMessage.singleText());
        variables.put("contents", format(contents));
        return promptTemplate.apply(variables);
    }

    protected String format(List contents) {
        return contents.stream().map(this::format).collect(joining("\n\n"));
    }

    protected String format(Content content) {

        TextSegment segment = content.textSegment();

        if (isNullOrEmpty(metadataKeysToInclude)) {
            return segment.text();
        }

        String segmentContent = segment.text();
        String segmentMetadata = format(segment.metadata());

        return format(segmentContent, segmentMetadata);
    }

    protected String format(Metadata metadata) {
        StringBuilder formattedMetadata = new StringBuilder();
        for (String metadataKey : metadataKeysToInclude) {
            String metadataValue = metadata.getString(metadataKey);
            if (metadataValue != null) {
                if (!formattedMetadata.isEmpty()) {
                    formattedMetadata.append("\n");
                }
                formattedMetadata.append(metadataKey).append(": ").append(metadataValue);
            }
        }
        return formattedMetadata.toString();
    }

    protected String format(String segmentContent, String segmentMetadata) {
        return segmentMetadata.isEmpty()
                ? segmentContent
                : String.format("content: %s\n%s", segmentContent, segmentMetadata);
    }

    //......
}    

DefaultContentInjector通过promptTemplate来将contents注入到userMessage中

小结

langchain4j的Advanced RAG提供的入口是RetrievalAugmentor,它包含了QueryTransformer、QueryRouter、ContentAggregator、ContentInjector、Executor这些属性。

  • DefaultRetrievalAugmentor的构造器要求queryRouter不为null(QueryRouter的route方法会返回ContentRetriever)
  • 对于queryTransformer为null的默认使用DefaultQueryTransformer
  • 对于contentAggregator为null的默认使用DefaultContentAggregator
  • 对于contentInjector为null的默认使用DefaultContentInjector
  • 对于executor为null的默认创建了一个coreSize为0,maximumPoolSize为Integer.MAX_VALUE,keepAliveTime为1s,workQueue为SynchronousQueue的ThreadPoolExecutor。而QueryRouter则包含了ContentRetriever

doc

  • Advanced RAG

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接