diff --git a/README.md b/README.md index 6c474c7a..6d83b948 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ Our toolkit currently consists of three intelligent agents, each designed to sim - Forward simulation of both internal and open-source models (BioModels). - Adjust parameters within the model to simulate different conditions. - Query simulation results. +- Extract model information such as species, parameters, units and description. ### 2. Talk2Cells *(Work in Progress)* diff --git a/aiagents4pharma/talk2biomodels/agents/t2b_agent.py b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py index 89b9413b..d5614332 100644 --- a/aiagents4pharma/talk2biomodels/agents/t2b_agent.py +++ b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py @@ -16,6 +16,7 @@ from ..tools.simulate_model import SimulateModelTool from ..tools.custom_plotter import CustomPlotterTool from ..tools.ask_question import AskQuestionTool +from ..tools.parameter_scan import ParameterScanTool from ..states.state_talk2biomodels import Talk2Biomodels # Initialize logger @@ -35,17 +36,13 @@ def agent_t2b_node(state: Annotated[dict, InjectedState]): return response # Define the tools - simulate_model = SimulateModelTool() - custom_plotter = CustomPlotterTool() - ask_question = AskQuestionTool() - search_model = SearchModelsTool() - get_modelinfo = GetModelInfoTool() tools = ToolNode([ - simulate_model, - ask_question, - custom_plotter, - search_model, - get_modelinfo + SimulateModelTool(), + AskQuestionTool(), + CustomPlotterTool(), + SearchModelsTool(), + GetModelInfoTool(), + ParameterScanTool() ]) # Define the model diff --git a/aiagents4pharma/talk2biomodels/models/basico_model.py b/aiagents4pharma/talk2biomodels/models/basico_model.py index 43e1a95c..25f8b109 100644 --- a/aiagents4pharma/talk2biomodels/models/basico_model.py +++ b/aiagents4pharma/talk2biomodels/models/basico_model.py @@ -48,52 +48,49 @@ def check_biomodel_id_or_sbml_file_path(self): self.name = basico.model_info.get_model_name(model=self.copasi_model) return self - def simulate(self, - parameters: Optional[Dict[str, Union[float, int]]] = None, - duration: Union[int, float] = 10, - interval: int = 10 - ) -> pd.DataFrame: + def update_parameters(self, parameters: Dict[str, Union[float, int]]) -> None: + """ + Update model parameters with new values. + """ + # Update parameters in the model + for param_name, param_value in parameters.items(): + # check if the param_name is not None + if param_name is None: + continue + # if param is a kinetic parameter + df_all_params = basico.model_info.get_parameters(model=self.copasi_model) + if param_name in df_all_params.index.tolist(): + basico.model_info.set_parameters(name=param_name, + exact=True, + initial_value=param_value, + model=self.copasi_model) + # if param is a species + else: + basico.model_info.set_species(name=param_name, + exact=True, + initial_concentration=param_value, + model=self.copasi_model) + + def simulate(self, duration: Union[int, float] = 10, interval: int = 10) -> pd.DataFrame: """ Simulate the COPASI model over a specified range of time points. Args: - parameters: Dictionary of model parameters to update before simulation. duration: Duration of the simulation in time units. interval: Interval between time points in the simulation. Returns: Pandas DataFrame with time-course simulation results. """ - - # Update parameters in the model - if parameters: - for param_name, param_value in parameters.items(): - # check if the param_name is not None - if param_name is None: - continue - # if param is a kinectic parameter - df_all_params = basico.model_info.get_parameters(model=self.copasi_model) - if param_name in df_all_params.index.tolist(): - basico.model_info.set_parameters(name=param_name, - exact=True, - initial_value=param_value, - model=self.copasi_model) - # if param is a species - else: - basico.model_info.set_species(name=param_name, - exact=True, - initial_concentration=param_value, - model=self.copasi_model) - # Run the simulation and return results df_result = basico.run_time_course(model=self.copasi_model, intervals=interval, duration=duration) - # Replace curly braces in column headers with square brackets - # Because curly braces in the world of LLMS are used for - # structured output - df_result.columns = df_result.columns.str.replace('{', '[', regex=False).\ - str.replace('}', ']', regex=False) + # # Replace curly braces in column headers with square brackets + # # Because curly braces in the world of LLMS are used for + # # structured output + # df_result.columns = df_result.columns.str.replace('{', '[', regex=False).\ + # str.replace('}', ']', regex=False) # Reset the index df_result.reset_index(inplace=True) # Store the simulation results diff --git a/aiagents4pharma/talk2biomodels/models/sys_bio_model.py b/aiagents4pharma/talk2biomodels/models/sys_bio_model.py index fce48985..4e7faeb6 100644 --- a/aiagents4pharma/talk2biomodels/models/sys_bio_model.py +++ b/aiagents4pharma/talk2biomodels/models/sys_bio_model.py @@ -35,18 +35,21 @@ def get_model_metadata(self) -> Dict[str, Union[str, int]]: Returns: dict: Dictionary with model metadata """ + @abstractmethod + def update_parameters(self, parameters: Dict[str, Union[float, int]]) -> None: + """ + Abstract method to update model parameters. + + Args: + parameters: Dictionary of parameter values. + """ @abstractmethod - def simulate(self, - parameters: Dict[str, Union[float, int]], - duration: Union[int, float]) -> List[float]: + def simulate(self, duration: Union[int, float]) -> List[float]: """ Abstract method to run a simulation of the model. - This method should be implemented to simulate model - behavior based on the provided parameters. Args: - parameters: Dictionary of parameter values. duration: Duration of the simulation. Returns: diff --git a/aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py b/aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py index e997f985..967de186 100644 --- a/aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +++ b/aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py @@ -12,13 +12,13 @@ class Talk2Biomodels(AgentState): """ The state for the Talk2BioModels agent. """ - model_id: Annotated[list, operator.add] - # sbml_file_path: str + llm_model: str # A StateGraph may receive a concurrent updates # which is not supported by the StateGraph. # Therefore, we need to use Annotated to specify # the operator for the sbml_file_path field. # https://langchain-ai.github.io/langgraph/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE/ + model_id: Annotated[list, operator.add] sbml_file_path: Annotated[list, operator.add] dic_simulated_data: Annotated[list[dict], operator.add] - llm_model: str + dic_scanned_data: Annotated[list[dict], operator.add] diff --git a/aiagents4pharma/talk2biomodels/tests/test_basico_model.py b/aiagents4pharma/talk2biomodels/tests/test_basico_model.py index 06eba53a..7a8c9311 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_basico_model.py +++ b/aiagents4pharma/talk2biomodels/tests/test_basico_model.py @@ -19,13 +19,14 @@ def test_with_biomodel_id(model): Test initialization of BasicoModel with biomodel_id. """ assert model.biomodel_id == 64 + model.update_parameters(parameters={'Pyruvate': 0.5, 'KmPFKF6P': 1.5}) + df_species = basico.model_info.get_species(model=model.copasi_model) + assert df_species.loc['Pyruvate', 'initial_concentration'] == 0.5 + df_parameters = basico.model_info.get_parameters(model=model.copasi_model) + assert df_parameters.loc['KmPFKF6P', 'initial_value'] == 1.5 # check if the simulation results are a pandas DataFrame object - assert isinstance(model.simulate(parameters={'Pyruvate': 0.5, 'KmPFKF6P': 1.5}, - duration=2, - interval=2), - pd.DataFrame) - assert isinstance(model.simulate(parameters={None: None}, duration=2, interval=2), - pd.DataFrame) + assert isinstance(model.simulate(duration=2, interval=2), pd.DataFrame) + model.update_parameters(parameters={None: None}) assert model.description == basico.biomodels.get_model_info(model.biomodel_id)["description"] def test_with_sbml_file(): @@ -35,8 +36,6 @@ def test_with_sbml_file(): model_object = BasicoModel(sbml_file_path="./BIOMD0000000064_url.xml") assert model_object.sbml_file_path == "./BIOMD0000000064_url.xml" assert isinstance(model_object.simulate(duration=2, interval=2), pd.DataFrame) - assert isinstance(model_object.simulate(parameters={'NADH': 0.5}, duration=2, interval=2), - pd.DataFrame) def test_check_biomodel_id_or_sbml_file_path(): ''' diff --git a/aiagents4pharma/talk2biomodels/tests/test_langgraph.py b/aiagents4pharma/talk2biomodels/tests/test_langgraph.py index 65ea5997..9b38b296 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_langgraph.py +++ b/aiagents4pharma/talk2biomodels/tests/test_langgraph.py @@ -119,6 +119,68 @@ def test_simulate_model_tool(): # Check if the data of the second model contains assert 'mTORC2' in dic_simulated_data[1]['data'] +def test_param_scan_tool(): + ''' + In this test, we will test the parameter_scan tool. + We will prompt it to scan the parameter `kIL6RBind` + from 1 to 100 in steps of 10, record the changes + in the concentration of the species `Ab{serum}` in + model 537. + + We will pass the inaccuarate parameter (`KIL6Rbind`) + and species names (just `Ab`) to the tool to test + if it can deal with it. + + We expect the agent to first invoke the parameter_scan + tool and raise an error. It will then invoke another + tool get_modelinfo to get the correct parameter + and species names. Finally, the agent will reinvoke + the parameter_scan tool with the correct parameter + and species names. + + ''' + unique_id = 123 + app = get_app(unique_id) + config = {"configurable": {"thread_id": unique_id}} + app.update_state(config, {"llm_model": "gpt-4o-mini"}) + prompt = """How will the value of Ab in model 537 change + if the param kIL6Rbind is varied from 1 to 100 in steps of 10? + Set the initial `DoseQ2W` concentration to 300. + Reset the IL6{serum} concentration to 100 every 500 hours. + Assume that the model is simulated for 2016 hours with + an interval of 2016.""" + # Invoke the agent + app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config + ) + current_state = app.get_state(config) + reversed_messages = current_state.values["messages"][::-1] + # Loop through the reversed messages until a + # ToolMessage is found. + df = pd.DataFrame(columns=['name', 'status', 'content']) + names = [] + statuses = [] + contents = [] + for msg in reversed_messages: + # Assert that the message is a ToolMessage + # and its status is "error" + if not isinstance(msg, ToolMessage): + continue + names.append(msg.name) + statuses.append(msg.status) + contents.append(msg.content) + df = pd.DataFrame({'name': names, 'status': statuses, 'content': contents}) + # print (df) + assert any((df["status"] == "error") & + (df["name"] == "parameter_scan") & + (df["content"].str.startswith("Error: ValueError('Invalid parameter name:"))) + assert any((df["status"] == "success") & + (df["name"] == "parameter_scan") & + (df["content"].str.startswith("Parameter scan results of"))) + assert any((df["status"] == "success") & + (df["name"] == "get_modelinfo")) + def test_integration(): ''' Test the integration of the tools. @@ -184,9 +246,9 @@ def test_integration(): reversed_messages = current_state.values["messages"][::-1] # Loop through the reversed messages # until a ToolMessage is found. - expected_header = ['Time', 'CRP[serum]', 'CRPExtracellular'] + expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular'] expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)'] - expected_header += ['CRP[liver]'] + expected_header += ['CRP{liver}'] predicted_artifact = [] for msg in reversed_messages: if isinstance(msg, ToolMessage): diff --git a/aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py b/aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py index 271c9aa9..0394884f 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +++ b/aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py @@ -16,6 +16,8 @@ class TestBioModel(SysBioModel): sbml_file_path: Optional[str] = Field(None, description="Path to an SBML file") name: Optional[str] = Field(..., description="Name of the model") description: Optional[str] = Field("", description="Description of the model") + param1: Optional[float] = Field(0.0, description="Parameter 1") + param2: Optional[float] = Field(0.0, description="Parameter 2") def get_model_metadata(self) -> Dict[str, Union[str, int]]: ''' @@ -23,15 +25,18 @@ def get_model_metadata(self) -> Dict[str, Union[str, int]]: ''' return self.biomodel_id - def simulate(self, - parameters: Dict[str, Union[float, int]], - duration: Union[int, float]) -> List[float]: + def update_parameters(self, parameters): + ''' + Update the model parameters. + ''' + self.param1 = parameters.get('param1', 0.0) + self.param2 = parameters.get('param2', 0.0) + + def simulate(self, duration: Union[int, float]) -> List[float]: ''' Simulate the model. ''' - param1 = parameters.get('param1', 0.0) - param2 = parameters.get('param2', 0.0) - return [param1 + param2 * t for t in range(int(duration))] + return [self.param1 + self.param2 * t for t in range(int(duration))] def test_get_model_metadata(): ''' @@ -53,5 +58,6 @@ def test_simulate(): Test the simulate method of the BioModel class. ''' model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model") - results = model.simulate(parameters={'param1': 1.0, 'param2': 2.0}, duration=4.0) + model.update_parameters({'param1': 1.0, 'param2': 2.0}) + results = model.simulate(duration=4.0) assert results == [1.0, 3.0, 5.0, 7.0] diff --git a/aiagents4pharma/talk2biomodels/tools/__init__.py b/aiagents4pharma/talk2biomodels/tools/__init__.py index e860d4e3..7aa21427 100644 --- a/aiagents4pharma/talk2biomodels/tools/__init__.py +++ b/aiagents4pharma/talk2biomodels/tools/__init__.py @@ -6,4 +6,5 @@ from . import ask_question from . import custom_plotter from . import get_modelinfo +from . import parameter_scan from . import load_biomodel diff --git a/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py b/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py index d0fedb46..f582ee79 100644 --- a/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +++ b/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py @@ -47,8 +47,10 @@ class GetModelInfoTool(BaseTool): """ This tool ise used extract model information. """ - name: str = "get_parameters" - description: str = "A tool for extracting model information." + name: str = "get_modelinfo" + description: str = """A tool for extracting name, + description, species, parameters, + compartments, and units from a model.""" args_schema: Type[BaseModel] = GetModelInfoInput def _run(self, @@ -81,7 +83,7 @@ def _run(self, # Extract species from the model if requested_model_info.species: df_species = basico.model_info.get_species(model=model_obj.copasi_model) - dic_results['Species'] = df_species.index.tolist() + dic_results['Species'] = df_species['display_name'].tolist() dic_results['Species'] = ','.join(dic_results['Species']) # Extract parameters from the model diff --git a/aiagents4pharma/talk2biomodels/tools/parameter_scan.py b/aiagents4pharma/talk2biomodels/tools/parameter_scan.py new file mode 100644 index 00000000..b982906d --- /dev/null +++ b/aiagents4pharma/talk2biomodels/tools/parameter_scan.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 + +""" +Tool for parameter scan. +""" + +import logging +from dataclasses import dataclass +from typing import Type, Union, List, Annotated +import pandas as pd +import basico +from pydantic import BaseModel, Field +from langgraph.types import Command +from langgraph.prebuilt import InjectedState +from langchain_core.tools import BaseTool +from langchain_core.messages import ToolMessage +from langchain_core.tools.base import InjectedToolCallId +from .load_biomodel import ModelData, load_biomodel + +# Initialize logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@dataclass +class TimeData: + """ + Dataclass for storing the time data. + """ + duration: Union[int, float] = 100 + interval: Union[int, float] = 10 + +@dataclass +class SpeciesData: + """ + Dataclass for storing the species data. + """ + species_name: List[str] = Field(description="species name", default=[]) + species_concentration: List[Union[int, float]] = Field( + description="initial species concentration", + default=[]) + +@dataclass +class TimeSpeciesNameConcentration: + """ + Dataclass for storing the time, species name, and concentration data. + """ + time: Union[int, float] = Field(description="time point where the event occurs") + species_name: str = Field(description="species name") + species_concentration: Union[int, float] = Field( + description="species concentration at the time point") + +@dataclass +class ReocurringData: + """ + Dataclass for species that reoccur. In other words, the concentration + of the species resets to a certain value after a certain time interval. + """ + data: List[TimeSpeciesNameConcentration] = Field( + description="time, name, and concentration data of species that reoccur", + default=[]) + +@dataclass +class ParameterScanData(BaseModel): + """ + Dataclass for storing the parameter scan data. + """ + species_names: List[str] = Field(description="species names to scan", + default=[]) + parameter_name: str = Field(description="Parameter name to scan", + default_factory=None) + parameter_values: List[Union[int, float]] = Field( + description="Parameter values to scan", + default_factory=None) + +@dataclass +class ArgumentData: + """ + Dataclass for storing the argument data. + """ + time_data: TimeData = Field(description="time data", default=None) + species_data: SpeciesData = Field( + description="species name and initial concentration data", + default=None) + reocurring_data: ReocurringData = Field( + description="""Concentration and time data of species that reoccur + For example, a species whose concentration resets to a certain value + after a certain time interval""") + parameter_scan_data: ParameterScanData = Field( + description="parameter scan data", + default=None) + scan_name: str = Field( + description="""An AI assigned `_` separated name of + the parameter scan experiment based on human query""") + +def add_rec_events(model_object, reocurring_data): + """ + Add reocurring events to the model. + """ + for row in reocurring_data.data: + tp, sn, sc = row.time, row.species_name, row.species_concentration + basico.add_event(f'{sn}_{tp}', + f'Time > {tp}', + [[sn, str(sc)]], + model=model_object.copasi_model) + +def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id): + """ + Prepare the list dictionary of scanned data + that will be passed to the state of the graph. + + Args: + dic_param_scan: Dictionary of parameter scan results. + arg_data: The argument data. + sys_bio_model: The model data. + tool_call_id: The tool call ID. + + Returns: + list: List of dictionary of scanned data. + """ + list_dic_scanned_data = [] + for species_name, df_param_scan in dic_param_scan.items(): + logger.log(logging.INFO, "Parameter scan results for %s with shape %s", + species_name, + df_param_scan.shape) + # Prepare the list dictionary of scanned data + # that will be passed to the state of the graph + list_dic_scanned_data.append({ + 'name': arg_data.scan_name+':'+species_name, + 'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload', + 'tool_call_id': tool_call_id, + 'data': df_param_scan.to_dict() + }) + return list_dic_scanned_data +def run_parameter_scan(model_object, arg_data, dic_species_data, duration, interval) -> dict: + """ + Run parameter scan on the model. + + Args: + model_object: The model object. + arg_data: The argument data. + dic_species_data: Dictionary of species data. + duration: Duration of the simulation. + interval: Interval between time points in the simulation. + + Returns: + dict: Dictionary of parameter scan results. Each key is a species name + and each value is a DataFrame containing the results of the parameter scan. + """ + # Extract all parameter names from the model and verify if the given parameter name is valid + df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model) + all_parameters = df_all_parameters.index.tolist() + if arg_data.parameter_scan_data.parameter_name not in all_parameters: + logger.error( + "Invalid parameter name: %s", arg_data.parameter_scan_data.parameter_name) + raise ValueError( + f"Invalid parameter name: {arg_data.parameter_scan_data.parameter_name}") + # Extract all species name from the model and verify if the given species name is valid + df_all_species = basico.model_info.get_species(model=model_object.copasi_model) + all_species = df_all_species['display_name'].tolist() + # Dictionary to store the parameter scan results + dic_param_scan_results = {} + for species_name in arg_data.parameter_scan_data.species_names: + if species_name not in all_species: + logger.error("Invalid species name: %s", species_name) + raise ValueError(f"Invalid species name: {species_name}") + # Update the fixed model species and parameters + # These are the initial conditions of the model + # set by the user + model_object.update_parameters(dic_species_data) + # Initialize empty DataFrame to store results + # of the parameter scan + df_param_scan = pd.DataFrame() + for param_value in arg_data.parameter_scan_data.parameter_values: + # Update the parameter value in the model + model_object.update_parameters( + {arg_data.parameter_scan_data.parameter_name: param_value}) + # Simulate the model + model_object.simulate(duration=duration, interval=interval) + # If the column name 'Time' is not present in the results DataFrame + if 'Time' not in df_param_scan.columns: + df_param_scan['Time'] = model_object.simulation_results['Time'] + # Add the simulation results to the results DataFrame + col_name = f"{arg_data.parameter_scan_data.parameter_name}_{param_value}" + df_param_scan[col_name] = model_object.simulation_results[species_name] + + logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape) + # Add the results of the parameter scan to the dictionary + dic_param_scan_results[species_name] = df_param_scan + # return df_param_scan + return dic_param_scan_results + +class ParameterScanInput(BaseModel): + """ + Input schema for the ParameterScan tool. + """ + sys_bio_model: ModelData = Field(description="model data", + default=None) + arg_data: ArgumentData = Field(description= + """time, species, and reocurring data + as well as the parameter scan name and + data""", + default=None) + tool_call_id: Annotated[str, InjectedToolCallId] + state: Annotated[dict, InjectedState] + +# Note: It's important that every field has type hints. BaseTool is a +# Pydantic class and not having type hints can lead to unexpected behavior. +class ParameterScanTool(BaseTool): + """ + Tool for parameter scan. + """ + name: str = "parameter_scan" + description: str = """A tool to perform parameter scan + of a list of parameter values for a given species.""" + args_schema: Type[BaseModel] = ParameterScanInput + + def _run(self, + tool_call_id: Annotated[str, InjectedToolCallId], + state: Annotated[dict, InjectedState], + sys_bio_model: ModelData = None, + arg_data: ArgumentData = None + ) -> Command: + """ + Run the tool. + + Args: + tool_call_id (str): The tool call ID. This is injected by the system. + state (dict): The state of the tool. + sys_bio_model (ModelData): The model data. + arg_data (ArgumentData): The argument data. + + Returns: + Command: The updated state of the tool. + """ + logger.log(logging.INFO, "Calling parameter_scan tool %s, %s", + sys_bio_model, arg_data) + sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None + model_object = load_biomodel(sys_bio_model, + sbml_file_path=sbml_file_path) + # Prepare the dictionary of species data + # that will be passed to the simulate method + # of the BasicoModel class + duration = 100.0 + interval = 10 + dic_species_data = {} + if arg_data: + # Prepare the dictionary of species data + if arg_data.species_data is not None: + dic_species_data = dict(zip(arg_data.species_data.species_name, + arg_data.species_data.species_concentration)) + # Add reocurring events (if any) to the model + if arg_data.reocurring_data is not None: + add_rec_events(model_object, arg_data.reocurring_data) + # Set the duration and interval + if arg_data.time_data is not None: + duration = arg_data.time_data.duration + interval = arg_data.time_data.interval + + # Run the parameter scan + dic_param_scan = run_parameter_scan(model_object, + arg_data, + dic_species_data, + duration, + interval) + + logger.log(logging.INFO, "Parameter scan results ready") + # Prepare the list dictionary of scanned data + list_dic_scanned_data = make_list_dic_scanned_data(dic_param_scan, + arg_data, + sys_bio_model, + tool_call_id) + # Prepare the dictionary of updated state for the model + dic_updated_state_for_model = {} + for key, value in { + "model_id": [sys_bio_model.biomodel_id], + "sbml_file_path": [sbml_file_path], + "dic_scanned_data": list_dic_scanned_data, + }.items(): + if value: + dic_updated_state_for_model[key] = value + # Return the updated state + return Command( + update=dic_updated_state_for_model|{ + # update the message history + "messages": [ + ToolMessage( + content=f"Parameter scan results of {arg_data.scan_name}", + tool_call_id=tool_call_id + ) + ], + } + ) diff --git a/aiagents4pharma/talk2biomodels/tools/simulate_model.py b/aiagents4pharma/talk2biomodels/tools/simulate_model.py index dabede7e..12938921 100644 --- a/aiagents4pharma/talk2biomodels/tools/simulate_model.py +++ b/aiagents4pharma/talk2biomodels/tools/simulate_model.py @@ -138,7 +138,7 @@ def _run(self, # of the BasicoModel class duration = 100.0 interval = 10 - dic_species_data = None + dic_species_data = {} if arg_data: # Prepare the dictionary of species data if arg_data.species_data is not None: @@ -151,22 +151,21 @@ def _run(self, if arg_data.time_data is not None: duration = arg_data.time_data.duration interval = arg_data.time_data.interval - + # Update the model parameters + model_object.update_parameters(dic_species_data) + logger.log(logging.INFO, + "Following species/parameters updated in the model %s", + dic_species_data) # Simulate the model - df = model_object.simulate( - parameters=dic_species_data, - duration=duration, - interval=interval - ) - + df = model_object.simulate(duration=duration, interval=interval) + logger.log(logging.INFO, "Simulation results ready with shape %s", df.shape) dic_simulated_data = { 'name': arg_data.simulation_name, 'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload', 'tool_call_id': tool_call_id, 'data': df.to_dict() } - - # Prepare the dictionary of updated state for the model + # Prepare the dictionary of updated state dic_updated_state_for_model = {} for key, value in { "model_id": [sys_bio_model.biomodel_id], @@ -175,7 +174,6 @@ def _run(self, }.items(): if value: dic_updated_state_for_model[key] = value - # Return the updated state of the tool return Command( update=dic_updated_state_for_model|{ diff --git a/app/frontend/streamlit_app_talk2biomodels.py b/app/frontend/streamlit_app_talk2biomodels.py index 4207ae19..83f5c12a 100644 --- a/app/frontend/streamlit_app_talk2biomodels.py +++ b/app/frontend/streamlit_app_talk2biomodels.py @@ -250,6 +250,10 @@ # These may contain additional visuals that # need to be displayed to the user. print("ToolMessage", msg) + # Skip the Tool message if it is an error message + if msg.status == "error": + continue + # Create a unique message id to identify the tool call # msg.name is the name of the tool # msg.tool_call_id is the unique id of the tool call @@ -276,67 +280,35 @@ # print (df_selected) else: continue - # # Add Time column to the custom headers - # custom_headers = msg.artifact - # if custom_headers: - # if 'Time' not in msg.artifact: - # custom_headers = ['Time'] + custom_headers - # # Make df with only the custom headers - # df_selected = df_simulated[custom_headers] - # else: - # continue - # Display the toggle button to suppress the table - streamlit_utils.render_toggle( - key="toggle_plotly_"+uniq_msg_id, - toggle_text="Show Plot", - toggle_state=True, - save_toggle=True) - # Display the plotly chart - streamlit_utils.render_plotly( - df_selected, - key="plotly_"+uniq_msg_id, - title=msg.content, - # tool_name=msg.name, - # tool_call_id=msg.tool_call_id, - save_chart=True) # Display the toggle button to suppress the table - streamlit_utils.render_toggle( - key="toggle_table_"+uniq_msg_id, - toggle_text="Show Table", - toggle_state=False, - save_toggle=True) - # Display the table - streamlit_utils.render_table( - df_selected, - key="dataframe_"+uniq_msg_id, - # tool_name=msg.name, - # tool_call_id=msg.tool_call_id, - save_table=True) - st.empty() - # elif msg.name in ["ask_question"]: - # # df_simulated = pd.DataFrame.from_dict( - # # current_state.values["dic_simulated_data"]) - # dic_simulated = current_state.values["dic_simulated_data"] - # # print (dic_simulated) - # print (msg.tool_call_id) - # for entry in dic_simulated: - # print (entry.keys()) - # if msg.tool_call_id in entry: - # df_simulated = pd.DataFrame.from_dict(entry[msg.tool_call_id]['data']) - # break - # # Display the toggle button to suppress the table - # streamlit_utils.render_toggle( - # key="toggle_table_"+uniq_msg_id, - # toggle_text="Show Table", - # toggle_state=False, - # save_toggle=True) - # # Display the table - # streamlit_utils.render_table( - # df_simulated, - # key="dataframe_"+uniq_msg_id, - # tool_name=msg.name, - # save_table=True) - # st.empty() + streamlit_utils.render_table_plotly( + uniq_msg_id, msg.content, df_selected) + elif msg.name == "parameter_scan": + # Convert the scanned data to a single dictionary + print ('-', len(current_state.values["dic_scanned_data"])) + dic_scanned_data = {} + for data in current_state.values["dic_scanned_data"]: + print ('-', data['name']) + for key in data: + if key not in dic_scanned_data: + dic_scanned_data[key] = [] + dic_scanned_data[key] += [data[key]] + # Create a pandas dataframe from the dictionary + df_scanned_data = pd.DataFrame.from_dict(dic_scanned_data) + # Get the scanned data for the current tool call + df_scanned_current_tool_call = pd.DataFrame( + df_scanned_data[df_scanned_data['tool_call_id'] == msg.tool_call_id]) + # df_scanned_current_tool_call.drop_duplicates() + # print (df_scanned_current_tool_call) + for count in range(0, len(df_scanned_current_tool_call.index)): + # Get the scanned data for the current tool call + df_selected = pd.DataFrame( + df_scanned_data[df_scanned_data['tool_call_id'] == msg.tool_call_id]['data'].iloc[count]) + # Display the toggle button to suppress the table + streamlit_utils.render_table_plotly( + uniq_msg_id+'_'+str(count), + df_scanned_current_tool_call['name'].iloc[count], + df_selected) # Collect feedback and display the thumbs feedback if st.session_state.get("run_id"): feedback = streamlit_feedback( diff --git a/app/frontend/utils/streamlit_utils.py b/app/frontend/utils/streamlit_utils.py index def75a24..76fc2cb9 100644 --- a/app/frontend/utils/streamlit_utils.py +++ b/app/frontend/utils/streamlit_utils.py @@ -22,6 +22,44 @@ def submit_feedback(user_response): ) st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀") +def render_table_plotly(uniq_msg_id, content, df_selected): + """ + Function to render the table and plotly chart in the chat. + + Args: + uniq_msg_id: str: The unique message id + msg: dict: The message object + df_selected: pd.DataFrame: The selected dataframe + """ + # Display the toggle button to suppress the table + render_toggle( + key="toggle_plotly_"+uniq_msg_id, + toggle_text="Show Plot", + toggle_state=True, + save_toggle=True) + # Display the plotly chart + render_plotly( + df_selected, + key="plotly_"+uniq_msg_id, + title=content, + # tool_name=msg.name, + # tool_call_id=msg.tool_call_id, + save_chart=True) + # Display the toggle button to suppress the table + render_toggle( + key="toggle_table_"+uniq_msg_id, + toggle_text="Show Table", + toggle_state=False, + save_toggle=True) + # Display the table + render_table( + df_selected, + key="dataframe_"+uniq_msg_id, + # tool_name=msg.name, + # tool_call_id=msg.tool_call_id, + save_table=True) + st.empty() + def render_toggle(key: str, toggle_text: str, toggle_state: bool,