Skip to content

Custom plots

Tool for plotting a custom y-axis of a simulation plot.

CustomPlotterInput

Bases: BaseModel

Input schema for the custom plotter tool.

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
72
73
74
75
76
77
78
79
80
class CustomPlotterInput(BaseModel):
    """
    Input schema for the custom plotter tool.
    """
    question: str = Field(description="Description of the plot")
    sys_bio_model: ModelData = Field(description="model data",
                                     default=None)
    simulation_name: str = Field(description="Name assigned to the simulation")
    state: Annotated[dict, InjectedState]

CustomPlotterTool

Bases: BaseTool

Tool for custom plotting the y-axis of a plot.

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class CustomPlotterTool(BaseTool):
    """
    Tool for custom plotting the y-axis of a plot.
    """
    name: str = "custom_plotter"
    description: str = '''A visualization tool designed to extract and display a subset
                        of the larger simulation plot generated by the simulate_model tool.
                        It allows users to specify particular species for the y-axis, 
                        providing a more targeted view of key species without the clutter 
                        of the full plot.'''
    args_schema: Type[BaseModel] = CustomPlotterInput
    response_format: str = "content_and_artifact"

    def _run(self,
             question: str,
             sys_bio_model: ModelData,
             simulation_name: str,
             state: Annotated[dict, InjectedState]
             ) -> Tuple[str, Union[None, List[str]]]:
        """
        Run the tool.

        Args:
            question (str): The question about the custom plot.
            sys_bio_model (ModelData): The model data.
            simulation_name (str): The name assigned to the simulation.
            state (dict): The state of the graph.

        Returns:
            str: The answer to the question
        """
        logger.log(logging.INFO, "Calling custom_plotter tool %s, %s", question, sys_bio_model)
        # Load the model
        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
        model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
        dic_simulated_data = {}
        for data in state["dic_simulated_data"]:
            for key in data:
                if key not in dic_simulated_data:
                    dic_simulated_data[key] = []
                dic_simulated_data[key] += [data[key]]
        # Create a pandas dataframe from the dictionary
        df = pd.DataFrame.from_dict(dic_simulated_data)
        # Get the simulated data for the current tool call
        df = pd.DataFrame(
                df[df['name'] == simulation_name]['data'].iloc[0]
                )
        # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
        species_names = df.columns.tolist()
        # Exclude the time column
        species_names.remove('Time')
        logging.log(logging.INFO, "Species names: %s", species_names)
        # Extract the relevant species from the user question
        results = extract_relevant_species(question, species_names, state)
        print (results)
        if results.relevant_species is None:
            raise ValueError("No species found in the simulation results \
                             that matches the user prompt.")
        extracted_species = []
        # Extract the species from the results
        # that are available in the simulation results
        for species in results.relevant_species:
            if species in species_names:
                extracted_species.append(species)
        logging.info("Extracted species: %s", extracted_species)
        # Include the time column
        extracted_species.insert(0, 'Time')
        return f"Custom plot {simulation_name}",{
                            'dic_data': df[extracted_species].to_dict(orient='records')
                            }| get_model_units(model_object)

_run(question, sys_bio_model, simulation_name, state)

Run the tool.

Parameters:

Name Type Description Default
question str

The question about the custom plot.

required
sys_bio_model ModelData

The model data.

required
simulation_name str

The name assigned to the simulation.

required
state dict

The state of the graph.

required

Returns:

Name Type Description
str Tuple[str, Union[None, List[str]]]

The answer to the question

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _run(self,
         question: str,
         sys_bio_model: ModelData,
         simulation_name: str,
         state: Annotated[dict, InjectedState]
         ) -> Tuple[str, Union[None, List[str]]]:
    """
    Run the tool.

    Args:
        question (str): The question about the custom plot.
        sys_bio_model (ModelData): The model data.
        simulation_name (str): The name assigned to the simulation.
        state (dict): The state of the graph.

    Returns:
        str: The answer to the question
    """
    logger.log(logging.INFO, "Calling custom_plotter tool %s, %s", question, sys_bio_model)
    # Load the model
    sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
    model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
    dic_simulated_data = {}
    for data in state["dic_simulated_data"]:
        for key in data:
            if key not in dic_simulated_data:
                dic_simulated_data[key] = []
            dic_simulated_data[key] += [data[key]]
    # Create a pandas dataframe from the dictionary
    df = pd.DataFrame.from_dict(dic_simulated_data)
    # Get the simulated data for the current tool call
    df = pd.DataFrame(
            df[df['name'] == simulation_name]['data'].iloc[0]
            )
    # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
    species_names = df.columns.tolist()
    # Exclude the time column
    species_names.remove('Time')
    logging.log(logging.INFO, "Species names: %s", species_names)
    # Extract the relevant species from the user question
    results = extract_relevant_species(question, species_names, state)
    print (results)
    if results.relevant_species is None:
        raise ValueError("No species found in the simulation results \
                         that matches the user prompt.")
    extracted_species = []
    # Extract the species from the results
    # that are available in the simulation results
    for species in results.relevant_species:
        if species in species_names:
            extracted_species.append(species)
    logging.info("Extracted species: %s", extracted_species)
    # Include the time column
    extracted_species.insert(0, 'Time')
    return f"Custom plot {simulation_name}",{
                        'dic_data': df[extracted_species].to_dict(orient='records')
                        }| get_model_units(model_object)

extract_relevant_species(question, species_names, state)

Extract the relevant species from the user question.

Parameters:

Name Type Description Default
question str

The user question.

required
species_names list

The species names available in the simulation results.

required
state dict

The state of the graph.

required

Returns:

Name Type Description
CustomHeader

The relevant species

Source code in aiagents4pharma/talk2biomodels/tools/custom_plotter.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def extract_relevant_species(question, species_names, state):
    """
    Extract the relevant species from the user question.

    Args:
        question (str): The user question.
        species_names (list): The species names available in the simulation results.
        state (dict): The state of the graph.

    Returns:
        CustomHeader: The relevant species
    """
    # In the following code, we extract the species
    # from the user question. We use Literal to restrict
    # the species names to the ones available in the
    # simulation results.
    class CustomHeader(BaseModel):
        """
        A list of species based on user question.

        This is a Pydantic model that restricts the species
        names to the ones available in the simulation results.

        If no species is relevant, set the attribute
        `relevant_species` to None.
        """
        relevant_species: Union[None, List[Literal[*species_names]]] = Field(
                description="This is a list of species based on the user question."
                "It is restricted to the species available in the simulation results."
                "If no species is relevant, set this attribute to None."
                "If the user asks for very specific species (for example, using the"
                "keyword `only` in the question), set this attribute to correspond "
                "to the species available in the simulation results, otherwise set it to None."
                )
    # Load hydra configuration
    with hydra.initialize(version_base=None, config_path="../configs"):
        cfg = hydra.compose(config_name='config',
                            overrides=['tools/custom_plotter=default'])
        cfg = cfg.tools.custom_plotter
    # Get the system prompt
    system_prompt = cfg.system_prompt_custom_header
    # Create an instance of the LLM model
    logging.log(logging.INFO, "LLM model: %s", state['llm_model'])
    llm = state['llm_model']
    llm_with_structured_output = llm.with_structured_output(CustomHeader)
    prompt = ChatPromptTemplate.from_messages([("system", system_prompt),
                                               ("human", "{input}")])
    few_shot_structured_llm = prompt | llm_with_structured_output
    return few_shot_structured_llm.invoke(question)