diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 864b02f5..8f17818e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -70,7 +70,7 @@ jobs: run: pip3 install -r requirements.txt # Adjust this according to your project - name: Run tests with coverage - run: coverage run -m pytest --cache-clear aiagents4pharma + run: coverage run -m pytest --cache-clear aiagents4pharma/talk2biomodels/tests/ - name: Check coverage run: | @@ -105,7 +105,7 @@ jobs: run: pip3 install -r requirements.txt # Adjust this according to your project - name: Run tests with coverage - run: coverage run -m pytest --cache-clear + run: coverage run -m pytest --cache-clear aiagents4pharma/talk2biomodels/tests/ - name: Check coverage run: | diff --git a/aiagents4pharma/talk2biomodels/tests/test_ask_question.py b/aiagents4pharma/talk2biomodels/tests/test_ask_question.py index 55a54f08..d506ecb5 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_ask_question.py +++ b/aiagents4pharma/talk2biomodels/tests/test_ask_question.py @@ -4,7 +4,6 @@ import pytest import streamlit as st -from langchain_core.callbacks import CallbackManagerForToolRun from ..tools.ask_question import AskQuestionTool, AskQuestionInput, ModelData from ..models.basico_model import BasicoModel @@ -13,101 +12,93 @@ def ask_question_tool_fixture(): ''' Fixture for creating an instance of AskQuestionTool. ''' - return AskQuestionTool() + return AskQuestionTool(st_session_key="test_key", + sys_bio_model=ModelData( + sbml_file_path="aiagents4pharma/talk2biomodels/tests//BIOMD0000000064_url.xml" + ) + ) + +@pytest.fixture(name="ask_question_tool_with_model_id") +def ask_question_tool__with_model_id_fixture(): + ''' + Fixture for creating an instance of AskQuestionTool. + ''' + return AskQuestionTool(st_session_key="test_key", + sys_bio_model=ModelData(modelid=64)) @pytest.fixture(name="input_data", scope="module") def input_data_fixture(): ''' Fixture for creating an instance of AskQuestionInput. ''' - return AskQuestionInput(question="What is the concentration of Pyruvate at time 5?", - sys_bio_model=ModelData(modelid=64), - st_session_key="test_key" - ) + return AskQuestionInput(question="What is the concentration of Pyruvate at time 5?") def test_run_with_sbml_file(input_data, ask_question_tool): ''' - Test the _run method of the AskQuestionTool class with a valid session key and model data. + Test the _run method of the AskQuestionTool class + with a valid session key and model data. ''' - input_data.sys_bio_model = ModelData(sbml_file_path="./BIOMD0000000064_url.xml") - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + result = ask_question_tool.invoke(input={'question':input_data.question}) assert isinstance(result, str) -def test_run_manager(input_data, ask_question_tool): +def test_run_manager(input_data, ask_question_tool_with_model_id): ''' Test the run manager of the AskQuestionTool class. ''' - run_manager = CallbackManagerForToolRun(run_id=1, handlers=[], inheritable_handlers=False) - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key, - run_manager=run_manager) - assert isinstance(result, str) - run_manager = CallbackManagerForToolRun(run_id=1, - handlers=[], - inheritable_handlers=False, - metadata={"prompt": "Answer the question carefully."}) - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key, - run_manager=run_manager) + ask_question_tool_with_model_id.metadata = { + "prompt": "Answer the question carefully." + } + result = ask_question_tool_with_model_id.invoke(input={'question':input_data.question}) assert isinstance(result, str) def test_run_with_no_model_data_at_all(input_data, ask_question_tool): ''' - Test the _run method of the AskQuestionTool class with a valid session key and model data. + Test the _run method of the AskQuestionTool class + with a valid session key and NO model data. ''' - result = ask_question_tool.call_run(question=input_data.question, - st_session_key=input_data.st_session_key) + ask_question_tool.sys_bio_model = ModelData() + result = ask_question_tool.invoke(input={'question':input_data.question}) assert isinstance(result, str) def test_run_with_session_key(input_data, ask_question_tool): ''' - Test the _run method of the AskQuestionTool class with a missing session key. + Test the _run method of the AskQuestionTool class + with a missing session key. ''' - input_data.sys_bio_model = ModelData(modelid=64) - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + del st.session_state["test_key"] + result = ask_question_tool.invoke(input={'question':input_data.question}) assert isinstance(result, str) def test_run_with_none_key(input_data, ask_question_tool): ''' - Test the _run method of the AskQuestionTool class with a None session key. + Test the _run method of the AskQuestionTool class + with a None session key. ''' - input_data.st_session_key = None - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + ask_question_tool.st_session_key = None + result = ask_question_tool.invoke(input={'question':input_data.question}) assert isinstance(result, str) - input_data.sys_bio_model = ModelData() - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + ask_question_tool.sys_bio_model = ModelData() + result = ask_question_tool.invoke(input={'question':input_data.question}) # No model data or object in the streeamlit key assert result == "Please provide a valid model object or \ Streamlit session key that contains the model object." - input_data.st_session_key = "test_key" # delete the session key form the session state - st.session_state.pop(input_data.st_session_key, None) - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == f"Session key {input_data.st_session_key} " \ - "not found in Streamlit session state." + del st.session_state["test_key"] + ask_question_tool.st_session_key = "test_key" + result = ask_question_tool.invoke(input={'question':input_data.question}) + expected_result = f"Session key {ask_question_tool.st_session_key} " + expected_result += "not found in Streamlit session state." + assert result == expected_result def test_run_with_a_simulated_model(input_data, ask_question_tool): ''' - Test the _run method of the AskQuestionTool class with a valid session key and model data. + Test the _run method of the AskQuestionTool class + with a valid session key and model data. ''' model = BasicoModel(model_id=64) model.simulate(duration=2, interval=2) - input_data.sys_bio_model = ModelData(model_object=model) - result = ask_question_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + ask_question_tool.sys_bio_model = ModelData(model_object=model) + result = ask_question_tool.invoke(input={'question':input_data.question}) assert isinstance(result, str) def test_get_metadata(ask_question_tool): diff --git a/aiagents4pharma/talk2biomodels/tests/test_basico_model.py b/aiagents4pharma/talk2biomodels/tests/test_basico_model.py index 25c89b95..fc5100e5 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_basico_model.py +++ b/aiagents4pharma/talk2biomodels/tests/test_basico_model.py @@ -32,8 +32,10 @@ def test_with_sbml_file(): """ Test initialization of BasicoModel with sbml_file_path. """ - model_object = BasicoModel(sbml_file_path="./BIOMD0000000064_url.xml") - assert model_object.sbml_file_path == "./BIOMD0000000064_url.xml" + model_object = BasicoModel( + sbml_file_path="aiagents4pharma/talk2biomodels/tests/BIOMD0000000064_url.xml") + assert model_object.sbml_file_path == \ + "aiagents4pharma/talk2biomodels/tests/BIOMD0000000064_url.xml" assert isinstance(model_object.simulate(duration=2, interval=2), pd.DataFrame) assert isinstance(model_object.simulate(parameters={'NADH': 0.5}, duration=2, interval=2), pd.DataFrame) diff --git a/aiagents4pharma/talk2biomodels/tests/test_custom_plotter.py b/aiagents4pharma/talk2biomodels/tests/test_custom_plotter.py new file mode 100644 index 00000000..1d3589f4 --- /dev/null +++ b/aiagents4pharma/talk2biomodels/tests/test_custom_plotter.py @@ -0,0 +1,42 @@ +''' +Test cases for plot_figure.py +''' + +import pytest +import streamlit as st +from ..tools.custom_plotter import CustomPlotterTool +from ..models.basico_model import BasicoModel + +ST_SESSION_KEY = "test_key" + +@pytest.fixture(name="custom_plotter_tool") +def custom_plotter_tool_fixture(): + ''' + Fixture for creating an instance of custom_plotter_tool. + ''' + return CustomPlotterTool(st_session_key=ST_SESSION_KEY) + +def test_tool(custom_plotter_tool): + ''' + Test the tool custom_plotter_tool. + ''' + custom_plotter = custom_plotter_tool + st.session_state[ST_SESSION_KEY] = None + response = custom_plotter.invoke(input={ + 'question': 'Plot only Th cells related species' + }) + assert response == "Please run the simulation first before plotting the figure." + st.session_state[ST_SESSION_KEY] = BasicoModel(model_id=537) + response = custom_plotter.invoke(input={ + 'question': 'Plot only Th cells related species' + }) + assert response == "Please run the simulation first before plotting the figure." + st.session_state[ST_SESSION_KEY].simulate() + response = custom_plotter.invoke(input={ + 'question': 'Plot only T helper cells related species' + }) + assert response.startswith("No species found in the simulation") + response = custom_plotter.invoke(input={ + 'question': 'Plot only antibodies' + }) + assert response.startswith("Plotted the figure") diff --git a/aiagents4pharma/talk2biomodels/tests/test_fetch_params.py b/aiagents4pharma/talk2biomodels/tests/test_fetch_params.py new file mode 100644 index 00000000..6c37e2e5 --- /dev/null +++ b/aiagents4pharma/talk2biomodels/tests/test_fetch_params.py @@ -0,0 +1,26 @@ +''' +Test cases for plot_figure.py +''' + +import streamlit as st +from ..models.basico_model import BasicoModel +from ..tools.fetch_parameters import FetchParametersTool + +ST_SESSION_KEY = "test_key" +MODEL_OBJ = BasicoModel(model_id=537) + +def test_tool_fetch_params(): + ''' + Test the tool fetch_params. + ''' + st.session_state[ST_SESSION_KEY] = MODEL_OBJ + fetch_params = FetchParametersTool(st_session_key=ST_SESSION_KEY) + response = fetch_params.invoke(input={ + 'fetch_species': True, + 'fetch_parameters': True + }) + # Check if response is a dictionary + # with keys 'Species' and 'Parameters' + assert isinstance(response, dict) + assert 'Species' in response + assert 'Parameters' in response diff --git a/aiagents4pharma/talk2biomodels/tests/test_model_description.py b/aiagents4pharma/talk2biomodels/tests/test_model_description.py index 1afc33b5..f0a3121e 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_model_description.py +++ b/aiagents4pharma/talk2biomodels/tests/test_model_description.py @@ -4,7 +4,6 @@ import pytest import streamlit as st -from langchain_core.callbacks import CallbackManagerForToolRun from ..tools.model_description import ModelDescriptionInput, ModelDescriptionTool, ModelData from ..models.basico_model import BasicoModel @@ -13,7 +12,7 @@ def model_description_tool_fixture(): ''' Fixture for creating an instance of ModelDescriptionTool. ''' - return ModelDescriptionTool() + return ModelDescriptionTool(st_session_key="test_key") @pytest.fixture(name="input_data") def input_data_fixture(): @@ -21,9 +20,7 @@ def input_data_fixture(): Fixture for creating an instance of AskQuestionInput. ''' return ModelDescriptionInput(question="Describe the model", - sys_bio_model=ModelData(modelid=64), - st_session_key="test_key" - ) + sys_bio_model=ModelData(model_id=64)) @pytest.fixture(name="basico_model", scope="module") def basico_model_fixture(): @@ -34,25 +31,17 @@ def basico_model_fixture(): def test_run_with_missing_session_key(input_data, model_description_tool): ''' - Test the _run method of the ModelDescriptionTool class with a missing session key. - ''' - # Delete the session key from the session state. - st.session_state.pop(input_data.st_session_key, None) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert isinstance(result, str) - -def test_model_data_initialization(): - """ - Test the initialization of the ModelData class. - """ - model_data = ModelData(modelid=1, - sbml_file_path="path/to/file", - model_object=BasicoModel(model_id=1)) - assert model_data.modelid == 1 - assert model_data.sbml_file_path == "path/to/file" - assert isinstance(model_data.model_object, BasicoModel) + Test the _run method of the ModelDescriptionTool class + with a missing session key. + ''' + if 'test_key' in st.session_state: + del st.session_state['test_key'] + result = model_description_tool.invoke(input={ + 'question':input_data.question, + }) + expected_result = f"Session key {model_description_tool.st_session_key} " + expected_result += "not found in Streamlit session state." + assert result == expected_result def test_check_model_object(basico_model): """ @@ -79,88 +68,85 @@ def test_run_with_none_key_no_model_data(input_data, model_description_tool): ''' st.session_state["test_key"] = None input_data.sys_bio_model = ModelData() - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert result == "Please provide a BioModels ID or an SBML file path for the model." def test_call_run_with_different_model_data(input_data, basico_model, model_description_tool): ''' Test the _run method of the ModelDescriptionTool class with a model id. ''' - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert isinstance(result, str) # Test the _run method of the ModelDescriptionTool class with an SBML file. input_data = ModelDescriptionInput(question="Describe the model", - sys_bio_model=ModelData(sbml_file_path="./BIOMD0000000064_url.xml"), - st_session_key="test_key" - ) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + sys_bio_model=ModelData( + sbml_file_path="aiagents4pharma/talk2biomodels/tests/BIOMD0000000064_url.xml"), + st_session_key="test_key" + ) + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert isinstance(result, str) # Test the _run method of the ModelDescriptionTool class with a model object. input_data = ModelDescriptionInput(question="Describe the model", sys_bio_model=ModelData(model_object=basico_model), st_session_key="test_key" ) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert isinstance(result, str) def test_run_with_none_key(input_data, model_description_tool): ''' Test the _run method of the ModelDescriptionTool class with a None ''' - input_data.st_session_key = None - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + model_description_tool.st_session_key = None + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert isinstance(result, str) # sleep for 5 seconds # time.sleep(5) input_data.sys_bio_model = ModelData() - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert result == "Please provide a valid model object or " \ "Streamlit session key that contains the model object." # sleep for 5 seconds # time.sleep(5) - input_data.st_session_key = "test_key" + model_description_tool.st_session_key = "test_key" # delete the session key form the session state - st.session_state.pop(input_data.st_session_key, None) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == f"Session key {input_data.st_session_key} " \ + del st.session_state[model_description_tool.st_session_key] + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) + assert result == f"Session key {model_description_tool.st_session_key} " \ "not found in Streamlit session state." def test_run_manager(input_data, model_description_tool): ''' Test the _run method of the ModelDescriptionTool class with a run_manager. ''' - run_manager = CallbackManagerForToolRun(run_id=2, handlers=[], inheritable_handlers=False) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key, - run_manager=run_manager) - assert isinstance(result, str) - # sleep for 5 seconds - # time.sleep(5) - run_manager = CallbackManagerForToolRun(run_id=2, - handlers=[], - inheritable_handlers=False, - metadata={"prompt": '''Given: {description}, - answer the question: - {question}.'''}) - result = model_description_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key, - run_manager=run_manager) + model_description_tool.metadata = {"prompt": '''Given: {description}, + answer the question: + {question}.'''} + result = model_description_tool.invoke(input={ + 'question':input_data.question, + 'sys_bio_model':input_data.sys_bio_model, + }) assert isinstance(result, str) def test_get_metadata(model_description_tool): diff --git a/aiagents4pharma/talk2biomodels/tests/test_plot_figure.py b/aiagents4pharma/talk2biomodels/tests/test_plot_figure.py deleted file mode 100644 index b14e0378..00000000 --- a/aiagents4pharma/talk2biomodels/tests/test_plot_figure.py +++ /dev/null @@ -1,99 +0,0 @@ -''' -Test cases for plot_figure.py -''' - -import pytest -import streamlit as st -from ..tools.plot_figure import PlotImageTool, PlotImageInput, ModelData -from ..models.basico_model import BasicoModel - -@pytest.fixture(name="plot_image_tool") -def plot_image_tool_fixture(): - ''' - Fixture for creating an instance of PlotImageTool. - ''' - return PlotImageTool() - -@pytest.fixture(name="input_data", scope="module") -def input_data_fixture(): - ''' - Fixture for creating an instance of AskQuestionInput. - ''' - return PlotImageInput(question="What is the concentration of Pyruvate at time 5?", - sys_bio_model=ModelData(modelid=64), - st_session_key="test_key" - ) - -def test_call_run(input_data, plot_image_tool): - ''' - Test the _run method of the PlotImageTool class with an invalid model ID. - ''' - input_data.sys_bio_model = ModelData() - st.session_state["test_key"] = None - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == "Please run the simulation first before plotting the figure." - st.session_state["test_key"] = BasicoModel(model_id=64) - st.session_state["test_key"].simulate(duration=2, interval=2) - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == "Figure plotted successfully" - -def test_call_run_with_different_input_model_data(input_data, plot_image_tool): - ''' - Test the _run method of the PlotImageTool class with different input model data. - ''' - input_data.sys_bio_model = ModelData(modelid=64) - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == "Figure plotted successfully" - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=ModelData(sbml_file_path="./BIOMD0000000064_url.xml"), - st_session_key=input_data.st_session_key) - assert result == "Figure plotted successfully" - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=ModelData(model_object=BasicoModel(model_id=64)), - st_session_key=input_data.st_session_key) - assert result == "Figure plotted successfully" - # without simulation results - model = BasicoModel(model_id=64) - input_data.sys_bio_model = ModelData(model_object=model) - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == "Figure plotted successfully" - -def test_run_with_none_key(input_data, plot_image_tool): - ''' - Test the _run method of the AskQuestionTool class with a None session key. - ''' - input_data.st_session_key = None - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert isinstance(result, str) - input_data.sys_bio_model = ModelData() - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == "Please provide a valid model object or \ - Streamlit session key that contains the model object." - input_data.st_session_key = "test_key" - # delete the session key form the session state - st.session_state.pop(input_data.st_session_key, None) - result = plot_image_tool.call_run(question=input_data.question, - sys_bio_model=input_data.sys_bio_model, - st_session_key=input_data.st_session_key) - assert result == f"Session key {input_data.st_session_key} " \ - "not found in Streamlit session state." - -def test_get_metadata(plot_image_tool): - ''' - Test the get_metadata method of the PlotImageTool class. - ''' - metadata = plot_image_tool.get_metadata() - assert metadata["name"] == "plot_figure" - assert metadata["description"] == "A tool to plot or visualize the simulation results." diff --git a/aiagents4pharma/talk2biomodels/tests/test_search_models.py b/aiagents4pharma/talk2biomodels/tests/test_search_models.py new file mode 100644 index 00000000..599c0e02 --- /dev/null +++ b/aiagents4pharma/talk2biomodels/tests/test_search_models.py @@ -0,0 +1,23 @@ +''' +Test cases for search_models.py +''' + +from ..tools.search_models import SearchModelsTool + +def test_tool_search_models(): + ''' + Test the tool search_models. + ''' + search_models = SearchModelsTool() + response = search_models.run({'query': 'Crohns Disease'}) + # Check if the response contains the BioModel ID + # BIOMD0000000537 + assert 'BIOMD0000000537' in response + +def test_get_metadata(): + ''' + Test the get_metadata method of the SearchModelsTool class. + ''' + metadata = SearchModelsTool().get_metadata() + assert metadata["name"] == "search_models" + assert metadata["description"] == "Search models based on search query." diff --git a/aiagents4pharma/talk2biomodels/tests/test_simulate_model.py b/aiagents4pharma/talk2biomodels/tests/test_simulate_model.py index 19e17a1e..69d76fcb 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +++ b/aiagents4pharma/talk2biomodels/tests/test_simulate_model.py @@ -81,7 +81,7 @@ def test_run_with_valid_sbml_file_path(simulate_model_tool): Test the _run method of the SimulateModelTool class with a valid SBML file path. ''' - sbml_file_path="./BIOMD0000000064.xml" + sbml_file_path="aiagents4pharma/talk2biomodels/tests/BIOMD0000000064.xml" model_data=ModelData(sbml_file_path=sbml_file_path) time_data=TimeData(duration=100.0, interval=10) species_data=SpeciesData(species_name=["Pyruvate"], species_concentration=[1.0]) diff --git a/aiagents4pharma/talk2biomodels/tools/__init__.py b/aiagents4pharma/talk2biomodels/tools/__init__.py index d14cb661..455320dd 100644 --- a/aiagents4pharma/talk2biomodels/tools/__init__.py +++ b/aiagents4pharma/talk2biomodels/tools/__init__.py @@ -1,7 +1,10 @@ ''' This file is used to import all the modules in the package. ''' +# import everything from the module +from . import ask_question from . import simulate_model +from . import custom_plotter +from . import fetch_parameters from . import model_description -from . import ask_question -from . import plot_figure +from . import search_models diff --git a/aiagents4pharma/talk2biomodels/tools/plot_figure.py b/aiagents4pharma/talk2biomodels/tools/plot_figure.py deleted file mode 100644 index 0cd5bfe0..00000000 --- a/aiagents4pharma/talk2biomodels/tools/plot_figure.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python3 - -""" -Tool for plotting a figure. -""" - -from typing import Type, Optional -from dataclasses import dataclass -import matplotlib.pyplot as plt -from pydantic import BaseModel, Field -import streamlit as st -from langchain_openai import ChatOpenAI -from langchain_core.tools import BaseTool -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser -from langchain_experimental.tools import PythonAstREPLTool -from ..models.basico_model import BasicoModel - -@dataclass -class ModelData: - """ - Dataclass for storing the model data. - """ - modelid: Optional[int] = None - sbml_file_path: Optional[str] = None - model_object: Optional[BasicoModel] = None - -class PlotImageInput(BaseModel): - """ - Input schema for the PlotImage tool. - """ - question: str = Field(description="Description of the plot") - sys_bio_model: ModelData = Field(description="model data", default=None) - -# Note: It's important that every field has type hints. BaseTool is a -# Pydantic class and not having type hints can lead to unexpected behavior. -class PlotImageTool(BaseTool): - """ - Tool for plotting a figure. - """ - name: str = "plot_figure" - description: str = "A tool to plot or visualize the simulation results." - args_schema: Type[BaseModel] = PlotImageInput - st_session_key: str = None - - def _run(self, - question: str, - sys_bio_model: ModelData = ModelData()) -> str: - """ - Run the tool. - - Args: - question (str): The question to ask about the model description. - sys_bio_model (ModelData): The model data. - - Returns: - str: The answer to the question - """ - st_session_key = self.st_session_key - # Check if sys_bio_model is provided - if sys_bio_model.modelid or sys_bio_model.sbml_file_path or sys_bio_model.model_object: - if sys_bio_model.modelid: - model_object = BasicoModel(model_id=sys_bio_model.modelid) - elif sys_bio_model.sbml_file_path: - model_object = BasicoModel(sbml_file_path=sys_bio_model.sbml_file_path) - else: - model_object = sys_bio_model.model_object - if st_session_key: - st.session_state[st_session_key] = model_object - else: - # If the model_object is not provided, - # get it from the Streamlit session state - if st_session_key: - if st_session_key not in st.session_state: - return f"Session key {st_session_key} not found in Streamlit session state." - model_object = st.session_state[st_session_key] - else: - return "Please provide a valid model object or \ - Streamlit session key that contains the model object." - if model_object is None: - return "Please run the simulation first before plotting the figure." - if model_object.simulation_results is None: - model_object.simulate() - df = model_object.simulation_results - tool = PythonAstREPLTool(locals={"df": df}) - llm = ChatOpenAI(model="gpt-3.5-turbo") - llm_with_tools = llm.bind_tools([tool], tool_choice=tool.name) - system = f""" - You have access to a pandas dataframe `df`. - Here is the output of `df.head().to_markdown()`: - {df.head().to_markdown()} - Given a user question, write the Python code to - plot a figure of the answer using matplolib. - Return ONLY the valid Python code and nothing else. - The firgure size should be equal or smaller than (2, 2). - Show the grid and legend. The font size of the legend should be 6. - Also write a suitable title for the figure. The font size of the title should be 8. - The font size of the x-axis and y-axis labels should be 8. - The font size of the x-axis and y-axis ticks should be 6. - Make sure that the x-axis has at least 10 tick marks. - Use color-blind friendly colors. The figure must be of high quality. - Don't assume you have access to any libraries other - than built-in Python ones, pandas, streamlit and matplotlib. - """ - prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{question}")]) - parser = JsonOutputKeyToolsParser(key_name=tool.name, first_tool_only=True) - code_chain = prompt | llm_with_tools | parser - response = code_chain.invoke({"question": question}) - exec(response['query'], globals(), {"df": df, "plt": plt}) - # load for plotly - fig = plt.gcf() - if st_session_key: - st.pyplot(fig, use_container_width=False) - st.dataframe(df) - return "Figure plotted successfully" - - def call_run(self, - question: str, - sys_bio_model: ModelData = ModelData(), - st_session_key: str = None) -> str: - """ - Run the tool. - """ - return self._run(question=question, - sys_bio_model=sys_bio_model, - st_session_key=st_session_key) - - def get_metadata(self): - """ - Get metadata for the tool. - """ - return { - "name": self.name, - "description": self.description - }