Skip to content

Custom plots

Tool for plotting a custom figure.

CustomPlotterInput

Bases: BaseModel

Input schema for the PlotImage tool.

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
19
20
21
22
23
24
25
class CustomPlotterInput(BaseModel):
    """
    Input schema for the PlotImage tool.
    """
    question: str = Field(description="Description of the plot")
    simulation_name: str = Field(description="Name assigned to the simulation")
    state: Annotated[dict, InjectedState]

CustomPlotterTool

Bases: BaseTool

Tool for making custom plots

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class CustomPlotterTool(BaseTool):
    """
    Tool for making custom plots
    """
    name: str = "custom_plotter"
    description: str = "A tool to make custom plots of the simulation results"
    args_schema: Type[BaseModel] = CustomPlotterInput
    response_format: str = "content_and_artifact"

    def _run(self,
             question: str,
             simulation_name: str,
             state: Annotated[dict, InjectedState]
             ) -> Tuple[str, Union[None, List[str]]]:
        """
        Run the tool.

        Args:
            question (str): The question about the custom plot.
            state (dict): The state of the graph.

        Returns:
            str: The answer to the question
        """
        logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
        dic_simulated_data = {}
        for data in state["dic_simulated_data"]:
            for key in data:
                if key not in dic_simulated_data:
                    dic_simulated_data[key] = []
                dic_simulated_data[key] += [data[key]]
        # Create a pandas dataframe from the dictionary
        df = pd.DataFrame.from_dict(dic_simulated_data)
        # Get the simulated data for the current tool call
        df = pd.DataFrame(
                df[df['name'] == simulation_name]['data'].iloc[0]
                )
        # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
        species_names = df.columns.tolist()
        # Exclude the time column
        species_names.remove('Time')
        # In the following code, we extract the species
        # from the user question. We use Literal to restrict
        # the species names to the ones available in the
        # simulation results.
        class CustomHeader(TypedDict):
            """
            A list of species based on user question.
            """
            relevant_species: Union[None, List[Literal[*species_names]]] = Field(
                    description="""List of species based on user question.
                    If no relevant species are found, it will be None.""")
        # Create an instance of the LLM model
        llm = ChatOpenAI(model=state['llm_model'], temperature=0)
        llm_with_structured_output = llm.with_structured_output(CustomHeader)
        results = llm_with_structured_output.invoke(question)
        extracted_species = []
        # Extract the species from the results
        # that are available in the simulation results
        for species in results['relevant_species']:
            if species in species_names:
                extracted_species.append(species)
        logger.info("Extracted species: %s", extracted_species)
        if len(extracted_species) == 0:
            return "No species found in the simulation results that matches the user prompt.", None
        # Include the time column
        extracted_species.insert(0, 'Time')
        return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')

_run(question, simulation_name, state)

Run the tool.

Parameters:

Name Type Description Default
question str

The question about the custom plot.

required
state dict

The state of the graph.

required

Returns:

Name Type Description
str Tuple[str, Union[None, List[str]]]

The answer to the question

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def _run(self,
         question: str,
         simulation_name: str,
         state: Annotated[dict, InjectedState]
         ) -> Tuple[str, Union[None, List[str]]]:
    """
    Run the tool.

    Args:
        question (str): The question about the custom plot.
        state (dict): The state of the graph.

    Returns:
        str: The answer to the question
    """
    logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
    dic_simulated_data = {}
    for data in state["dic_simulated_data"]:
        for key in data:
            if key not in dic_simulated_data:
                dic_simulated_data[key] = []
            dic_simulated_data[key] += [data[key]]
    # Create a pandas dataframe from the dictionary
    df = pd.DataFrame.from_dict(dic_simulated_data)
    # Get the simulated data for the current tool call
    df = pd.DataFrame(
            df[df['name'] == simulation_name]['data'].iloc[0]
            )
    # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
    species_names = df.columns.tolist()
    # Exclude the time column
    species_names.remove('Time')
    # In the following code, we extract the species
    # from the user question. We use Literal to restrict
    # the species names to the ones available in the
    # simulation results.
    class CustomHeader(TypedDict):
        """
        A list of species based on user question.
        """
        relevant_species: Union[None, List[Literal[*species_names]]] = Field(
                description="""List of species based on user question.
                If no relevant species are found, it will be None.""")
    # Create an instance of the LLM model
    llm = ChatOpenAI(model=state['llm_model'], temperature=0)
    llm_with_structured_output = llm.with_structured_output(CustomHeader)
    results = llm_with_structured_output.invoke(question)
    extracted_species = []
    # Extract the species from the results
    # that are available in the simulation results
    for species in results['relevant_species']:
        if species in species_names:
            extracted_species.append(species)
    logger.info("Extracted species: %s", extracted_species)
    if len(extracted_species) == 0:
        return "No species found in the simulation results that matches the user prompt.", None
    # Include the time column
    extracted_species.insert(0, 'Time')
    return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')