From d0a935af14088d3738db2917e06fb380d10bf85a Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:37:26 +0100 Subject: [PATCH 01/19] feat: initialize agents package with main and s2 agent exports --- aiagents4pharma/talk2competitors/agents/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/agents/__init__.py diff --git a/aiagents4pharma/talk2competitors/agents/__init__.py b/aiagents4pharma/talk2competitors/agents/__init__.py new file mode 100644 index 00000000..1b1a3415 --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/__init__.py @@ -0,0 +1,5 @@ +# Expose main agent and sub-agents at package level +from agents.main_agent import get_app +from agents.s2_agent import s2_agent + +__all__ = ["get_app", "s2_agent"] From 9de715d58f2290fcf5ab1466ba78694a4b49f906 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:37:43 +0100 Subject: [PATCH 02/19] feat: implement main supervisor agent with LangGraph routing --- .../talk2competitors/agents/main_agent.py | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/agents/main_agent.py diff --git a/aiagents4pharma/talk2competitors/agents/main_agent.py b/aiagents4pharma/talk2competitors/agents/main_agent.py new file mode 100644 index 00000000..e01d8672 --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/main_agent.py @@ -0,0 +1,164 @@ +import logging +from typing import Literal + +import requests +from dotenv import load_dotenv +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.tools import ToolException +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command + +from agents.s2_agent import s2_agent +from config.config import config +from state.shared_state import talk2comp + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +load_dotenv() + + +def make_supervisor_node(llm: BaseChatModel) -> str: + """ + Creates a supervisor node following LangGraph patterns. + + Args: + llm (BaseChatModel): The language model to use for generating responses. + + Returns: + str: The supervisor node function. + """ + # options = ["FINISH", "s2_agent"] + + def supervisor_node(state: talk2comp) -> Command[Literal["s2_agent", "__end__"]]: + """ + Supervisor node that routes to appropriate sub-agents. + + Args: + state (talk2comp): The current state of the conversation. + + Returns: + Command[Literal["s2_agent", "__end__"]]: The command to execute next. + """ + logger.info("Supervisor node called") + + messages = [{"role": "system", "content": config.MAIN_AGENT_PROMPT}] + state[ + "messages" + ] + response = llm.invoke(messages) + goto = ( + "FINISH" + if not any( + kw in state["messages"][-1].content.lower() + for kw in ["search", "paper", "find"] + ) + else "s2_agent" + ) + + if goto == "FINISH": + return Command( + goto=END, + update={ + "messages": state["messages"] + + [AIMessage(content=response.content)], + "is_last_step": True, + "current_agent": None, + }, + ) + + return Command( + goto="s2_agent", + update={ + "messages": state["messages"], + "is_last_step": False, + "current_agent": "s2_agent", + }, + ) + + return supervisor_node + + +def call_s2_agent(state: talk2comp) -> Command[Literal["__end__"]]: + """ + Node for calling the S2 agent. + + Args: + state (talk2comp): The current state of the conversation. + + Returns: + Command[Literal["__end__"]]: The command to execute next. + """ + logger.info("Calling S2 agent") + try: + response = s2_agent.invoke(state) + logger.info("S2 agent completed") + return Command( + goto=END, + update={ + "messages": response["messages"], + "papers": response.get("papers", []), + "is_last_step": True, + "current_agent": "s2_agent", + }, + ) + except requests.RequestException as e: + logger.error("Network error in S2 agent: %s", str(e)) + return Command( + goto=END, + update={ + "messages": state["messages"] + + [AIMessage(content=f"Network error: {str(e)}")], + "is_last_step": True, + "current_agent": "s2_agent", + }, + ) + except ValueError as e: + logger.error("Value error in S2 agent: %s", str(e)) + return Command( + goto=END, + update={ + "messages": state["messages"] + + [AIMessage(content=f"Input error: {str(e)}")], + "is_last_step": True, + "current_agent": "s2_agent", + }, + ) + except ToolException as e: + logger.error("Tool error in S2 agent: %s", str(e)) + return Command( + goto=END, + update={ + "messages": state["messages"] + [AIMessage(content=str(e))], + "is_last_step": True, + "current_agent": "s2_agent", + }, + ) + + +def get_app(thread_id: str): + """ + Returns the langraph app with hierarchical structure. + + Args: + thread_id (str): The thread ID for the conversation. + + Returns: + The compiled langraph app. + """ + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + workflow = StateGraph(talk2comp) + + supervisor = make_supervisor_node(llm) + workflow.add_node("supervisor", supervisor) + workflow.add_node("s2_agent", call_s2_agent) + + # Define edges + workflow.add_edge(START, "supervisor") + workflow.add_edge("s2_agent", END) + + app = workflow.compile(checkpointer=MemorySaver()) + logger.info("Main agent workflow compiled") + return app From ba4a29c6b558adf48ddd593af0460a7b2f2e5870 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:38:00 +0100 Subject: [PATCH 03/19] feat: add Semantic Scholar agent with ReAct pattern --- .../talk2competitors/agents/s2_agent.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/agents/s2_agent.py diff --git a/aiagents4pharma/talk2competitors/agents/s2_agent.py b/aiagents4pharma/talk2competitors/agents/s2_agent.py new file mode 100644 index 00000000..efd507ee --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/s2_agent.py @@ -0,0 +1,133 @@ +import logging +from typing import Literal + +import requests +from dotenv import load_dotenv +from langchain_core.messages import AIMessage +from langchain_core.tools import ToolException +from langchain_openai import ChatOpenAI +from langgraph.graph import END, START, StateGraph +from langgraph.prebuilt import create_react_agent +from langgraph.types import Command + +from config.config import config +from state.shared_state import talk2comp +from tools.s2 import s2_tools + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +load_dotenv() + + +class SemanticScholarAgent: + """ + Agent for interacting with Semantic Scholar using LangGraph and LangChain. + """ + + def __init__(self): + """ + Initializes the SemanticScholarAgent with necessary configurations. + """ + try: + logger.info("Initializing S2 Agent...") + self.llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + + # Create the tools agent using config prompt + self.tools_agent = create_react_agent( + self.llm, + tools=s2_tools, + state_schema=talk2comp, + state_modifier=config.S2_AGENT_PROMPT, + ) + + def execute_tools(state: talk2comp) -> Command[Literal["__end__"]]: + """ + Execute tools and return results. + + Args: + state (talk2comp): The current state of the conversation. + + Returns: + Command[Literal["__end__"]]: The command to execute next. + """ + logger.info("Executing tools") + try: + result = self.tools_agent.invoke(state) + logger.info("Tool execution completed") + return Command( + goto=END, + update={ + "messages": result["messages"], + "papers": result.get("papers", []), + "is_last_step": True, + }, + ) + except (requests.RequestException, ToolException) as e: + logger.error("API or tool error: %s", str(e)) + return Command( + goto=END, + update={ + "messages": [AIMessage(content=f"Error: {str(e)}")], + "is_last_step": True, + }, + ) + except ValueError as e: + logger.error("Value error: %s", str(e)) + return Command( + goto=END, + update={ + "messages": [ + AIMessage(content=f"Input validation error: {str(e)}") + ], + "is_last_step": True, + }, + ) + + # Create graph + workflow = StateGraph(talk2comp) + workflow.add_node("tools", execute_tools) + workflow.add_edge(START, "tools") + + self.graph = workflow.compile() + logger.info("S2 Agent initialized successfully") + + except Exception as e: + logger.error("Initialization error: %s", str(e)) + raise + + def invoke(self, state): + """ + Invokes the SemanticScholarAgent with the given state. + + Args: + state (talk2comp): The current state of the conversation. + + Returns: + dict: The result of the invocation, including messages and papers. + """ + try: + logger.info("Invoking S2 agent") + return self.graph.invoke(state) + except (requests.RequestException, ToolException) as e: + logger.error("Network or tool error in S2 agent: %s", str(e)) + return { + "messages": [AIMessage(content=f"Error in processing: {str(e)}")], + "papers": [], + } + except ValueError as e: + logger.error("Value error in S2 agent: %s", str(e)) + return { + "messages": [AIMessage(content=f"Invalid input: {str(e)}")], + "papers": [], + } + except RuntimeError as e: + logger.error("Runtime error in S2 agent: %s", str(e)) + return { + "messages": [AIMessage(content=f"Internal error: {str(e)}")], + "papers": [], + } + + +# Create a global instance +s2_agent = SemanticScholarAgent() From 67cdc30f7d5cad063b4c29b0eba64ee54e32ebb1 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:38:25 +0100 Subject: [PATCH 04/19] feat: initialize Semantic Scholar tools package --- .../talk2competitors/tools/s2/__init__.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tools/s2/__init__.py diff --git a/aiagents4pharma/talk2competitors/tools/s2/__init__.py b/aiagents4pharma/talk2competitors/tools/s2/__init__.py new file mode 100644 index 00000000..a07b8614 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/__init__.py @@ -0,0 +1,20 @@ +from tools.s2.display_results import display_results +from tools.s2.multi_paper_rec import get_multi_paper_recommendations +from tools.s2.search import search_tool +from tools.s2.single_paper_rec import get_single_paper_recommendations + +# Export all tools in a list for easy access +s2_tools = [ + search_tool, + display_results, + get_single_paper_recommendations, + get_multi_paper_recommendations, +] + +__all__ = [ + "search_tool", + "display_results", + "get_single_paper_recommendations", + "get_multi_paper_recommendations", + "s2_tools", +] From 2a60ed7173abaf29b617200536cfe1c34bcdf356 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:38:43 +0100 Subject: [PATCH 05/19] feat: implement paper search tool with pagination --- .../talk2competitors/tools/s2/search.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tools/s2/search.py diff --git a/aiagents4pharma/talk2competitors/tools/s2/search.py b/aiagents4pharma/talk2competitors/tools/s2/search.py new file mode 100644 index 00000000..0d8f61de --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/search.py @@ -0,0 +1,130 @@ +import time +from typing import Annotated, Any, Dict + +import pandas as pd +import requests +from langchain_core.messages import AIMessage +from langchain_core.tools import ToolException, tool +from langchain_core.tools.base import InjectedToolCallId +from pydantic import BaseModel, Field + +from config.config import config + + +class SearchInput(BaseModel): + """Input schema for the search papers tool.""" + + query: str = Field( + description="Search query string to find academic papers." + "Be specific and include relevant academic terms." + ) + limit: int = Field( + default=2, description="Maximum number of results to return", ge=1, le=100 + ) + tool_call_id: Annotated[str, InjectedToolCallId] + + +@tool(args_schema=SearchInput) +def search_tool( + query: str, + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, +) -> Dict[str, Any]: + """ + Search for academic papers on Semantic Scholar. + + Args: + query (str): The search query string to find academic papers. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of results to return. Defaults to 2. + + Returns: + Dict[str, Any]: The search results and related information. + """ + print("Starting paper search...") + endpoint = f"{config.SEMANTIC_SCHOLAR_API}/paper/search" + params = { + "query": query, + "limit": min(limit, 100), + "fields": "paperId,title,abstract,year,authors,citationCount,openAccessPdf", + } + + max_retries = 3 + retry_count = 0 + retry_delay = 2 + while retry_count < max_retries: + try: + print(f"Attempt {retry_count + 1} of {max_retries}") + response = requests.get(endpoint, params=params, timeout=10) + if response.status_code == 429: + retry_count += 1 + wait_time = retry_delay * (2**retry_count) + print(f"Rate limit hit. Waiting {wait_time} seconds...") + time.sleep(wait_time) + continue + if response.status_code == 200: + print("Successful response received") + break + response.raise_for_status() + except requests.exceptions.RequestException as e: + print(f"Request failed: {str(e)}") + retry_count += 1 + if retry_count == max_retries: + raise ToolException( + f"Error searching papers after {max_retries} attempts: {str(e)}" + ) from e + time.sleep(retry_delay * (2**retry_count)) + continue + + print("Processing response...") + data = response.json() + papers = data.get("data", []) + + filtered_papers = [ + {"Paper ID": paper["paperId"], "Title": paper["title"]} + for paper in papers + if paper.get("title") and paper.get("authors") + ] + + if not filtered_papers: + return { + "papers": ["No papers found matching your query."], + "messages": [AIMessage(content="No papers found matching your query")], + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": { + "name": "search_tool", + "arguments": {"query": query, "limit": limit}, + }, + } + ], + } + + df = pd.DataFrame(filtered_papers) + print("Created DataFrame with results") + print(df) + + papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in filtered_papers + ] + + markdown_table = df.to_markdown(tablefmt="grid") + print("Search tool execution completed") + + return { + "papers": papers, + "messages": [AIMessage(content=markdown_table)], + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": { + "name": "search_tool", + "arguments": {"query": query, "limit": limit}, + }, + } + ], + } From 62cbcd71940151899980e5411d09c43129d8cacc Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:38:55 +0100 Subject: [PATCH 06/19] feat: add single paper recommendation tool --- .../tools/s2/single_paper_rec.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py diff --git a/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py new file mode 100644 index 00000000..0ef61a76 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py @@ -0,0 +1,145 @@ +import re +import time +from typing import Annotated, Any, Dict + +import pandas as pd +import requests +from langchain_core.messages import ToolMessage +from langchain_core.tools import ToolException, tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command +from pydantic import BaseModel, Field, field_validator + + +class SinglePaperRecInput(BaseModel): + """Input schema for single paper recommendation tool.""" + + paper_id: str = Field( + description="Semantic Scholar Paper ID to get recommendations for (40-character string)" + ) + limit: int = Field( + default=2, + description="Maximum number of recommendations to return", + ge=1, + le=500, + ) + tool_call_id: Annotated[str, InjectedToolCallId] + + @classmethod + @field_validator("paper_id") + def validate_paper_id(cls, v: str) -> str: + """ + Validates the paper ID. + + Args: + v (str): The paper ID to validate. + + Returns: + str: The validated paper ID. + + Raises: + ValueError: If the paper ID is not a 40-character hexadecimal string. + """ + if not re.match(r"^[a-f0-9]{40}$", v): + raise ValueError("Paper ID must be a 40-character hexadecimal string") + return v + + model_config = {"arbitrary_types_allowed": True} + + +@tool(args_schema=SinglePaperRecInput) +def get_single_paper_recommendations( + paper_id: str, + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, +) -> Dict[str, Any]: + """ + Get paper recommendations based on a single paper. + + Args: + paper_id (str): The Semantic Scholar Paper ID to get recommendations for. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of recommendations to return. Defaults to 2. + + Returns: + Dict[str, Any]: The recommendations and related information. + """ + # Validate paper ID format first + if not re.match(r"^[a-f0-9]{40}$", paper_id): + raise ValueError("Paper ID must be a 40-character hexadecimal string") + print("Starting single paper recommendations search...") + + endpoint = ( + f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}" + ) + params = { + "limit": min(limit, 500), # Max 500 per API docs + "fields": "title,paperId,abstract,year", + "from": "all-cs", # Using all-cs pool as specified in docs + } + + max_retries = 3 + retry_count = 0 + retry_delay = 2 + + while retry_count < max_retries: + print(f"Attempt {retry_count + 1} of {max_retries}") + response = requests.get(endpoint, params=params, timeout=10) + print(f"API Response Status: {response.status_code}") + print(f"Request params: {params}") + + if response.status_code == 200: + data = response.json() + print(f"Raw API Response: {data}") + recommendations = data.get("recommendedPapers", []) + + if recommendations: + filtered_papers = [ + {"Paper ID": paper["paperId"], "Title": paper["title"]} + for paper in recommendations + if paper.get("title") and paper.get("paperId") + ] + + if filtered_papers: + df = pd.DataFrame(filtered_papers) + + papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in filtered_papers + ] + + markdown_table = df.to_markdown(tablefmt="grid") + + return Command( + update={ + "papers": papers, + "messages": [ + ToolMessage( + content=markdown_table, tool_call_id=tool_call_id + ) + ], + } + ) + + return Command( + update={ + "papers": [], + "messages": [ + ToolMessage( + content="No recommendations found for this paper", + tool_call_id=tool_call_id, + ) + ], + } + ) + + retry_count += 1 + if retry_count < max_retries: + wait_time = retry_delay * (2**retry_count) + print(f"Retrying in {wait_time} seconds...") + time.sleep(wait_time) + + raise ToolException( + "Error getting recommendations after " + f"{max_retries} attempts. Status code: {response.status_code}" + ) From 3ab402259ef39b286f0b9d200d4ba725a594944b Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:39:45 +0100 Subject: [PATCH 07/19] feat: implement multi-paper recommendation tool --- .../tools/s2/multi_paper_rec.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py diff --git a/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py new file mode 100644 index 00000000..1254ce6d --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py @@ -0,0 +1,206 @@ +import json +import time +from typing import Annotated, Any, Dict, List + +import pandas as pd +import requests +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command +from pydantic import BaseModel, Field, field_validator + + +class MultiPaperRecInput(BaseModel): + """Input schema for multiple paper recommendations tool.""" + + paper_ids: List[str] = Field( + description=("List of Semantic Scholar Paper IDs to get recommendations for") + ) + limit: int = Field( + default=2, + description="Maximum total number of recommendations to return", + ge=1, + le=500, + ) + tool_call_id: Annotated[str, InjectedToolCallId] + + @classmethod + @field_validator("paper_ids") + def validate_paper_ids(cls, v: List[str]) -> List[str]: + """ + Validates the list of paper IDs. + + Args: + v (List[str]): The list of paper IDs to validate. + + Returns: + List[str]: The validated list of paper IDs. + + Raises: + ValueError: If the list is empty, contains more than 10 IDs, or any ID has an invalid format. + """ + if not v: + raise ValueError("At least one paper ID must be provided") + if len(v) > 10: + raise ValueError("Maximum of 10 paper IDs allowed") + return v + + model_config = {"arbitrary_types_allowed": True} + + +@tool(args_schema=MultiPaperRecInput) +def get_multi_paper_recommendations( + paper_ids: List[str], + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, +) -> Dict[str, Any]: + """ + Get paper recommendations based on multiple papers. + + Args: + paper_ids (List[str]): The list of paper IDs to base recommendations on. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of recommendations to return. Defaults to 2. + + Returns: + Dict[str, Any]: The recommendations and related information. + """ + # Validate inputs + if not paper_ids: + raise ValueError("At least one paper ID must be provided") + if len(paper_ids) > 10: + raise ValueError("Maximum of 10 paper IDs allowed") + print("Starting multi-paper recommendations search...") + + endpoint = "https://api.semanticscholar.org/recommendations/v1/papers" + headers = {"Content-Type": "application/json"} + payload = {"positivePaperIds": paper_ids, "negativePaperIds": []} + params = {"limit": min(limit, 500), "fields": "title,paperId"} + + max_retries = 3 + retry_count = 0 + retry_delay = 2 + + while retry_count < max_retries: + print(f"Attempt {retry_count + 1} of {max_retries}") + try: + response = requests.post( + endpoint, + headers=headers, + params=params, + data=json.dumps(payload), + timeout=10, + ) + print(f"API Response Status: {response.status_code}") + + if response.status_code == 200: + data = response.json() + recommendations = data.get("recommendedPapers", []) + + if not recommendations: + print("No recommendations found") + return Command( + update={ + "papers": [ + "No recommendations found for the provided papers" + ], + "messages": [ + ToolMessage( + content="No recommendations found for the provided papers", + tool_call_id=tool_call_id, + ) + ], + } + ) + + # Create a list to store the papers + papers_list = [] + for paper in recommendations: + if paper.get("title") and paper.get("paperId"): + papers_list.append( + {"Paper ID": paper["paperId"], "Title": paper["title"]} + ) + + if not papers_list: + return Command( + update={ + "papers": ["No valid recommendations found"], + "messages": [ + ToolMessage( + content="No valid recommendations found", + tool_call_id=tool_call_id, + ) + ], + } + ) + + df = pd.DataFrame(papers_list) + print("Created DataFrame with results:") + print(df) + + # Format papers for state update + formatted_papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in papers_list + ] + + markdown_table = df.to_markdown(tablefmt="grid") + return Command( + update={ + "papers": formatted_papers, + "messages": [ + ToolMessage( + content=markdown_table, tool_call_id=tool_call_id + ) + ], + } + ) + + if response.status_code == 404: + return Command( + update={ + "papers": ["One or more paper IDs not found"], + "messages": [ + ToolMessage( + content="One or more paper IDs not found", + tool_call_id=tool_call_id, + ) + ], + } + ) + + retry_count += 1 + if retry_count < max_retries: + wait_time = retry_delay * (2**retry_count) + print(f"Retrying in {wait_time} seconds...") + time.sleep(wait_time) + + except Exception as e: + print(f"Error: {str(e)}") + retry_count += 1 + if retry_count == max_retries: + return Command( + update={ + "papers": [f"Error getting recommendations: {str(e)}"], + "messages": [ + ToolMessage( + content=f"Error getting recommendations: {str(e)}", + tool_call_id=tool_call_id, + ) + ], + } + ) + time.sleep(retry_delay * (2**retry_count)) + + return Command( + update={ + "papers": ["Failed to get recommendations after maximum retries"], + "messages": [ + ToolMessage( + content="Failed to get recommendations after maximum retries", + tool_call_id=tool_call_id, + ) + ], + } + ) From bddcb7b48ab675651f5979b2e65359d4eac32a05 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:39:54 +0100 Subject: [PATCH 08/19] feat: add results display formatting tool --- .../tools/s2/display_results.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tools/s2/display_results.py diff --git a/aiagents4pharma/talk2competitors/tools/s2/display_results.py b/aiagents4pharma/talk2competitors/tools/s2/display_results.py new file mode 100644 index 00000000..6211b87d --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/display_results.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +''' +This tool is used to display the table of studies. +''' + +from typing import Annotated +from langchain_core.tools import tool +from langgraph.prebuilt import InjectedState + +@tool('display_results') +def display_results(state: Annotated[dict, InjectedState]): + """ + Display the table of studies. + + Args: + state (dict): The state of the agent. + """ + print ('Called display_results') + return state["papers"] From 48b33e817bcf20259d5dd8882f283d524595cfa3 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:42:15 +0100 Subject: [PATCH 09/19] feat: initialize utils package --- aiagents4pharma/talk2competitors/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 aiagents4pharma/talk2competitors/utils/__init__.py diff --git a/aiagents4pharma/talk2competitors/utils/__init__.py b/aiagents4pharma/talk2competitors/utils/__init__.py new file mode 100644 index 00000000..beaf8977 --- /dev/null +++ b/aiagents4pharma/talk2competitors/utils/__init__.py @@ -0,0 +1 @@ +# __init__.py - File description From 48f8e14c3e1c497a11b8a13df85731853445242c Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:42:29 +0100 Subject: [PATCH 10/19] feat: add LLM manager with OpenAI integration and tool binding --- aiagents4pharma/talk2competitors/utils/llm.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/utils/llm.py diff --git a/aiagents4pharma/talk2competitors/utils/llm.py b/aiagents4pharma/talk2competitors/utils/llm.py new file mode 100644 index 00000000..c7f0fa84 --- /dev/null +++ b/aiagents4pharma/talk2competitors/utils/llm.py @@ -0,0 +1,104 @@ +import os +from typing import Any, Dict, List + +from dotenv import load_dotenv +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI + +from config.config import config +from state.shared_state import shared_state + +# Load environment variables from .env file +load_dotenv() + + +# Load environment variables from .env file +load_dotenv() + +# Get API key from environment +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +if not OPENAI_API_KEY: + raise ValueError( + "OPENAI_API_KEY not found in environment variables. Please check your .env file." + ) + + +def create_llm() -> ChatOpenAI: + """Create and configure OpenAI LLM instance""" + return ChatOpenAI( + model="gpt-4o-mini", # Using gpt-4o-mini as requested + temperature=config.TEMPERATURE, + timeout=60, # Timeout in seconds + max_retries=3, + api_key=OPENAI_API_KEY, # Explicitly passing API key + top_p=0.95, # Moved out of model_kwargs + presence_penalty=0, # Moved out of model_kwargs + frequency_penalty=0, # Moved out of model_kwargs + ) + + +class LLMManager: + def __init__(self): + # Initialize the LLM with default configuration + self.llm = create_llm() + + def get_response( + self, + system_prompt: str, + user_input: str, + additional_context: Dict[str, Any] = None, + include_history: bool = True, + ) -> str: + """Get response from LLM with system prompt and user input""" + try: + # Create messages list + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=user_input), + ] + + # Add chat history if requested + if include_history: + history = shared_state.get_chat_history(limit=3) + for msg in history: + if msg["role"] == "assistant": + messages.append(AIMessage(content=msg["content"])) + elif msg["role"] == "user": + messages.append(HumanMessage(content=msg["content"])) + + # Add debug logging + print("\nDebug - LLM Input:") + print(f"System prompt: {system_prompt[:200]}...") + print(f"User input: {user_input}") + + # Get response with retries + response = self.llm.invoke(messages) + + # Add debug logging + print("\nDebug - LLM Response:") + print(f"Raw response: {response.content}") + + # Log token usage if available + if hasattr(response, "usage_metadata"): + print(f"Token usage: {response.usage_metadata}") + + if response and response.content.strip(): + return response.content.strip() + + return "" + + except Exception as e: + print(f"Error in get_response: {str(e)}") + return "" + + def bind_tools(self, tools: List[Any], strict: bool = True) -> None: + """Bind tools to the LLM for function/tool calling with strict mode enabled""" + self.llm = self.llm.bind_tools( + tools, + tool_choice="auto", # Let the model decide which tool to use + strict=strict, # Enable strict mode for better schema validation + ) + + +# Create a global instance +llm_manager = LLMManager() From 66ba2af1906cec460faf1a2b989eee6a62b8cc46 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:42:51 +0100 Subject: [PATCH 11/19] feat: initialize testing package --- aiagents4pharma/talk2competitors/tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 aiagents4pharma/talk2competitors/tests/__init__.py diff --git a/aiagents4pharma/talk2competitors/tests/__init__.py b/aiagents4pharma/talk2competitors/tests/__init__.py new file mode 100644 index 00000000..e69de29b From 0e3f0a3620e1166fe44b55980cd872d92edbed60 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:44:02 +0100 Subject: [PATCH 12/19] feat: add test configuration and fixtures --- .../talk2competitors/tests/conftest.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tests/conftest.py diff --git a/aiagents4pharma/talk2competitors/tests/conftest.py b/aiagents4pharma/talk2competitors/tests/conftest.py new file mode 100644 index 00000000..6ab3f6d1 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tests/conftest.py @@ -0,0 +1,29 @@ +"""Test configuration and fixtures""" + +import os + +import pytest +from dotenv import load_dotenv + + +@pytest.fixture(autouse=True) +def setup_env(): + """Setup environment variables for tests""" + load_dotenv() + os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "test-key") + yield + + +@pytest.fixture +def mock_response(): + """Fixture for mocked API responses""" + return { + "data": [ + { + "paperId": "1234567890123456789012345678901234567890", + "title": "Test Paper", + "abstract": "Test abstract", + "year": 2024, + } + ] + } From 29d594ceaccae3d2664f16aa8b57c346cac14c82 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:44:16 +0100 Subject: [PATCH 13/19] feat: implement comprehensive test suite for agents and tools --- .../talk2competitors/tests/test_talk2comp.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/tests/test_talk2comp.py diff --git a/aiagents4pharma/talk2competitors/tests/test_talk2comp.py b/aiagents4pharma/talk2competitors/tests/test_talk2comp.py new file mode 100644 index 00000000..6ff2d877 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tests/test_talk2comp.py @@ -0,0 +1,153 @@ +"""Test cases for talk2comp agents and tools""" + +from unittest.mock import patch + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import ToolException # Add ToolException import + +from agents.main_agent import get_app +from agents.s2_agent import s2_agent +from tools.s2.multi_paper_rec import get_multi_paper_recommendations +from tools.s2.search import search_tool +from tools.s2.single_paper_rec import get_single_paper_recommendations + +# Mock data for tests +MOCK_PAPER_RESPONSE = { + "recommendedPapers": [{"paperId": "abc123", "title": "Test Paper"}] +} + + +@pytest.fixture +def mock_api_response(): + """Fixture for mocked API responses""" + return MOCK_PAPER_RESPONSE + + +def test_main_agent_routing(): + """Test the main agent's routing capabilities""" + unique_id = "test_12345" + app = get_app(unique_id) + config = {"configurable": {"thread_id": unique_id}} + + # Test search routing + prompt = "Find me papers about machine learning" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assert response["current_agent"] == "s2_agent" + assert isinstance(response["messages"][-1], AIMessage) + + +def test_s2_agent(): + """Test the S2 agent's functionality""" + assert hasattr(s2_agent, "tools_agent") + + state = { + "messages": [HumanMessage(content="Find papers about machine learning")], + "papers": [], + "is_last_step": False, + } + + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = MOCK_PAPER_RESPONSE + + response = s2_agent.invoke(state) + assert "messages" in response + assert "papers" in response + + +def test_search_tool(): + """Test the search papers tool""" + query = "machine learning" + tool_call_id = "test_123" + + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = {"data": [MOCK_PAPER_RESPONSE]} + + response = search_tool.func(query=query, tool_call_id=tool_call_id, limit=2) + + assert "papers" in response + assert "messages" in response + assert isinstance(response["papers"], list) + + +def test_single_paper_rec(): + """Test single paper recommendations""" + paper_id = "1234567890123456789012345678901234567890" + tool_call_id = "test_123" + + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = MOCK_PAPER_RESPONSE + + response = get_single_paper_recommendations.func( + paper_id=paper_id, tool_call_id=tool_call_id, limit=2 + ) + + assert response is not None + assert "papers" in response.update + + +def test_multi_paper_rec(): + """Test multi paper recommendations""" + paper_ids = [ + "1234567890123456789012345678901234567890", + "0987654321098765432109876543210987654321", + ] + tool_call_id = "test_123" + + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = MOCK_PAPER_RESPONSE + + response = get_multi_paper_recommendations.func( + paper_ids=paper_ids, tool_call_id=tool_call_id, limit=2 + ) + + assert response is not None + assert "papers" in response.update + + +def test_error_handling(): + """Test error handling in tools""" + # Test invalid paper ID format + with pytest.raises((ValueError, ToolException)) as exc_info: + with patch("requests.get") as mock_get: + mock_get.return_value.status_code = 404 + get_single_paper_recommendations.func( + paper_id="invalid_id", tool_call_id="test_123", limit=2 + ) + assert any( + msg in str(exc_info.value) + for msg in ["40-character hexadecimal", "Error getting recommendations"] + ) + + # Test empty paper IDs list for multi-paper recommendations + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 400 + with pytest.raises((ValueError, ToolException)) as exc_info: + get_multi_paper_recommendations.func( + paper_ids=[], tool_call_id="test_123", limit=2 + ) + assert "At least one paper ID must be provided" in str(exc_info.value) + + # Test too many paper IDs + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 400 + with pytest.raises(ValueError) as exc_info: + get_multi_paper_recommendations.func( + paper_ids=["a" * 40 for _ in range(11)], # 11 IDs + tool_call_id="test_123", + limit=2, + ) + assert "Maximum of 10 paper IDs allowed" in str(exc_info.value) From 0e184255ea7aea31aff1f956fadf3432867deb7f Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:44:53 +0100 Subject: [PATCH 14/19] feat: add Streamlit configuration for UI customization --- aiagents4pharma/talk2competitors/app/.streamlit/config.toml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/app/.streamlit/config.toml diff --git a/aiagents4pharma/talk2competitors/app/.streamlit/config.toml b/aiagents4pharma/talk2competitors/app/.streamlit/config.toml new file mode 100644 index 00000000..49f7885b --- /dev/null +++ b/aiagents4pharma/talk2competitors/app/.streamlit/config.toml @@ -0,0 +1,2 @@ +[theme] +base = "light" From fd8b068864ba1bd62cade76ffc3c8464ef0d3917 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:45:10 +0100 Subject: [PATCH 15/19] feat: implement Streamlit chat interface with session management --- .../talk2competitors/app/talk2comp.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/app/talk2comp.py diff --git a/aiagents4pharma/talk2competitors/app/talk2comp.py b/aiagents4pharma/talk2competitors/app/talk2comp.py new file mode 100644 index 00000000..d50442b3 --- /dev/null +++ b/aiagents4pharma/talk2competitors/app/talk2comp.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +talk2comp: A Streamlit app for academic paper search and recommendations +""" + +import os +import random +import sys + +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") +import streamlit as st +from langchain_core.messages import ChatMessage, HumanMessage + +from agents.main_agent import get_app + +# Page configuration +st.set_page_config(page_title="talk2comp", page_icon="📚", layout="wide") + +# Styles for fixed bottom input +st.markdown( + """ + +""", + unsafe_allow_html=True, +) + +# Initialize session states +if "messages" not in st.session_state: + st.session_state.messages = [] +if "unique_id" not in st.session_state: + st.session_state.unique_id = random.randint(1, 1000) +if "app" not in st.session_state: + st.session_state.app = get_app(str(st.session_state.unique_id)) + +app = st.session_state.app + +# Check OpenAI API key +if "OPENAI_API_KEY" not in os.environ: + st.error("Please set the OPENAI_API_KEY environment variable.") + st.stop() + +# Main layout +col1, col2 = st.columns([1, 4]) + +# Sidebar for settings +with col1: + st.markdown("### 📚 talk2comp") + llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] + llm_option = st.selectbox( + "Select LLM Model", + llms, + index=0, + ) + +# Main chat area +with col2: + # Container for chat history + chat_container = st.container() + with chat_container: + st.markdown("### 💬 Chat History") + + # Display messages + for message in st.session_state.messages: + with st.chat_message( + message["content"].role, + avatar="🤖" if message["content"].role != "user" else "👤", + ): + st.markdown(message["content"].content) + +# Fixed bottom input +prompt = st.chat_input("Search for papers or ask questions...") + +if prompt: + # Display user message + prompt_msg = ChatMessage(prompt, role="user") + st.session_state.messages.append({"type": "message", "content": prompt_msg}) + + # Get agent response + with st.spinner("Processing your request..."): + # Debug print before invoke + print("Sending request to agent:", prompt) + + # Prepare initial state with all required fields + initial_state = { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "search_table": "", + "next": None, + "current_agent": None, + "is_last_step": False, # Ensure this is included + } + + config = {"configurable": {"thread_id": str(st.session_state.unique_id)}} + os.environ["AIAGENTS4PHARMA_LLM_MODEL"] = llm_option + + response = app.invoke(initial_state, config=config) + + # Debug print response + print("Agent response:", response) + + # Add response to chat history and display + if "messages" in response: + # Get the last AI message + ai_messages = [msg for msg in response["messages"] if msg.type == "ai"] + if ai_messages: + last_ai_message = ai_messages[-1] + + # Format paper results if present + if "papers" in response and response["papers"]: + papers_content = response["papers"] + formatted_message = "Here are the papers I found:\n\n" + + for idx, paper in enumerate(papers_content, start=1): + if isinstance(paper, str): + parts = paper.split("\n") + paper_id = parts[0].replace("Paper ID: ", "").strip() + title = parts[1].replace("Title: ", "").strip() + formatted_message += f"{idx}. **{title}**\n" + formatted_message += f" - Paper ID: {paper_id}\n\n" + else: + # Use the AI message content + formatted_message = last_ai_message.content + + assistant_msg = ChatMessage(formatted_message, role="assistant") + st.session_state.messages.append( + {"type": "message", "content": assistant_msg} + ) + + # Rerun to update display + st.rerun() From 59210933cf4b9d4ca7547771d4da371b7797b6d3 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:46:01 +0100 Subject: [PATCH 16/19] feat: add centralized configuration with system prompts and API settings --- .../talk2competitors/config/config.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/config/config.py diff --git a/aiagents4pharma/talk2competitors/config/config.py b/aiagents4pharma/talk2competitors/config/config.py new file mode 100644 index 00000000..8b955763 --- /dev/null +++ b/aiagents4pharma/talk2competitors/config/config.py @@ -0,0 +1,183 @@ +class Config: + # LLM Configuration + LLM_MODEL = "gpt-4o-mini" # Updated to GPT-4-mini + TEMPERATURE = 0.7 + + # API Endpoints + SEMANTIC_SCHOLAR_API = "https://api.semanticscholar.org/graph/v1" + + # API Keys + SEMANTIC_SCHOLAR_API_KEY = "YOUR_API_KEY" # Get this from Semantic Scholar + + # State Keys + class StateKeys: + PAPERS = "papers" + SELECTED_PAPERS = "selected_papers" + CURRENT_TOOL = "current_tool" + CURRENT_AGENT = "current_agent" + RESPONSE = "response" + ERROR = "error" + CHAT_HISTORY = "chat_history" + USER_INFO = "user_info" + MEMORY = "memory" + + # Agent Names + class AgentNames: + MAIN = "main_agent" + S2 = "semantic_scholar_agent" + ZOTERO = "zotero_agent" + PDF = "pdf_agent" + ARXIV = "arxiv_agent" + + # Tool Names (Keeping for reference) + class ToolNames: + # S2 Tools + S2_SEARCH = "search_papers" + S2_SINGLE_REC = "single_paper_recommendation" + S2_MULTI_REC = "multi_paper_recommendation" + + # Zotero Tools + ZOTERO_READ = "zotero_read" + ZOTERO_WRITE = "zotero_write" + + # PDF Tools + PDF_RAG = "pdf_rag" + + # arXiv Tools + ARXIV_DOWNLOAD = "arxiv_download" + + # Updated System Prompts + MAIN_AGENT_PROMPT = """You are a supervisory AI agent that routes user queries to specialized tools. +Your task is to select the most appropriate tool based on the user's request. + +Available tools and their capabilities: + +1. semantic_scholar_agent: + - Search for academic papers and research + - Get paper recommendations + - Find similar papers + USE FOR: Any queries about finding papers, academic research, or getting paper recommendations + +2. zotero_agent: + - Manage paper library + - Save and organize papers + USE FOR: Saving papers or managing research library + +3. pdf_agent: + - Analyze PDF content + - Answer questions about documents + USE FOR: Analyzing or asking questions about PDF content + +4. arxiv_agent: + - Download papers from arXiv + USE FOR: Specifically downloading papers from arXiv + +ROUTING GUIDELINES: + +ALWAYS route to semantic_scholar_agent for: +- Finding academic papers +- Searching research topics +- Getting paper recommendations +- Finding similar papers +- Any query about academic literature + +Route to zotero_agent for: +- Saving papers to library +- Managing references +- Organizing research materials + +Route to pdf_agent for: +- PDF content analysis +- Document-specific questions +- Understanding paper contents + +Route to arxiv_agent for: +- Paper download requests +- arXiv-specific queries + +Approach: +1. Identify the core need in the user's query +2. Select the most appropriate tool based on the guidelines above +3. If unclear, ask for clarification +4. For multi-step tasks, focus on the immediate next step + +Remember: +- Be decisive in your tool selection +- Focus on the immediate task +- Default to semantic_scholar_agent for any paper-finding tasks +- Ask for clarification if the request is ambiguous + +IMPORTANT GUIDELINES FOR PAPER RECOMMENDATIONS: + +For Multiple Papers: +- When getting recommendations for multiple papers, always use get_multi_paper_recommendations tool +- DO NOT call get_single_paper_recommendations multiple times +- Always pass all paper IDs in a single call to get_multi_paper_recommendations +- Use for queries like "find papers related to both/all papers" or "find similar papers to these papers" + +For Single Paper: +- Use get_single_paper_recommendations when focusing on one specific paper +- Pass only one paper ID at a time +- Use for queries like "find papers similar to this paper" or "get recommendations for paper X" +- Do not use for multiple papers + +Examples: +- For "find related papers for both papers": + ✓ Use get_multi_paper_recommendations with both paper IDs + × Don't make multiple calls to get_single_paper_recommendations + +- For "find papers related to the first paper": + ✓ Use get_single_paper_recommendations with just that paper's ID + × Don't use get_multi_paper_recommendations + +Remember: +- Be precise in identifying which paper ID to use for single recommendations +- Don't reuse previous paper IDs unless specifically requested +- For fresh paper recommendations, always use the original paper ID""" + + S2_AGENT_PROMPT = """You are a specialized academic research assistant with access to the following tools: + +1. search_papers: + USE FOR: General paper searches + - Enhances search terms automatically + - Adds relevant academic keywords + - Focuses on recent research when appropriate + +2. get_single_paper_recommendations: + USE FOR: Finding papers similar to a specific paper + - Takes a single paper ID + - Returns related papers + +3. get_multi_paper_recommendations: + USE FOR: Finding papers similar to multiple papers + - Takes multiple paper IDs + - Finds papers related to all inputs + +GUIDELINES: + +For paper searches: +- Enhance search terms with academic language +- Include field-specific terminology +- Add "recent" or "latest" when appropriate +- Keep queries focused and relevant + +For paper recommendations: +- Identify paper IDs (40-character hexadecimal strings) +- Use single_paper_recommendations for one ID +- Use multi_paper_recommendations for multiple IDs + +Best practices: +1. Start with a broad search if no paper IDs are provided +2. Look for paper IDs in user input +3. Enhance search terms for better results +4. Consider the academic context +5. Be prepared to refine searches based on feedback + +Remember: +- Always select the most appropriate tool +- Enhance search queries naturally +- Consider academic context +- Focus on delivering relevant results""" + + +config = Config() From ade898f09e47105f5442586be59bc799cdb1289e Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Tue, 21 Jan 2025 23:46:43 +0100 Subject: [PATCH 17/19] feat: implement shared state management for agent communication --- .../talk2competitors/state/shared_state.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 aiagents4pharma/talk2competitors/state/shared_state.py diff --git a/aiagents4pharma/talk2competitors/state/shared_state.py b/aiagents4pharma/talk2competitors/state/shared_state.py new file mode 100644 index 00000000..1b3618c9 --- /dev/null +++ b/aiagents4pharma/talk2competitors/state/shared_state.py @@ -0,0 +1,40 @@ +""" +This is the state file for the talk2comp agent. +""" + +import logging +from typing import Annotated, List, Optional + +from langgraph.prebuilt.chat_agent_executor import AgentState +from typing_extensions import NotRequired, Required + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def replace_list(existing: List[str], new: List[str]) -> List[str]: + """Replace the existing list with the new one.""" + logger.info("Updating state list: %s", new) + return new + + +class talk2comp(AgentState): + """ + The state for the talk2comp agent, inheriting from AgentState. + """ + + papers: Annotated[List[str], replace_list] + search_table: NotRequired[str] + next: str # Required for routing in LangGraph + current_agent: NotRequired[Optional[str]] + is_last_step: Required[bool] # Required field for LangGraph + + def log_state_update(self) -> None: + """Log current state for debugging.""" + logger.info( + "Current State - Agent: %s, Next: %s", + self.get("current_agent"), + self.get("next"), + ) + logger.info("Papers count: %d", len(self.get("papers", []))) From 032eb1cacf136dbd6421ab541bad397c27f8df17 Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Wed, 22 Jan 2025 15:40:11 +0100 Subject: [PATCH 18/19] fix: update --- .../talk2competitors/agents/__init__.py | 9 +- .../talk2competitors/agents/main_agent.py | 94 +++----- .../talk2competitors/agents/s2_agent.py | 190 +++++---------- .../talk2competitors/state/__init__.py | 5 + ...red_state.py => state_talk2competitors.py} | 19 +- .../talk2competitors/tests/conftest.py | 29 --- .../talk2competitors/tests/test_langgraph.py | 92 +++++++ .../talk2competitors/tests/test_talk2comp.py | 153 ------------ .../talk2competitors/tools/__init__.py | 7 + .../talk2competitors/tools/s2/__init__.py | 26 +- .../tools/s2/display_results.py | 9 +- .../tools/s2/multi_paper_rec.py | 188 ++++----------- .../talk2competitors/tools/s2/search.py | 65 +---- .../tools/s2/single_paper_rec.py | 152 +++++------- .../talk2competitors/utils/__init__.py | 1 - aiagents4pharma/talk2competitors/utils/llm.py | 104 -------- .../streamlit_app_talk2competitors.py | 225 ++++++++++++++++++ .../streamlit_app_talk2competitors2.py | 5 +- 18 files changed, 561 insertions(+), 812 deletions(-) create mode 100644 aiagents4pharma/talk2competitors/state/__init__.py rename aiagents4pharma/talk2competitors/state/{shared_state.py => state_talk2competitors.py} (67%) delete mode 100644 aiagents4pharma/talk2competitors/tests/conftest.py create mode 100644 aiagents4pharma/talk2competitors/tests/test_langgraph.py delete mode 100644 aiagents4pharma/talk2competitors/tests/test_talk2comp.py create mode 100644 aiagents4pharma/talk2competitors/tools/__init__.py delete mode 100644 aiagents4pharma/talk2competitors/utils/__init__.py delete mode 100644 aiagents4pharma/talk2competitors/utils/llm.py create mode 100644 app/frontend/streamlit_app_talk2competitors.py rename aiagents4pharma/talk2competitors/app/talk2comp.py => app/frontend/streamlit_app_talk2competitors2.py (97%) diff --git a/aiagents4pharma/talk2competitors/agents/__init__.py b/aiagents4pharma/talk2competitors/agents/__init__.py index 1b1a3415..3423f265 100644 --- a/aiagents4pharma/talk2competitors/agents/__init__.py +++ b/aiagents4pharma/talk2competitors/agents/__init__.py @@ -1,5 +1,6 @@ -# Expose main agent and sub-agents at package level -from agents.main_agent import get_app -from agents.s2_agent import s2_agent +''' +This file is used to import all the modules in the package. +''' -__all__ = ["get_app", "s2_agent"] +from . import main_agent +from . import s2_agent diff --git a/aiagents4pharma/talk2competitors/agents/main_agent.py b/aiagents4pharma/talk2competitors/agents/main_agent.py index e01d8672..f72257dd 100644 --- a/aiagents4pharma/talk2competitors/agents/main_agent.py +++ b/aiagents4pharma/talk2competitors/agents/main_agent.py @@ -1,26 +1,27 @@ +#!/usr/bin/env python3 + +""" +Main agent for the talk2competitors app. +""" + import logging from typing import Literal - -import requests from dotenv import load_dotenv from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage -from langchain_core.tools import ToolException from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from langgraph.types import Command - -from agents.s2_agent import s2_agent -from config.config import config -from state.shared_state import talk2comp +from ..agents import s2_agent +from ..config.config import config +from ..state.state_talk2competitors import Talk2Competitors logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() - def make_supervisor_node(llm: BaseChatModel) -> str: """ Creates a supervisor node following LangGraph patterns. @@ -33,12 +34,12 @@ def make_supervisor_node(llm: BaseChatModel) -> str: """ # options = ["FINISH", "s2_agent"] - def supervisor_node(state: talk2comp) -> Command[Literal["s2_agent", "__end__"]]: + def supervisor_node(state: Talk2Competitors) -> Command[Literal["s2_agent", "__end__"]]: """ Supervisor node that routes to appropriate sub-agents. Args: - state (talk2comp): The current state of the conversation. + state (Talk2Competitors): The current state of the conversation. Returns: Command[Literal["s2_agent", "__end__"]]: The command to execute next. @@ -80,20 +81,29 @@ def supervisor_node(state: talk2comp) -> Command[Literal["s2_agent", "__end__"]] return supervisor_node - -def call_s2_agent(state: talk2comp) -> Command[Literal["__end__"]]: +def get_app(thread_id: str, llm_model ='gpt-4o-mini') -> StateGraph: """ - Node for calling the S2 agent. + Returns the langraph app with hierarchical structure. Args: - state (talk2comp): The current state of the conversation. + thread_id (str): The thread ID for the conversation. Returns: - Command[Literal["__end__"]]: The command to execute next. + The compiled langraph app. """ - logger.info("Calling S2 agent") - try: - response = s2_agent.invoke(state) + def call_s2_agent(state: Talk2Competitors) -> Command[Literal["__end__"]]: + """ + Node for calling the S2 agent. + + Args: + state (Talk2Competitors): The current state of the conversation. + + Returns: + Command[Literal["__end__"]]: The command to execute next. + """ + logger.info("Calling S2 agent") + app = s2_agent.get_app(thread_id, llm_model) + response = app.invoke(state) logger.info("S2 agent completed") return Command( goto=END, @@ -104,52 +114,8 @@ def call_s2_agent(state: talk2comp) -> Command[Literal["__end__"]]: "current_agent": "s2_agent", }, ) - except requests.RequestException as e: - logger.error("Network error in S2 agent: %s", str(e)) - return Command( - goto=END, - update={ - "messages": state["messages"] - + [AIMessage(content=f"Network error: {str(e)}")], - "is_last_step": True, - "current_agent": "s2_agent", - }, - ) - except ValueError as e: - logger.error("Value error in S2 agent: %s", str(e)) - return Command( - goto=END, - update={ - "messages": state["messages"] - + [AIMessage(content=f"Input error: {str(e)}")], - "is_last_step": True, - "current_agent": "s2_agent", - }, - ) - except ToolException as e: - logger.error("Tool error in S2 agent: %s", str(e)) - return Command( - goto=END, - update={ - "messages": state["messages"] + [AIMessage(content=str(e))], - "is_last_step": True, - "current_agent": "s2_agent", - }, - ) - - -def get_app(thread_id: str): - """ - Returns the langraph app with hierarchical structure. - - Args: - thread_id (str): The thread ID for the conversation. - - Returns: - The compiled langraph app. - """ - llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) - workflow = StateGraph(talk2comp) + llm = ChatOpenAI(model=llm_model, temperature=0) + workflow = StateGraph(Talk2Competitors) supervisor = make_supervisor_node(llm) workflow.add_node("supervisor", supervisor) diff --git a/aiagents4pharma/talk2competitors/agents/s2_agent.py b/aiagents4pharma/talk2competitors/agents/s2_agent.py index efd507ee..c090862a 100644 --- a/aiagents4pharma/talk2competitors/agents/s2_agent.py +++ b/aiagents4pharma/talk2competitors/agents/s2_agent.py @@ -1,133 +1,75 @@ -import logging -from typing import Literal +#/usr/bin/env python3 + +''' +Agent for interacting with Semantic Scholar +''' -import requests +import logging from dotenv import load_dotenv -from langchain_core.messages import AIMessage -from langchain_core.tools import ToolException from langchain_openai import ChatOpenAI -from langgraph.graph import END, START, StateGraph +from langgraph.graph import START, StateGraph from langgraph.prebuilt import create_react_agent -from langgraph.types import Command +from langgraph.checkpoint.memory import MemorySaver +from ..config.config import config +from ..state.state_talk2competitors import Talk2Competitors +# from ..tools.s2 import s2_tools +from ..tools.s2.search import search_tool +from ..tools.s2.display_results import display_results +from ..tools.s2.single_paper_rec import get_single_paper_recommendations +from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations -from config.config import config -from state.shared_state import talk2comp -from tools.s2 import s2_tools +load_dotenv() +# Initialize logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -load_dotenv() - - -class SemanticScholarAgent: - """ - Agent for interacting with Semantic Scholar using LangGraph and LangChain. - """ - - def __init__(self): - """ - Initializes the SemanticScholarAgent with necessary configurations. - """ - try: - logger.info("Initializing S2 Agent...") - self.llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) - - # Create the tools agent using config prompt - self.tools_agent = create_react_agent( - self.llm, - tools=s2_tools, - state_schema=talk2comp, - state_modifier=config.S2_AGENT_PROMPT, - ) - - def execute_tools(state: talk2comp) -> Command[Literal["__end__"]]: - """ - Execute tools and return results. - - Args: - state (talk2comp): The current state of the conversation. - - Returns: - Command[Literal["__end__"]]: The command to execute next. - """ - logger.info("Executing tools") - try: - result = self.tools_agent.invoke(state) - logger.info("Tool execution completed") - return Command( - goto=END, - update={ - "messages": result["messages"], - "papers": result.get("papers", []), - "is_last_step": True, - }, - ) - except (requests.RequestException, ToolException) as e: - logger.error("API or tool error: %s", str(e)) - return Command( - goto=END, - update={ - "messages": [AIMessage(content=f"Error: {str(e)}")], - "is_last_step": True, - }, - ) - except ValueError as e: - logger.error("Value error: %s", str(e)) - return Command( - goto=END, - update={ - "messages": [ - AIMessage(content=f"Input validation error: {str(e)}") - ], - "is_last_step": True, - }, - ) - - # Create graph - workflow = StateGraph(talk2comp) - workflow.add_node("tools", execute_tools) - workflow.add_edge(START, "tools") - - self.graph = workflow.compile() - logger.info("S2 Agent initialized successfully") - - except Exception as e: - logger.error("Initialization error: %s", str(e)) - raise - - def invoke(self, state): - """ - Invokes the SemanticScholarAgent with the given state. - - Args: - state (talk2comp): The current state of the conversation. - - Returns: - dict: The result of the invocation, including messages and papers. - """ - try: - logger.info("Invoking S2 agent") - return self.graph.invoke(state) - except (requests.RequestException, ToolException) as e: - logger.error("Network or tool error in S2 agent: %s", str(e)) - return { - "messages": [AIMessage(content=f"Error in processing: {str(e)}")], - "papers": [], - } - except ValueError as e: - logger.error("Value error in S2 agent: %s", str(e)) - return { - "messages": [AIMessage(content=f"Invalid input: {str(e)}")], - "papers": [], - } - except RuntimeError as e: - logger.error("Runtime error in S2 agent: %s", str(e)) - return { - "messages": [AIMessage(content=f"Internal error: {str(e)}")], - "papers": [], - } - - -# Create a global instance -s2_agent = SemanticScholarAgent() +def get_app(uniq_id, llm_model='gpt-4o-mini'): + ''' + This function returns the langraph app. + ''' + def agent_s2_node(state: Talk2Competitors): + ''' + This function calls the model. + ''' + logger.log(logging.INFO, "Creating Agent_S2 node with thread_id %s", uniq_id) + response = model.invoke(state, {"configurable": {"thread_id": uniq_id}}) + return response + + # Define the tools + tools = [search_tool, + display_results, + get_single_paper_recommendations, + get_multi_paper_recommendations] + + # Create the LLM + llm = ChatOpenAI(model=llm_model, temperature=0) + model = create_react_agent( + llm, + tools=tools, + state_schema=Talk2Competitors, + state_modifier=config.S2_AGENT_PROMPT, + checkpointer=MemorySaver() + ) + + # Define a new graph + workflow = StateGraph(Talk2Competitors) + + # Define the two nodes we will cycle between + workflow.add_node("agent_s2", agent_s2_node) + + # Set the entrypoint as `agent` + # This means that this node is the first one called + workflow.add_edge(START, "agent_s2") + + # Initialize memory to persist state between graph runs + checkpointer = MemorySaver() + + # Finally, we compile it! + # This compiles it into a LangChain Runnable, + # meaning you can use it as you would any other runnable. + # Note that we're (optionally) passing the memory when compiling the graph + app = workflow.compile(checkpointer=checkpointer) + logger.log(logging.INFO, "Compiled the graph") + + return app diff --git a/aiagents4pharma/talk2competitors/state/__init__.py b/aiagents4pharma/talk2competitors/state/__init__.py new file mode 100644 index 00000000..8cbeabf3 --- /dev/null +++ b/aiagents4pharma/talk2competitors/state/__init__.py @@ -0,0 +1,5 @@ +''' +This file is used to import all the modules in the package. +''' + +from . import state_talk2competitors diff --git a/aiagents4pharma/talk2competitors/state/shared_state.py b/aiagents4pharma/talk2competitors/state/state_talk2competitors.py similarity index 67% rename from aiagents4pharma/talk2competitors/state/shared_state.py rename to aiagents4pharma/talk2competitors/state/state_talk2competitors.py index 1b3618c9..3f800e6c 100644 --- a/aiagents4pharma/talk2competitors/state/shared_state.py +++ b/aiagents4pharma/talk2competitors/state/state_talk2competitors.py @@ -4,7 +4,6 @@ import logging from typing import Annotated, List, Optional - from langgraph.prebuilt.chat_agent_executor import AgentState from typing_extensions import NotRequired, Required @@ -12,29 +11,21 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - def replace_list(existing: List[str], new: List[str]) -> List[str]: """Replace the existing list with the new one.""" - logger.info("Updating state list: %s", new) + logger.info("Updating existing state %s with the state list: %s", + existing, + new) return new -class talk2comp(AgentState): +class Talk2Competitors(AgentState): """ The state for the talk2comp agent, inheriting from AgentState. """ - papers: Annotated[List[str], replace_list] search_table: NotRequired[str] next: str # Required for routing in LangGraph current_agent: NotRequired[Optional[str]] is_last_step: Required[bool] # Required field for LangGraph - - def log_state_update(self) -> None: - """Log current state for debugging.""" - logger.info( - "Current State - Agent: %s, Next: %s", - self.get("current_agent"), - self.get("next"), - ) - logger.info("Papers count: %d", len(self.get("papers", []))) + llm_model: str diff --git a/aiagents4pharma/talk2competitors/tests/conftest.py b/aiagents4pharma/talk2competitors/tests/conftest.py deleted file mode 100644 index 6ab3f6d1..00000000 --- a/aiagents4pharma/talk2competitors/tests/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Test configuration and fixtures""" - -import os - -import pytest -from dotenv import load_dotenv - - -@pytest.fixture(autouse=True) -def setup_env(): - """Setup environment variables for tests""" - load_dotenv() - os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "test-key") - yield - - -@pytest.fixture -def mock_response(): - """Fixture for mocked API responses""" - return { - "data": [ - { - "paperId": "1234567890123456789012345678901234567890", - "title": "Test Paper", - "abstract": "Test abstract", - "year": 2024, - } - ] - } diff --git a/aiagents4pharma/talk2competitors/tests/test_langgraph.py b/aiagents4pharma/talk2competitors/tests/test_langgraph.py new file mode 100644 index 00000000..27e0597e --- /dev/null +++ b/aiagents4pharma/talk2competitors/tests/test_langgraph.py @@ -0,0 +1,92 @@ +''' +Test cases +''' + +from langchain_core.messages import HumanMessage +from ..agents.main_agent import get_app + +def test_main_agent(): + ''' + Test the main agent. + ''' + unique_id = "test_12345" + app = get_app(unique_id) + config = {"configurable": {"thread_id": unique_id}} + #################################################### + prompt = "Without calling any tool, tell me the capital of France" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + # Check if the assistant message is a string + assert 'Paris' in assistant_msg + #################################################### + prompt = "Search articles on machine learning" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + # Check if the assistant message is a string + assert 'Fashion-MNIST' in assistant_msg + #################################################### + prompt = "Recommend articles using the first paper of the previous search" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'CNN Models' in assistant_msg + #################################################### + prompt = "Recommend articles using both papers of your last response" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'Efficient Handwritten Digit Classification' in assistant_msg + ################################################### + prompt = "Show me the papers in the state" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'Classification of Fashion-MNIST Dataset' in assistant_msg diff --git a/aiagents4pharma/talk2competitors/tests/test_talk2comp.py b/aiagents4pharma/talk2competitors/tests/test_talk2comp.py deleted file mode 100644 index 6ff2d877..00000000 --- a/aiagents4pharma/talk2competitors/tests/test_talk2comp.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Test cases for talk2comp agents and tools""" - -from unittest.mock import patch - -import pytest -from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.tools import ToolException # Add ToolException import - -from agents.main_agent import get_app -from agents.s2_agent import s2_agent -from tools.s2.multi_paper_rec import get_multi_paper_recommendations -from tools.s2.search import search_tool -from tools.s2.single_paper_rec import get_single_paper_recommendations - -# Mock data for tests -MOCK_PAPER_RESPONSE = { - "recommendedPapers": [{"paperId": "abc123", "title": "Test Paper"}] -} - - -@pytest.fixture -def mock_api_response(): - """Fixture for mocked API responses""" - return MOCK_PAPER_RESPONSE - - -def test_main_agent_routing(): - """Test the main agent's routing capabilities""" - unique_id = "test_12345" - app = get_app(unique_id) - config = {"configurable": {"thread_id": unique_id}} - - # Test search routing - prompt = "Find me papers about machine learning" - response = app.invoke( - { - "messages": [HumanMessage(content=prompt)], - "papers": [], - "is_last_step": False, - "current_agent": None, - }, - config=config, - ) - - assert response["current_agent"] == "s2_agent" - assert isinstance(response["messages"][-1], AIMessage) - - -def test_s2_agent(): - """Test the S2 agent's functionality""" - assert hasattr(s2_agent, "tools_agent") - - state = { - "messages": [HumanMessage(content="Find papers about machine learning")], - "papers": [], - "is_last_step": False, - } - - with patch("requests.get") as mock_get: - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = MOCK_PAPER_RESPONSE - - response = s2_agent.invoke(state) - assert "messages" in response - assert "papers" in response - - -def test_search_tool(): - """Test the search papers tool""" - query = "machine learning" - tool_call_id = "test_123" - - with patch("requests.get") as mock_get: - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = {"data": [MOCK_PAPER_RESPONSE]} - - response = search_tool.func(query=query, tool_call_id=tool_call_id, limit=2) - - assert "papers" in response - assert "messages" in response - assert isinstance(response["papers"], list) - - -def test_single_paper_rec(): - """Test single paper recommendations""" - paper_id = "1234567890123456789012345678901234567890" - tool_call_id = "test_123" - - with patch("requests.get") as mock_get: - mock_get.return_value.status_code = 200 - mock_get.return_value.json.return_value = MOCK_PAPER_RESPONSE - - response = get_single_paper_recommendations.func( - paper_id=paper_id, tool_call_id=tool_call_id, limit=2 - ) - - assert response is not None - assert "papers" in response.update - - -def test_multi_paper_rec(): - """Test multi paper recommendations""" - paper_ids = [ - "1234567890123456789012345678901234567890", - "0987654321098765432109876543210987654321", - ] - tool_call_id = "test_123" - - with patch("requests.post") as mock_post: - mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = MOCK_PAPER_RESPONSE - - response = get_multi_paper_recommendations.func( - paper_ids=paper_ids, tool_call_id=tool_call_id, limit=2 - ) - - assert response is not None - assert "papers" in response.update - - -def test_error_handling(): - """Test error handling in tools""" - # Test invalid paper ID format - with pytest.raises((ValueError, ToolException)) as exc_info: - with patch("requests.get") as mock_get: - mock_get.return_value.status_code = 404 - get_single_paper_recommendations.func( - paper_id="invalid_id", tool_call_id="test_123", limit=2 - ) - assert any( - msg in str(exc_info.value) - for msg in ["40-character hexadecimal", "Error getting recommendations"] - ) - - # Test empty paper IDs list for multi-paper recommendations - with patch("requests.post") as mock_post: - mock_post.return_value.status_code = 400 - with pytest.raises((ValueError, ToolException)) as exc_info: - get_multi_paper_recommendations.func( - paper_ids=[], tool_call_id="test_123", limit=2 - ) - assert "At least one paper ID must be provided" in str(exc_info.value) - - # Test too many paper IDs - with patch("requests.post") as mock_post: - mock_post.return_value.status_code = 400 - with pytest.raises(ValueError) as exc_info: - get_multi_paper_recommendations.func( - paper_ids=["a" * 40 for _ in range(11)], # 11 IDs - tool_call_id="test_123", - limit=2, - ) - assert "Maximum of 10 paper IDs allowed" in str(exc_info.value) diff --git a/aiagents4pharma/talk2competitors/tools/__init__.py b/aiagents4pharma/talk2competitors/tools/__init__.py new file mode 100644 index 00000000..c69aba29 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +''' +Import statements +''' + +from . import s2 diff --git a/aiagents4pharma/talk2competitors/tools/s2/__init__.py b/aiagents4pharma/talk2competitors/tools/s2/__init__.py index a07b8614..f809aa36 100644 --- a/aiagents4pharma/talk2competitors/tools/s2/__init__.py +++ b/aiagents4pharma/talk2competitors/tools/s2/__init__.py @@ -1,20 +1,8 @@ -from tools.s2.display_results import display_results -from tools.s2.multi_paper_rec import get_multi_paper_recommendations -from tools.s2.search import search_tool -from tools.s2.single_paper_rec import get_single_paper_recommendations +''' +This file is used to import all the modules in the package. +''' -# Export all tools in a list for easy access -s2_tools = [ - search_tool, - display_results, - get_single_paper_recommendations, - get_multi_paper_recommendations, -] - -__all__ = [ - "search_tool", - "display_results", - "get_single_paper_recommendations", - "get_multi_paper_recommendations", - "s2_tools", -] +from . import display_results +from . import multi_paper_rec +from . import search +from . import single_paper_rec diff --git a/aiagents4pharma/talk2competitors/tools/s2/display_results.py b/aiagents4pharma/talk2competitors/tools/s2/display_results.py index 6211b87d..1c06dde9 100644 --- a/aiagents4pharma/talk2competitors/tools/s2/display_results.py +++ b/aiagents4pharma/talk2competitors/tools/s2/display_results.py @@ -4,17 +4,22 @@ This tool is used to display the table of studies. ''' +import logging from typing import Annotated from langchain_core.tools import tool from langgraph.prebuilt import InjectedState +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + @tool('display_results') def display_results(state: Annotated[dict, InjectedState]): """ - Display the table of studies. + Display the papers in the state. Args: state (dict): The state of the agent. """ - print ('Called display_results') + logger.info("Displaying papers from the state") return state["papers"] diff --git a/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py index 1254ce6d..58174a3f 100644 --- a/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +++ b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py @@ -1,15 +1,20 @@ +#!/usr/bin/env python3 + +""" +multi_paper_rec: Tool for getting recommendations + based on multiple papers +""" + +import logging import json -import time from typing import Annotated, Any, Dict, List - import pandas as pd import requests from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from langgraph.types import Command -from pydantic import BaseModel, Field, field_validator - +from pydantic import BaseModel, Field class MultiPaperRecInput(BaseModel): """Input schema for multiple paper recommendations tool.""" @@ -25,27 +30,6 @@ class MultiPaperRecInput(BaseModel): ) tool_call_id: Annotated[str, InjectedToolCallId] - @classmethod - @field_validator("paper_ids") - def validate_paper_ids(cls, v: List[str]) -> List[str]: - """ - Validates the list of paper IDs. - - Args: - v (List[str]): The list of paper IDs to validate. - - Returns: - List[str]: The validated list of paper IDs. - - Raises: - ValueError: If the list is empty, contains more than 10 IDs, or any ID has an invalid format. - """ - if not v: - raise ValueError("At least one paper ID must be provided") - if len(v) > 10: - raise ValueError("Maximum of 10 paper IDs allowed") - return v - model_config = {"arbitrary_types_allowed": True} @@ -66,140 +50,54 @@ def get_multi_paper_recommendations( Returns: Dict[str, Any]: The recommendations and related information. """ - # Validate inputs - if not paper_ids: - raise ValueError("At least one paper ID must be provided") - if len(paper_ids) > 10: - raise ValueError("Maximum of 10 paper IDs allowed") - print("Starting multi-paper recommendations search...") + logging.info("Starting multi-paper recommendations search.") endpoint = "https://api.semanticscholar.org/recommendations/v1/papers" headers = {"Content-Type": "application/json"} payload = {"positivePaperIds": paper_ids, "negativePaperIds": []} params = {"limit": min(limit, 500), "fields": "title,paperId"} - max_retries = 3 - retry_count = 0 - retry_delay = 2 - - while retry_count < max_retries: - print(f"Attempt {retry_count + 1} of {max_retries}") - try: - response = requests.post( - endpoint, - headers=headers, - params=params, - data=json.dumps(payload), - timeout=10, + # Getting recommendations + response = requests.post( + endpoint, + headers=headers, + params=params, + data=json.dumps(payload), + timeout=10, + ) + logging.info("API Response Status for multi-paper recommendations: %s", + response.status_code) + + data = response.json() + recommendations = data.get("recommendedPapers", []) + + # Create a list to store the papers + papers_list = [] + for paper in recommendations: + if paper.get("title") and paper.get("paperId"): + papers_list.append( + {"Paper ID": paper["paperId"], "Title": paper["title"]} ) - print(f"API Response Status: {response.status_code}") - - if response.status_code == 200: - data = response.json() - recommendations = data.get("recommendedPapers", []) - - if not recommendations: - print("No recommendations found") - return Command( - update={ - "papers": [ - "No recommendations found for the provided papers" - ], - "messages": [ - ToolMessage( - content="No recommendations found for the provided papers", - tool_call_id=tool_call_id, - ) - ], - } - ) - - # Create a list to store the papers - papers_list = [] - for paper in recommendations: - if paper.get("title") and paper.get("paperId"): - papers_list.append( - {"Paper ID": paper["paperId"], "Title": paper["title"]} - ) - - if not papers_list: - return Command( - update={ - "papers": ["No valid recommendations found"], - "messages": [ - ToolMessage( - content="No valid recommendations found", - tool_call_id=tool_call_id, - ) - ], - } - ) - - df = pd.DataFrame(papers_list) - print("Created DataFrame with results:") - print(df) - - # Format papers for state update - formatted_papers = [ - f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" - for paper in papers_list - ] - - markdown_table = df.to_markdown(tablefmt="grid") - return Command( - update={ - "papers": formatted_papers, - "messages": [ - ToolMessage( - content=markdown_table, tool_call_id=tool_call_id - ) - ], - } - ) - if response.status_code == 404: - return Command( - update={ - "papers": ["One or more paper IDs not found"], - "messages": [ - ToolMessage( - content="One or more paper IDs not found", - tool_call_id=tool_call_id, - ) - ], - } - ) + # Create a DataFrame from the list of papers + df = pd.DataFrame(papers_list) + # print("Created DataFrame with results:") + logging.info("Created DataFrame with results: %s", df) - retry_count += 1 - if retry_count < max_retries: - wait_time = retry_delay * (2**retry_count) - print(f"Retrying in {wait_time} seconds...") - time.sleep(wait_time) - - except Exception as e: - print(f"Error: {str(e)}") - retry_count += 1 - if retry_count == max_retries: - return Command( - update={ - "papers": [f"Error getting recommendations: {str(e)}"], - "messages": [ - ToolMessage( - content=f"Error getting recommendations: {str(e)}", - tool_call_id=tool_call_id, - ) - ], - } - ) - time.sleep(retry_delay * (2**retry_count)) + # Format papers for state update + formatted_papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in papers_list + ] + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(tablefmt="grid") return Command( update={ - "papers": ["Failed to get recommendations after maximum retries"], + "papers": formatted_papers, "messages": [ ToolMessage( - content="Failed to get recommendations after maximum retries", - tool_call_id=tool_call_id, + content=markdown_table, tool_call_id=tool_call_id ) ], } diff --git a/aiagents4pharma/talk2competitors/tools/s2/search.py b/aiagents4pharma/talk2competitors/tools/s2/search.py index 0d8f61de..31703885 100644 --- a/aiagents4pharma/talk2competitors/tools/s2/search.py +++ b/aiagents4pharma/talk2competitors/tools/s2/search.py @@ -1,15 +1,19 @@ -import time -from typing import Annotated, Any, Dict +#!/usr/bin/env python3 + +""" +This tool is used to search for academic papers on Semantic Scholar. +""" +import logging +from typing import Annotated, Any, Dict import pandas as pd import requests from langchain_core.messages import AIMessage -from langchain_core.tools import ToolException, tool +from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from pydantic import BaseModel, Field -from config.config import config - +from ...config.config import config class SearchInput(BaseModel): """Input schema for the search papers tool.""" @@ -23,7 +27,6 @@ class SearchInput(BaseModel): ) tool_call_id: Annotated[str, InjectedToolCallId] - @tool(args_schema=SearchInput) def search_tool( query: str, @@ -48,35 +51,7 @@ def search_tool( "limit": min(limit, 100), "fields": "paperId,title,abstract,year,authors,citationCount,openAccessPdf", } - - max_retries = 3 - retry_count = 0 - retry_delay = 2 - while retry_count < max_retries: - try: - print(f"Attempt {retry_count + 1} of {max_retries}") - response = requests.get(endpoint, params=params, timeout=10) - if response.status_code == 429: - retry_count += 1 - wait_time = retry_delay * (2**retry_count) - print(f"Rate limit hit. Waiting {wait_time} seconds...") - time.sleep(wait_time) - continue - if response.status_code == 200: - print("Successful response received") - break - response.raise_for_status() - except requests.exceptions.RequestException as e: - print(f"Request failed: {str(e)}") - retry_count += 1 - if retry_count == max_retries: - raise ToolException( - f"Error searching papers after {max_retries} attempts: {str(e)}" - ) from e - time.sleep(retry_delay * (2**retry_count)) - continue - - print("Processing response...") + response = requests.get(endpoint, params=params, timeout=10) data = response.json() papers = data.get("data", []) @@ -86,25 +61,7 @@ def search_tool( if paper.get("title") and paper.get("authors") ] - if not filtered_papers: - return { - "papers": ["No papers found matching your query."], - "messages": [AIMessage(content="No papers found matching your query")], - "tool_calls": [ - { - "id": tool_call_id, - "type": "function", - "function": { - "name": "search_tool", - "arguments": {"query": query, "limit": limit}, - }, - } - ], - } - df = pd.DataFrame(filtered_papers) - print("Created DataFrame with results") - print(df) papers = [ f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" @@ -112,7 +69,7 @@ def search_tool( ] markdown_table = df.to_markdown(tablefmt="grid") - print("Search tool execution completed") + logging.info("Search results: %s", papers) return { "papers": papers, diff --git a/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py index 0ef61a76..96f47a98 100644 --- a/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +++ b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py @@ -1,15 +1,22 @@ -import re -import time -from typing import Annotated, Any, Dict +#!/usr/bin/env python3 + +''' +This tool is used to return recommendations for a single paper. +''' +import logging +from typing import Annotated, Any, Dict import pandas as pd import requests from langchain_core.messages import ToolMessage -from langchain_core.tools import ToolException, tool +from langchain_core.tools import tool from langchain_core.tools.base import InjectedToolCallId from langgraph.types import Command -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) class SinglePaperRecInput(BaseModel): """Input schema for single paper recommendation tool.""" @@ -24,35 +31,15 @@ class SinglePaperRecInput(BaseModel): le=500, ) tool_call_id: Annotated[str, InjectedToolCallId] - - @classmethod - @field_validator("paper_id") - def validate_paper_id(cls, v: str) -> str: - """ - Validates the paper ID. - - Args: - v (str): The paper ID to validate. - - Returns: - str: The validated paper ID. - - Raises: - ValueError: If the paper ID is not a 40-character hexadecimal string. - """ - if not re.match(r"^[a-f0-9]{40}$", v): - raise ValueError("Paper ID must be a 40-character hexadecimal string") - return v - model_config = {"arbitrary_types_allowed": True} @tool(args_schema=SinglePaperRecInput) def get_single_paper_recommendations( - paper_id: str, - tool_call_id: Annotated[str, InjectedToolCallId], - limit: int = 2, -) -> Dict[str, Any]: + paper_id: str, + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, + ) -> Dict[str, Any]: """ Get paper recommendations based on a single paper. @@ -64,10 +51,7 @@ def get_single_paper_recommendations( Returns: Dict[str, Any]: The recommendations and related information. """ - # Validate paper ID format first - if not re.match(r"^[a-f0-9]{40}$", paper_id): - raise ValueError("Paper ID must be a 40-character hexadecimal string") - print("Starting single paper recommendations search...") + logger.info("Starting single paper recommendations search.") endpoint = ( f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}" @@ -78,68 +62,42 @@ def get_single_paper_recommendations( "from": "all-cs", # Using all-cs pool as specified in docs } - max_retries = 3 - retry_count = 0 - retry_delay = 2 - - while retry_count < max_retries: - print(f"Attempt {retry_count + 1} of {max_retries}") - response = requests.get(endpoint, params=params, timeout=10) - print(f"API Response Status: {response.status_code}") - print(f"Request params: {params}") - - if response.status_code == 200: - data = response.json() - print(f"Raw API Response: {data}") - recommendations = data.get("recommendedPapers", []) - - if recommendations: - filtered_papers = [ - {"Paper ID": paper["paperId"], "Title": paper["title"]} - for paper in recommendations - if paper.get("title") and paper.get("paperId") - ] - - if filtered_papers: - df = pd.DataFrame(filtered_papers) - - papers = [ - f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" - for paper in filtered_papers - ] - - markdown_table = df.to_markdown(tablefmt="grid") - - return Command( - update={ - "papers": papers, - "messages": [ - ToolMessage( - content=markdown_table, tool_call_id=tool_call_id - ) - ], - } - ) - - return Command( - update={ - "papers": [], - "messages": [ - ToolMessage( - content="No recommendations found for this paper", - tool_call_id=tool_call_id, - ) - ], - } - ) - - retry_count += 1 - if retry_count < max_retries: - wait_time = retry_delay * (2**retry_count) - print(f"Retrying in {wait_time} seconds...") - time.sleep(wait_time) - - raise ToolException( - "Error getting recommendations after " - f"{max_retries} attempts. Status code: {response.status_code}" + response = requests.get(endpoint, params=params, timeout=10) + # print(f"API Response Status: {response.status_code}") + logging.info("API Response Status for recommendations of paper %s: %s", + paper_id, + response.status_code) + # print(f"Request params: {params}") + logging.info("Request params: %s", params) + + data = response.json() + recommendations = data.get("recommendedPapers", []) + + # Extract paper ID and title from recommendations + filtered_papers = [ + {"Paper ID": paper["paperId"], "Title": paper["title"]} + for paper in recommendations + if paper.get("title") and paper.get("paperId") + ] + + # Create a DataFrame for pretty printing + df = pd.DataFrame(filtered_papers) + + papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in filtered_papers + ] + + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(tablefmt="grid") + + return Command( + update={ + "papers": papers, + "messages": [ + ToolMessage( + content=markdown_table, tool_call_id=tool_call_id + ) + ], + } ) diff --git a/aiagents4pharma/talk2competitors/utils/__init__.py b/aiagents4pharma/talk2competitors/utils/__init__.py deleted file mode 100644 index beaf8977..00000000 --- a/aiagents4pharma/talk2competitors/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# __init__.py - File description diff --git a/aiagents4pharma/talk2competitors/utils/llm.py b/aiagents4pharma/talk2competitors/utils/llm.py deleted file mode 100644 index c7f0fa84..00000000 --- a/aiagents4pharma/talk2competitors/utils/llm.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from typing import Any, Dict, List - -from dotenv import load_dotenv -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_openai import ChatOpenAI - -from config.config import config -from state.shared_state import shared_state - -# Load environment variables from .env file -load_dotenv() - - -# Load environment variables from .env file -load_dotenv() - -# Get API key from environment -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -if not OPENAI_API_KEY: - raise ValueError( - "OPENAI_API_KEY not found in environment variables. Please check your .env file." - ) - - -def create_llm() -> ChatOpenAI: - """Create and configure OpenAI LLM instance""" - return ChatOpenAI( - model="gpt-4o-mini", # Using gpt-4o-mini as requested - temperature=config.TEMPERATURE, - timeout=60, # Timeout in seconds - max_retries=3, - api_key=OPENAI_API_KEY, # Explicitly passing API key - top_p=0.95, # Moved out of model_kwargs - presence_penalty=0, # Moved out of model_kwargs - frequency_penalty=0, # Moved out of model_kwargs - ) - - -class LLMManager: - def __init__(self): - # Initialize the LLM with default configuration - self.llm = create_llm() - - def get_response( - self, - system_prompt: str, - user_input: str, - additional_context: Dict[str, Any] = None, - include_history: bool = True, - ) -> str: - """Get response from LLM with system prompt and user input""" - try: - # Create messages list - messages = [ - SystemMessage(content=system_prompt), - HumanMessage(content=user_input), - ] - - # Add chat history if requested - if include_history: - history = shared_state.get_chat_history(limit=3) - for msg in history: - if msg["role"] == "assistant": - messages.append(AIMessage(content=msg["content"])) - elif msg["role"] == "user": - messages.append(HumanMessage(content=msg["content"])) - - # Add debug logging - print("\nDebug - LLM Input:") - print(f"System prompt: {system_prompt[:200]}...") - print(f"User input: {user_input}") - - # Get response with retries - response = self.llm.invoke(messages) - - # Add debug logging - print("\nDebug - LLM Response:") - print(f"Raw response: {response.content}") - - # Log token usage if available - if hasattr(response, "usage_metadata"): - print(f"Token usage: {response.usage_metadata}") - - if response and response.content.strip(): - return response.content.strip() - - return "" - - except Exception as e: - print(f"Error in get_response: {str(e)}") - return "" - - def bind_tools(self, tools: List[Any], strict: bool = True) -> None: - """Bind tools to the LLM for function/tool calling with strict mode enabled""" - self.llm = self.llm.bind_tools( - tools, - tool_choice="auto", # Let the model decide which tool to use - strict=strict, # Enable strict mode for better schema validation - ) - - -# Create a global instance -llm_manager = LLMManager() diff --git a/app/frontend/streamlit_app_talk2competitors.py b/app/frontend/streamlit_app_talk2competitors.py new file mode 100644 index 00000000..7e536ad8 --- /dev/null +++ b/app/frontend/streamlit_app_talk2competitors.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 + +''' +Talk2Competitors: A Streamlit app for the Talk2Competitors graph. +''' + +import os +import sys +import random +import streamlit as st +from streamlit_feedback import streamlit_feedback +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +from langchain_core.messages import ChatMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.tracers.context import collect_runs +from langchain.callbacks.tracers import LangChainTracer +from langsmith import Client +sys.path.append('./') +from aiagents4pharma.talk2competitors.agents.main_agent import get_app + +st.set_page_config(page_title="Talk2Competitors", page_icon="🤖", layout="wide") + +# Check if env variable OPENAI_API_KEY exists +if "OPENAI_API_KEY" not in os.environ: + st.error("Please set the OPENAI_API_KEY environment \ + variable in the terminal where you run the app.") + st.stop() + +# Create a chat prompt template +prompt = ChatPromptTemplate.from_messages([ + ("system", "Welcome to Talk2Competitors!"), + MessagesPlaceholder(variable_name='chat_history', optional=True), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), +]) + +# Initialize chat history +if "messages" not in st.session_state: + st.session_state.messages = [] + +# Initialize project_name for Langsmith +if "project_name" not in st.session_state: + # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4()) + st.session_state.project_name = 'Talk2Competitors-' + str(random.randint(1000, 9999)) + +# Initialize run_id for Langsmith +if "run_id" not in st.session_state: + st.session_state.run_id = None + +# Initialize graph +if "unique_id" not in st.session_state: + st.session_state.unique_id = random.randint(1, 1000) +if "app" not in st.session_state: + # st.session_state.app = get_app(st.session_state.unique_id) + if "llm_model" not in st.session_state: + st.session_state.app = get_app(st.session_state.unique_id) + else: + st.session_state.app = get_app(st.session_state.unique_id, + llm_model=st.session_state.llm_model) + +# Get the app +app = st.session_state.app + +def _submit_feedback(user_response): + ''' + Function to submit feedback to the developers. + ''' + client = Client() + client.create_feedback( + st.session_state.run_id, + key="feedback", + score=1 if user_response['score'] == "👍" else 0, + comment=user_response['text'] + ) + st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀") + +@st.dialog("Warning ⚠️") +def update_llm_model(): + """ + Function to update the LLM model. + """ + llm_model = st.session_state.llm_model + st.warning(f"Clicking 'Continue' will reset all agents, \ + set the selected LLM to {llm_model}. \ + This action will reset the entire app, \ + and agents will lose access to the \ + conversation history. Are you sure \ + you want to proceed?") + if st.button("Continue"): + # Delete all the messages and the app key + for key in st.session_state.keys(): + if key in ["messages", "app"]: + del st.session_state[key] + +# Main layout of the app split into two columns +main_col1, main_col2 = st.columns([3, 7]) +# First column +with main_col1: + with st.container(border=True): + # Title + st.write(""" +

+ 🤖 Talk2Competitors +

+ """, + unsafe_allow_html=True) + + # LLM panel (Only at the front-end for now) + llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] + # llm_option = st.selectbox( + # "Pick an LLM to power the agent", + # llms, + # index=0, + # key="st_selectbox_llm" + # ) + st.selectbox( + "Pick an LLM to power the agent", + llms, + index=0, + key="llm_model", + on_change=update_llm_model + ) + + # Upload files (placeholder) + # uploaded_file = st.file_uploader( + # "Upload sequencing data", + # accept_multiple_files=False, + # type=["h5ad"], + # help='''Upload a single h5ad file containing the sequencing data. + # The file should be in the AnnData format.''' + # ) + + with st.container(border=False, height=500): + prompt = st.chat_input("Say something ...", key="st_chat_input") + +# Second column +with main_col2: + # Chat history panel + with st.container(border=True, height=575): + st.write("#### 💬 Chat History") + + # Display chat messages + for count, message in enumerate(st.session_state.messages): + with st.chat_message(message["content"].role, + avatar="🤖" + if message["content"].role != 'user' + else "👩🏻‍💻"): + st.markdown(message["content"].content) + st.empty() + + # When the user asks a question + if prompt: + # Create a key 'uploaded_file' to read the uploaded file + # if uploaded_file: + # st.session_state.article_pdf = uploaded_file.read().decode("utf-8") + + # Display user prompt + prompt_msg = ChatMessage(prompt, role="user") + st.session_state.messages.append( + { + "type": "message", + "content": prompt_msg + } + ) + with st.chat_message("user", avatar="👩🏻‍💻"): + st.markdown(prompt) + st.empty() + + with st.chat_message("assistant", avatar="🤖"): + # with st.spinner("Fetching response ..."): + with st.spinner(): + # Get chat history + history = [(m["content"].role, m["content"].content) + for m in st.session_state.messages + if m["type"] == "message"] + # Convert chat history to ChatMessage objects + chat_history = [ + SystemMessage(content=m[1]) if m[0] == "system" else + HumanMessage(content=m[1]) if m[0] == "human" else + AIMessage(content=m[1]) + for m in history + ] + + # Create config for the agent + config = {"configurable": {"thread_id": st.session_state.unique_id}} + + # Update the agent state with the selected LLM model + current_state = app.get_state(config) + app.update_state( + config, + {"llm_model": st.session_state.llm_model} + ) + + with collect_runs() as cb: + # Add Langsmith tracer + tracer = LangChainTracer( + project_name=st.session_state.project_name + ) + # Get response from the agent + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config|{"callbacks": [tracer]} + ) + st.session_state.run_id = cb.traced_runs[-1].id + # Print the response + # print (response) + + # Add assistant response to chat history + assistant_msg = ChatMessage(response["messages"][-1].content, + role="assistant") + st.session_state.messages.append({ + "type": "message", + "content": assistant_msg + }) + # Display the response in the chat + st.markdown(response["messages"][-1].content) + st.empty() + # Collect feedback and display the thumbs feedback + if st.session_state.get("run_id"): + feedback = streamlit_feedback( + feedback_type="thumbs", + optional_text_label="[Optional] Please provide an explanation", + on_submit=_submit_feedback, + key=f"feedback_{st.session_state.run_id}" + ) diff --git a/aiagents4pharma/talk2competitors/app/talk2comp.py b/app/frontend/streamlit_app_talk2competitors2.py similarity index 97% rename from aiagents4pharma/talk2competitors/app/talk2comp.py rename to app/frontend/streamlit_app_talk2competitors2.py index d50442b3..50deebbb 100644 --- a/aiagents4pharma/talk2competitors/app/talk2comp.py +++ b/app/frontend/streamlit_app_talk2competitors2.py @@ -7,11 +7,12 @@ import random import sys -sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") import streamlit as st from langchain_core.messages import ChatMessage, HumanMessage -from agents.main_agent import get_app +# from agents.main_agent import get_app +sys.path.append('./') +from aiagents4pharma.talk2competitors.agents.main_agent import get_app # Page configuration st.set_page_config(page_title="talk2comp", page_icon="📚", layout="wide") From 898c0d12721ba06261b3969728a6df9a165ed96b Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Wed, 22 Jan 2025 19:00:31 +0100 Subject: [PATCH 19/19] feat: update --- .../app/.streamlit/config.toml | 2 - .../talk2competitors/config/config.py | 41 ------------------- 2 files changed, 43 deletions(-) delete mode 100644 aiagents4pharma/talk2competitors/app/.streamlit/config.toml diff --git a/aiagents4pharma/talk2competitors/app/.streamlit/config.toml b/aiagents4pharma/talk2competitors/app/.streamlit/config.toml deleted file mode 100644 index 49f7885b..00000000 --- a/aiagents4pharma/talk2competitors/app/.streamlit/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[theme] -base = "light" diff --git a/aiagents4pharma/talk2competitors/config/config.py b/aiagents4pharma/talk2competitors/config/config.py index 8b955763..a611dd43 100644 --- a/aiagents4pharma/talk2competitors/config/config.py +++ b/aiagents4pharma/talk2competitors/config/config.py @@ -1,51 +1,10 @@ class Config: - # LLM Configuration - LLM_MODEL = "gpt-4o-mini" # Updated to GPT-4-mini - TEMPERATURE = 0.7 - # API Endpoints SEMANTIC_SCHOLAR_API = "https://api.semanticscholar.org/graph/v1" # API Keys SEMANTIC_SCHOLAR_API_KEY = "YOUR_API_KEY" # Get this from Semantic Scholar - # State Keys - class StateKeys: - PAPERS = "papers" - SELECTED_PAPERS = "selected_papers" - CURRENT_TOOL = "current_tool" - CURRENT_AGENT = "current_agent" - RESPONSE = "response" - ERROR = "error" - CHAT_HISTORY = "chat_history" - USER_INFO = "user_info" - MEMORY = "memory" - - # Agent Names - class AgentNames: - MAIN = "main_agent" - S2 = "semantic_scholar_agent" - ZOTERO = "zotero_agent" - PDF = "pdf_agent" - ARXIV = "arxiv_agent" - - # Tool Names (Keeping for reference) - class ToolNames: - # S2 Tools - S2_SEARCH = "search_papers" - S2_SINGLE_REC = "single_paper_recommendation" - S2_MULTI_REC = "multi_paper_recommendation" - - # Zotero Tools - ZOTERO_READ = "zotero_read" - ZOTERO_WRITE = "zotero_write" - - # PDF Tools - PDF_RAG = "pdf_rag" - - # arXiv Tools - ARXIV_DOWNLOAD = "arxiv_download" - # Updated System Prompts MAIN_AGENT_PROMPT = """You are a supervisory AI agent that routes user queries to specialized tools. Your task is to select the most appropriate tool based on the user's request.