Skip to content

Subgraph Summarization

Tool for performing subgraph summarization.

SubgraphSummarizationInput

Bases: BaseModel

SubgraphSummarizationInput is a Pydantic model representing an input for summarizing a given textualized subgraph.

Parameters:

Name Type Description Default
tool_call_id

Tool call ID.

required
state

Injected state.

required
prompt

Prompt to interact with the backend.

required
extraction_name

Name assigned to the subgraph extraction process

required
Source code in aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SubgraphSummarizationInput(BaseModel):
    """
    SubgraphSummarizationInput is a Pydantic model representing an input for
    summarizing a given textualized subgraph.

    Args:
        tool_call_id: Tool call ID.
        state: Injected state.
        prompt: Prompt to interact with the backend.
        extraction_name: Name assigned to the subgraph extraction process
    """

    tool_call_id: Annotated[str, InjectedToolCallId] = Field(
        description="Tool call ID."
    )
    state: Annotated[dict, InjectedState] = Field(description="Injected state.")
    prompt: str = Field(description="Prompt to interact with the backend.")
    extraction_name: str = Field(
        description="""Name assigned to the subgraph extraction process
                                 when the subgraph_extraction tool is invoked."""
    )

SubgraphSummarizationTool

Bases: BaseTool

This tool performs subgraph summarization over textualized graph to highlight the most important information in responding to user's prompt.

Source code in aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class SubgraphSummarizationTool(BaseTool):
    """
    This tool performs subgraph summarization over textualized graph to highlight the most
    important information in responding to user's prompt.
    """

    name: str = "subgraph_summarization"
    description: str = """A tool to perform subgraph summarization over textualized graph
                        for responding to user's follow-up prompt(s)."""
    args_schema: Type[BaseModel] = SubgraphSummarizationInput

    def _run(
        self,
        tool_call_id: Annotated[str, InjectedToolCallId],
        state: Annotated[dict, InjectedState],
        prompt: str,
        extraction_name: str,
    ):
        """
        Run the subgraph summarization tool.

        Args:
            tool_call_id: The tool call ID.
            state: The injected state.
            prompt: The prompt to interact with the backend.
            extraction_name: The name assigned to the subgraph extraction process.
        """
        logger.log(
            logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
        )

        # Load hydra configuration
        with hydra.initialize(version_base=None, config_path="../configs"):
            cfg = hydra.compose(
                config_name="config", overrides=["tools/subgraph_summarization=default"]
            )
            cfg = cfg.tools.subgraph_summarization

        # Load the extracted graph
        extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
        # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)

        # Prepare prompt template
        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", cfg.prompt_subgraph_summarization),
                ("human", "{input}"),
            ]
        )

        # Prepare chain
        chain = prompt_template | state["llm_model"] | StrOutputParser()

        # Return the subgraph and textualized graph as JSON response
        response = chain.invoke(
            {
                "input": prompt,
                "textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
            }
        )

        # Store the response as graph_summary in the extracted graph
        for key, value in extracted_graph.items():
            if key == extraction_name:
                value["graph_summary"] = response

        # Prepare the dictionary of updated state
        dic_updated_state_for_model = {}
        for key, value in {
            "dic_extracted_graph": list(extracted_graph.values()),
        }.items():
            if value:
                dic_updated_state_for_model[key] = value

        # Return the updated state of the tool
        return Command(
            update=dic_updated_state_for_model
            | {
                # update the message history
                "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
            }
        )

_run(tool_call_id, state, prompt, extraction_name)

Run the subgraph summarization tool.

Parameters:

Name Type Description Default
tool_call_id Annotated[str, InjectedToolCallId]

The tool call ID.

required
state Annotated[dict, InjectedState]

The injected state.

required
prompt str

The prompt to interact with the backend.

required
extraction_name str

The name assigned to the subgraph extraction process.

required
Source code in aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def _run(
    self,
    tool_call_id: Annotated[str, InjectedToolCallId],
    state: Annotated[dict, InjectedState],
    prompt: str,
    extraction_name: str,
):
    """
    Run the subgraph summarization tool.

    Args:
        tool_call_id: The tool call ID.
        state: The injected state.
        prompt: The prompt to interact with the backend.
        extraction_name: The name assigned to the subgraph extraction process.
    """
    logger.log(
        logging.INFO, "Invoking subgraph_summarization tool for %s", extraction_name
    )

    # Load hydra configuration
    with hydra.initialize(version_base=None, config_path="../configs"):
        cfg = hydra.compose(
            config_name="config", overrides=["tools/subgraph_summarization=default"]
        )
        cfg = cfg.tools.subgraph_summarization

    # Load the extracted graph
    extracted_graph = {dic["name"]: dic for dic in state["dic_extracted_graph"]}
    # logger.log(logging.INFO, "Extracted graph: %s", extracted_graph)

    # Prepare prompt template
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", cfg.prompt_subgraph_summarization),
            ("human", "{input}"),
        ]
    )

    # Prepare chain
    chain = prompt_template | state["llm_model"] | StrOutputParser()

    # Return the subgraph and textualized graph as JSON response
    response = chain.invoke(
        {
            "input": prompt,
            "textualized_subgraph": extracted_graph[extraction_name]["graph_text"],
        }
    )

    # Store the response as graph_summary in the extracted graph
    for key, value in extracted_graph.items():
        if key == extraction_name:
            value["graph_summary"] = response

    # Prepare the dictionary of updated state
    dic_updated_state_for_model = {}
    for key, value in {
        "dic_extracted_graph": list(extracted_graph.values()),
    }.items():
        if value:
            dic_updated_state_for_model[key] = value

    # Return the updated state of the tool
    return Command(
        update=dic_updated_state_for_model
        | {
            # update the message history
            "messages": [ToolMessage(content=response, tool_call_id=tool_call_id)]
        }
    )