Skip to content

Retrieve Chunks

Retrieve relevant chunks from a vector store using MMR (Maximal Marginal Relevance).

retrieve_relevant_chunks(self, query, paper_ids=None, top_k=25, mmr_diversity=1.0)

Retrieve the most relevant chunks for a query using maximal marginal relevance.

Parameters:

Name Type Description Default
query str

Query string

required
paper_ids Optional[List[str]]

Optional list of paper IDs to filter by

None
top_k int

Number of chunks to retrieve

25
mmr_diversity float

Diversity parameter for MMR (higher = more diverse)

1.0

Returns:

Type Description
List[Document]

List of document chunks

Source code in aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def retrieve_relevant_chunks(
    self,
    query: str,
    paper_ids: Optional[List[str]] = None,
    top_k: int = 25,
    mmr_diversity: float = 1.00,
) -> List[Document]:
    """
    Retrieve the most relevant chunks for a query using maximal marginal relevance.

    Args:
        query: Query string
        paper_ids: Optional list of paper IDs to filter by
        top_k: Number of chunks to retrieve
        mmr_diversity: Diversity parameter for MMR (higher = more diverse)

    Returns:
        List of document chunks
    """
    if not self.vector_store:
        logger.error("Failed to build vector store")
        return []

    if paper_ids:
        logger.info("Filtering retrieval to papers: %s", paper_ids)

    # Step 1: Embed the query
    logger.info("Embedding query using model: %s", type(self.embedding_model).__name__)
    query_embedding = np.array(self.embedding_model.embed_query(query))

    # Step 2: Filter relevant documents
    all_docs = [
        doc
        for doc in self.documents.values()
        if not paper_ids or doc.metadata["paper_id"] in paper_ids
    ]

    if not all_docs:
        logger.warning("No documents found after filtering by paper_ids.")
        return []

    # Step 3: Retrieve or compute embeddings for all documents using cache
    logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
    all_embeddings = []
    for doc in all_docs:
        doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
        if doc_id not in self.embeddings:
            logger.info("Embedding missing chunk %s", doc_id)
            emb = self.embedding_model.embed_documents([doc.page_content])[0]
            self.embeddings[doc_id] = emb
        all_embeddings.append(self.embeddings[doc_id])

    # Step 4: Apply MMR
    mmr_indices = maximal_marginal_relevance(
        query_embedding,
        all_embeddings,
        k=top_k,
        lambda_mult=mmr_diversity,
    )

    results = [all_docs[i] for i in mmr_indices]
    logger.info("Retrieved %d chunks using MMR", len(results))
    return results