Skip to content

GraphRAG Reasoning

Tool for performing Graph RAG reasoning.

GraphRAGReasoningInput

Bases: BaseModel

GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.

Parameters:

Name Type Description Default
state

Injected state.

required
prompt

Prompt to interact with the backend.

required
extraction_name

Name assigned to the subgraph extraction process

required
Source code in aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class GraphRAGReasoningInput(BaseModel):
    """
    GraphRAGReasoningInput is a Pydantic model representing an input for Graph RAG reasoning.

    Args:
        state: Injected state.
        prompt: Prompt to interact with the backend.
        extraction_name: Name assigned to the subgraph extraction process
    """

    tool_call_id: Annotated[str, InjectedToolCallId] = Field(
        description="Tool call ID."
    )
    state: Annotated[dict, InjectedState] = Field(description="Injected state.")
    prompt: str = Field(description="Prompt to interact with the backend.")
    extraction_name: str = Field(
        description="""Name assigned to the subgraph extraction process
                    when the subgraph_extraction tool is invoked."""
    )

GraphRAGReasoningTool

Bases: BaseTool

This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach over user's request by considering textualized subgraph context and document context.

Source code in aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class GraphRAGReasoningTool(BaseTool):
    """
    This tool performs reasoning using a Graph Retrieval-Augmented Generation (RAG) approach
    over user's request by considering textualized subgraph context and document context.
    """

    name: str = "graphrag_reasoning"
    description: str = """A tool to perform reasoning using a Graph RAG approach
                        by considering textualized subgraph context and document context."""
    args_schema: Type[BaseModel] = GraphRAGReasoningInput

    def _run(
        self,
        tool_call_id: Annotated[str, InjectedToolCallId],
        state: Annotated[dict, InjectedState],
        prompt: str,
        extraction_name: str,
    ):
        """
        Run the Graph RAG reasoning tool.

        Args:
            tool_call_id: The tool call ID.
            state: The injected state.
            prompt: The prompt to interact with the backend.
            extraction_name: The name assigned to the subgraph extraction process.
        """
        logger.log(
            logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
        )

        # Load Hydra configuration
        with hydra.initialize(version_base=None, config_path="../configs"):
            cfg = hydra.compose(
                config_name="config", overrides=["tools/graphrag_reasoning=default"]
            )
            cfg = cfg.tools.graphrag_reasoning

        # Prepare documents
        all_docs = []
        if len(state["uploaded_files"]) != 0:
            for uploaded_file in state["uploaded_files"]:
                if uploaded_file["file_type"] == "drug_data":
                    # Load documents
                    raw_documents = PyPDFLoader(
                        file_path=uploaded_file["file_path"]
                    ).load()

                    # Split documents
                    # May need to find an optimal chunk size and overlap configuration
                    documents = RecursiveCharacterTextSplitter(
                        chunk_size=cfg.splitter_chunk_size,
                        chunk_overlap=cfg.splitter_chunk_overlap,
                    ).split_documents(raw_documents)

                    # Add documents to the list
                    all_docs.extend(documents)

        # Load the extracted graph
        extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
        # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)

        # Set another prompt template
        prompt_template = ChatPromptTemplate.from_messages(
            [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
        )

        # Prepare chain with retrieved documents
        qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
        rag_chain = create_retrieval_chain(
            InMemoryVectorStore.from_documents(
                documents=all_docs, embedding=state["embedding_model"]
            ).as_retriever(
                search_type=cfg.retriever_search_type,
                search_kwargs={
                    "k": cfg.retriever_k,
                    "fetch_k": cfg.retriever_fetch_k,
                    "lambda_mult": cfg.retriever_lambda_mult,
                },
            ),
            qa_chain,
        )

        # Invoke the chain
        response = rag_chain.invoke(
            {
                "input": prompt,
                "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
            }
        )

        return Command(
            update={
                # update the message history
                "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
            }
        )

_run(tool_call_id, state, prompt, extraction_name)

Run the Graph RAG reasoning tool.

Parameters:

Name Type Description Default
tool_call_id Annotated[str, InjectedToolCallId]

The tool call ID.

required
state Annotated[dict, InjectedState]

The injected state.

required
prompt str

The prompt to interact with the backend.

required
extraction_name str

The name assigned to the subgraph extraction process.

required
Source code in aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def _run(
    self,
    tool_call_id: Annotated[str, InjectedToolCallId],
    state: Annotated[dict, InjectedState],
    prompt: str,
    extraction_name: str,
):
    """
    Run the Graph RAG reasoning tool.

    Args:
        tool_call_id: The tool call ID.
        state: The injected state.
        prompt: The prompt to interact with the backend.
        extraction_name: The name assigned to the subgraph extraction process.
    """
    logger.log(
        logging.INFO, "Invoking graphrag_reasoning tool for %s", extraction_name
    )

    # Load Hydra configuration
    with hydra.initialize(version_base=None, config_path="../configs"):
        cfg = hydra.compose(
            config_name="config", overrides=["tools/graphrag_reasoning=default"]
        )
        cfg = cfg.tools.graphrag_reasoning

    # Prepare documents
    all_docs = []
    if len(state["uploaded_files"]) != 0:
        for uploaded_file in state["uploaded_files"]:
            if uploaded_file["file_type"] == "drug_data":
                # Load documents
                raw_documents = PyPDFLoader(
                    file_path=uploaded_file["file_path"]
                ).load()

                # Split documents
                # May need to find an optimal chunk size and overlap configuration
                documents = RecursiveCharacterTextSplitter(
                    chunk_size=cfg.splitter_chunk_size,
                    chunk_overlap=cfg.splitter_chunk_overlap,
                ).split_documents(raw_documents)

                # Add documents to the list
                all_docs.extend(documents)

    # Load the extracted graph
    extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
    # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)

    # Set another prompt template
    prompt_template = ChatPromptTemplate.from_messages(
        [("system", cfg.prompt_graphrag_w_docs), ("human", "{input}")]
    )

    # Prepare chain with retrieved documents
    qa_chain = create_stuff_documents_chain(state["llm_model"], prompt_template)
    rag_chain = create_retrieval_chain(
        InMemoryVectorStore.from_documents(
            documents=all_docs, embedding=state["embedding_model"]
        ).as_retriever(
            search_type=cfg.retriever_search_type,
            search_kwargs={
                "k": cfg.retriever_k,
                "fetch_k": cfg.retriever_fetch_k,
                "lambda_mult": cfg.retriever_lambda_mult,
            },
        ),
        qa_chain,
    )

    # Invoke the chain
    response = rag_chain.invoke(
        {
            "input": prompt,
            "subgraph_summary": extracted_graph[extraction_name]["graph_summary"],
        }
    )

    return Command(
        update={
            # update the message history
            "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
        }
    )