diff --git a/aiagents4pharma/talk2competitors/agents/__init__.py b/aiagents4pharma/talk2competitors/agents/__init__.py new file mode 100644 index 00000000..3423f265 --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/__init__.py @@ -0,0 +1,6 @@ +''' +This file is used to import all the modules in the package. +''' + +from . import main_agent +from . import s2_agent diff --git a/aiagents4pharma/talk2competitors/agents/main_agent.py b/aiagents4pharma/talk2competitors/agents/main_agent.py new file mode 100644 index 00000000..f72257dd --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/main_agent.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 + +""" +Main agent for the talk2competitors app. +""" + +import logging +from typing import Literal +from dotenv import load_dotenv +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.types import Command +from ..agents import s2_agent +from ..config.config import config +from ..state.state_talk2competitors import Talk2Competitors + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +load_dotenv() + +def make_supervisor_node(llm: BaseChatModel) -> str: + """ + Creates a supervisor node following LangGraph patterns. + + Args: + llm (BaseChatModel): The language model to use for generating responses. + + Returns: + str: The supervisor node function. + """ + # options = ["FINISH", "s2_agent"] + + def supervisor_node(state: Talk2Competitors) -> Command[Literal["s2_agent", "__end__"]]: + """ + Supervisor node that routes to appropriate sub-agents. + + Args: + state (Talk2Competitors): The current state of the conversation. + + Returns: + Command[Literal["s2_agent", "__end__"]]: The command to execute next. + """ + logger.info("Supervisor node called") + + messages = [{"role": "system", "content": config.MAIN_AGENT_PROMPT}] + state[ + "messages" + ] + response = llm.invoke(messages) + goto = ( + "FINISH" + if not any( + kw in state["messages"][-1].content.lower() + for kw in ["search", "paper", "find"] + ) + else "s2_agent" + ) + + if goto == "FINISH": + return Command( + goto=END, + update={ + "messages": state["messages"] + + [AIMessage(content=response.content)], + "is_last_step": True, + "current_agent": None, + }, + ) + + return Command( + goto="s2_agent", + update={ + "messages": state["messages"], + "is_last_step": False, + "current_agent": "s2_agent", + }, + ) + + return supervisor_node + +def get_app(thread_id: str, llm_model ='gpt-4o-mini') -> StateGraph: + """ + Returns the langraph app with hierarchical structure. + + Args: + thread_id (str): The thread ID for the conversation. + + Returns: + The compiled langraph app. + """ + def call_s2_agent(state: Talk2Competitors) -> Command[Literal["__end__"]]: + """ + Node for calling the S2 agent. + + Args: + state (Talk2Competitors): The current state of the conversation. + + Returns: + Command[Literal["__end__"]]: The command to execute next. + """ + logger.info("Calling S2 agent") + app = s2_agent.get_app(thread_id, llm_model) + response = app.invoke(state) + logger.info("S2 agent completed") + return Command( + goto=END, + update={ + "messages": response["messages"], + "papers": response.get("papers", []), + "is_last_step": True, + "current_agent": "s2_agent", + }, + ) + llm = ChatOpenAI(model=llm_model, temperature=0) + workflow = StateGraph(Talk2Competitors) + + supervisor = make_supervisor_node(llm) + workflow.add_node("supervisor", supervisor) + workflow.add_node("s2_agent", call_s2_agent) + + # Define edges + workflow.add_edge(START, "supervisor") + workflow.add_edge("s2_agent", END) + + app = workflow.compile(checkpointer=MemorySaver()) + logger.info("Main agent workflow compiled") + return app diff --git a/aiagents4pharma/talk2competitors/agents/s2_agent.py b/aiagents4pharma/talk2competitors/agents/s2_agent.py new file mode 100644 index 00000000..c090862a --- /dev/null +++ b/aiagents4pharma/talk2competitors/agents/s2_agent.py @@ -0,0 +1,75 @@ +#/usr/bin/env python3 + +''' +Agent for interacting with Semantic Scholar +''' + +import logging +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI +from langgraph.graph import START, StateGraph +from langgraph.prebuilt import create_react_agent +from langgraph.checkpoint.memory import MemorySaver +from ..config.config import config +from ..state.state_talk2competitors import Talk2Competitors +# from ..tools.s2 import s2_tools +from ..tools.s2.search import search_tool +from ..tools.s2.display_results import display_results +from ..tools.s2.single_paper_rec import get_single_paper_recommendations +from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations + +load_dotenv() + +# Initialize logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def get_app(uniq_id, llm_model='gpt-4o-mini'): + ''' + This function returns the langraph app. + ''' + def agent_s2_node(state: Talk2Competitors): + ''' + This function calls the model. + ''' + logger.log(logging.INFO, "Creating Agent_S2 node with thread_id %s", uniq_id) + response = model.invoke(state, {"configurable": {"thread_id": uniq_id}}) + return response + + # Define the tools + tools = [search_tool, + display_results, + get_single_paper_recommendations, + get_multi_paper_recommendations] + + # Create the LLM + llm = ChatOpenAI(model=llm_model, temperature=0) + model = create_react_agent( + llm, + tools=tools, + state_schema=Talk2Competitors, + state_modifier=config.S2_AGENT_PROMPT, + checkpointer=MemorySaver() + ) + + # Define a new graph + workflow = StateGraph(Talk2Competitors) + + # Define the two nodes we will cycle between + workflow.add_node("agent_s2", agent_s2_node) + + # Set the entrypoint as `agent` + # This means that this node is the first one called + workflow.add_edge(START, "agent_s2") + + # Initialize memory to persist state between graph runs + checkpointer = MemorySaver() + + # Finally, we compile it! + # This compiles it into a LangChain Runnable, + # meaning you can use it as you would any other runnable. + # Note that we're (optionally) passing the memory when compiling the graph + app = workflow.compile(checkpointer=checkpointer) + logger.log(logging.INFO, "Compiled the graph") + + return app diff --git a/aiagents4pharma/talk2competitors/config/config.py b/aiagents4pharma/talk2competitors/config/config.py new file mode 100644 index 00000000..a611dd43 --- /dev/null +++ b/aiagents4pharma/talk2competitors/config/config.py @@ -0,0 +1,142 @@ +class Config: + # API Endpoints + SEMANTIC_SCHOLAR_API = "https://api.semanticscholar.org/graph/v1" + + # API Keys + SEMANTIC_SCHOLAR_API_KEY = "YOUR_API_KEY" # Get this from Semantic Scholar + + # Updated System Prompts + MAIN_AGENT_PROMPT = """You are a supervisory AI agent that routes user queries to specialized tools. +Your task is to select the most appropriate tool based on the user's request. + +Available tools and their capabilities: + +1. semantic_scholar_agent: + - Search for academic papers and research + - Get paper recommendations + - Find similar papers + USE FOR: Any queries about finding papers, academic research, or getting paper recommendations + +2. zotero_agent: + - Manage paper library + - Save and organize papers + USE FOR: Saving papers or managing research library + +3. pdf_agent: + - Analyze PDF content + - Answer questions about documents + USE FOR: Analyzing or asking questions about PDF content + +4. arxiv_agent: + - Download papers from arXiv + USE FOR: Specifically downloading papers from arXiv + +ROUTING GUIDELINES: + +ALWAYS route to semantic_scholar_agent for: +- Finding academic papers +- Searching research topics +- Getting paper recommendations +- Finding similar papers +- Any query about academic literature + +Route to zotero_agent for: +- Saving papers to library +- Managing references +- Organizing research materials + +Route to pdf_agent for: +- PDF content analysis +- Document-specific questions +- Understanding paper contents + +Route to arxiv_agent for: +- Paper download requests +- arXiv-specific queries + +Approach: +1. Identify the core need in the user's query +2. Select the most appropriate tool based on the guidelines above +3. If unclear, ask for clarification +4. For multi-step tasks, focus on the immediate next step + +Remember: +- Be decisive in your tool selection +- Focus on the immediate task +- Default to semantic_scholar_agent for any paper-finding tasks +- Ask for clarification if the request is ambiguous + +IMPORTANT GUIDELINES FOR PAPER RECOMMENDATIONS: + +For Multiple Papers: +- When getting recommendations for multiple papers, always use get_multi_paper_recommendations tool +- DO NOT call get_single_paper_recommendations multiple times +- Always pass all paper IDs in a single call to get_multi_paper_recommendations +- Use for queries like "find papers related to both/all papers" or "find similar papers to these papers" + +For Single Paper: +- Use get_single_paper_recommendations when focusing on one specific paper +- Pass only one paper ID at a time +- Use for queries like "find papers similar to this paper" or "get recommendations for paper X" +- Do not use for multiple papers + +Examples: +- For "find related papers for both papers": + ✓ Use get_multi_paper_recommendations with both paper IDs + × Don't make multiple calls to get_single_paper_recommendations + +- For "find papers related to the first paper": + ✓ Use get_single_paper_recommendations with just that paper's ID + × Don't use get_multi_paper_recommendations + +Remember: +- Be precise in identifying which paper ID to use for single recommendations +- Don't reuse previous paper IDs unless specifically requested +- For fresh paper recommendations, always use the original paper ID""" + + S2_AGENT_PROMPT = """You are a specialized academic research assistant with access to the following tools: + +1. search_papers: + USE FOR: General paper searches + - Enhances search terms automatically + - Adds relevant academic keywords + - Focuses on recent research when appropriate + +2. get_single_paper_recommendations: + USE FOR: Finding papers similar to a specific paper + - Takes a single paper ID + - Returns related papers + +3. get_multi_paper_recommendations: + USE FOR: Finding papers similar to multiple papers + - Takes multiple paper IDs + - Finds papers related to all inputs + +GUIDELINES: + +For paper searches: +- Enhance search terms with academic language +- Include field-specific terminology +- Add "recent" or "latest" when appropriate +- Keep queries focused and relevant + +For paper recommendations: +- Identify paper IDs (40-character hexadecimal strings) +- Use single_paper_recommendations for one ID +- Use multi_paper_recommendations for multiple IDs + +Best practices: +1. Start with a broad search if no paper IDs are provided +2. Look for paper IDs in user input +3. Enhance search terms for better results +4. Consider the academic context +5. Be prepared to refine searches based on feedback + +Remember: +- Always select the most appropriate tool +- Enhance search queries naturally +- Consider academic context +- Focus on delivering relevant results""" + + +config = Config() diff --git a/aiagents4pharma/talk2competitors/state/__init__.py b/aiagents4pharma/talk2competitors/state/__init__.py new file mode 100644 index 00000000..8cbeabf3 --- /dev/null +++ b/aiagents4pharma/talk2competitors/state/__init__.py @@ -0,0 +1,5 @@ +''' +This file is used to import all the modules in the package. +''' + +from . import state_talk2competitors diff --git a/aiagents4pharma/talk2competitors/state/state_talk2competitors.py b/aiagents4pharma/talk2competitors/state/state_talk2competitors.py new file mode 100644 index 00000000..3f800e6c --- /dev/null +++ b/aiagents4pharma/talk2competitors/state/state_talk2competitors.py @@ -0,0 +1,31 @@ +""" +This is the state file for the talk2comp agent. +""" + +import logging +from typing import Annotated, List, Optional +from langgraph.prebuilt.chat_agent_executor import AgentState +from typing_extensions import NotRequired, Required + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def replace_list(existing: List[str], new: List[str]) -> List[str]: + """Replace the existing list with the new one.""" + logger.info("Updating existing state %s with the state list: %s", + existing, + new) + return new + + +class Talk2Competitors(AgentState): + """ + The state for the talk2comp agent, inheriting from AgentState. + """ + papers: Annotated[List[str], replace_list] + search_table: NotRequired[str] + next: str # Required for routing in LangGraph + current_agent: NotRequired[Optional[str]] + is_last_step: Required[bool] # Required field for LangGraph + llm_model: str diff --git a/aiagents4pharma/talk2competitors/tests/__init__.py b/aiagents4pharma/talk2competitors/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiagents4pharma/talk2competitors/tests/test_langgraph.py b/aiagents4pharma/talk2competitors/tests/test_langgraph.py new file mode 100644 index 00000000..27e0597e --- /dev/null +++ b/aiagents4pharma/talk2competitors/tests/test_langgraph.py @@ -0,0 +1,92 @@ +''' +Test cases +''' + +from langchain_core.messages import HumanMessage +from ..agents.main_agent import get_app + +def test_main_agent(): + ''' + Test the main agent. + ''' + unique_id = "test_12345" + app = get_app(unique_id) + config = {"configurable": {"thread_id": unique_id}} + #################################################### + prompt = "Without calling any tool, tell me the capital of France" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + # Check if the assistant message is a string + assert 'Paris' in assistant_msg + #################################################### + prompt = "Search articles on machine learning" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + # Check if the assistant message is a string + assert 'Fashion-MNIST' in assistant_msg + #################################################### + prompt = "Recommend articles using the first paper of the previous search" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'CNN Models' in assistant_msg + #################################################### + prompt = "Recommend articles using both papers of your last response" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'Efficient Handwritten Digit Classification' in assistant_msg + ################################################### + prompt = "Show me the papers in the state" + response = app.invoke( + { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "is_last_step": False, + "current_agent": None, + }, + config=config, + ) + + assistant_msg = response["messages"][-1].content + print (assistant_msg) + # Check if the assistant message is a string + assert 'Classification of Fashion-MNIST Dataset' in assistant_msg diff --git a/aiagents4pharma/talk2competitors/tools/__init__.py b/aiagents4pharma/talk2competitors/tools/__init__.py new file mode 100644 index 00000000..c69aba29 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +''' +Import statements +''' + +from . import s2 diff --git a/aiagents4pharma/talk2competitors/tools/s2/__init__.py b/aiagents4pharma/talk2competitors/tools/s2/__init__.py new file mode 100644 index 00000000..f809aa36 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/__init__.py @@ -0,0 +1,8 @@ +''' +This file is used to import all the modules in the package. +''' + +from . import display_results +from . import multi_paper_rec +from . import search +from . import single_paper_rec diff --git a/aiagents4pharma/talk2competitors/tools/s2/display_results.py b/aiagents4pharma/talk2competitors/tools/s2/display_results.py new file mode 100644 index 00000000..1c06dde9 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/display_results.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +''' +This tool is used to display the table of studies. +''' + +import logging +from typing import Annotated +from langchain_core.tools import tool +from langgraph.prebuilt import InjectedState + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@tool('display_results') +def display_results(state: Annotated[dict, InjectedState]): + """ + Display the papers in the state. + + Args: + state (dict): The state of the agent. + """ + logger.info("Displaying papers from the state") + return state["papers"] diff --git a/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py new file mode 100644 index 00000000..58174a3f --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +""" +multi_paper_rec: Tool for getting recommendations + based on multiple papers +""" + +import logging +import json +from typing import Annotated, Any, Dict, List +import pandas as pd +import requests +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command +from pydantic import BaseModel, Field + +class MultiPaperRecInput(BaseModel): + """Input schema for multiple paper recommendations tool.""" + + paper_ids: List[str] = Field( + description=("List of Semantic Scholar Paper IDs to get recommendations for") + ) + limit: int = Field( + default=2, + description="Maximum total number of recommendations to return", + ge=1, + le=500, + ) + tool_call_id: Annotated[str, InjectedToolCallId] + + model_config = {"arbitrary_types_allowed": True} + + +@tool(args_schema=MultiPaperRecInput) +def get_multi_paper_recommendations( + paper_ids: List[str], + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, +) -> Dict[str, Any]: + """ + Get paper recommendations based on multiple papers. + + Args: + paper_ids (List[str]): The list of paper IDs to base recommendations on. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of recommendations to return. Defaults to 2. + + Returns: + Dict[str, Any]: The recommendations and related information. + """ + logging.info("Starting multi-paper recommendations search.") + + endpoint = "https://api.semanticscholar.org/recommendations/v1/papers" + headers = {"Content-Type": "application/json"} + payload = {"positivePaperIds": paper_ids, "negativePaperIds": []} + params = {"limit": min(limit, 500), "fields": "title,paperId"} + + # Getting recommendations + response = requests.post( + endpoint, + headers=headers, + params=params, + data=json.dumps(payload), + timeout=10, + ) + logging.info("API Response Status for multi-paper recommendations: %s", + response.status_code) + + data = response.json() + recommendations = data.get("recommendedPapers", []) + + # Create a list to store the papers + papers_list = [] + for paper in recommendations: + if paper.get("title") and paper.get("paperId"): + papers_list.append( + {"Paper ID": paper["paperId"], "Title": paper["title"]} + ) + + # Create a DataFrame from the list of papers + df = pd.DataFrame(papers_list) + # print("Created DataFrame with results:") + logging.info("Created DataFrame with results: %s", df) + + # Format papers for state update + formatted_papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in papers_list + ] + + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(tablefmt="grid") + return Command( + update={ + "papers": formatted_papers, + "messages": [ + ToolMessage( + content=markdown_table, tool_call_id=tool_call_id + ) + ], + } + ) diff --git a/aiagents4pharma/talk2competitors/tools/s2/search.py b/aiagents4pharma/talk2competitors/tools/s2/search.py new file mode 100644 index 00000000..31703885 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/search.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +""" +This tool is used to search for academic papers on Semantic Scholar. +""" + +import logging +from typing import Annotated, Any, Dict +import pandas as pd +import requests +from langchain_core.messages import AIMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from pydantic import BaseModel, Field + +from ...config.config import config + +class SearchInput(BaseModel): + """Input schema for the search papers tool.""" + + query: str = Field( + description="Search query string to find academic papers." + "Be specific and include relevant academic terms." + ) + limit: int = Field( + default=2, description="Maximum number of results to return", ge=1, le=100 + ) + tool_call_id: Annotated[str, InjectedToolCallId] + +@tool(args_schema=SearchInput) +def search_tool( + query: str, + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, +) -> Dict[str, Any]: + """ + Search for academic papers on Semantic Scholar. + + Args: + query (str): The search query string to find academic papers. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of results to return. Defaults to 2. + + Returns: + Dict[str, Any]: The search results and related information. + """ + print("Starting paper search...") + endpoint = f"{config.SEMANTIC_SCHOLAR_API}/paper/search" + params = { + "query": query, + "limit": min(limit, 100), + "fields": "paperId,title,abstract,year,authors,citationCount,openAccessPdf", + } + response = requests.get(endpoint, params=params, timeout=10) + data = response.json() + papers = data.get("data", []) + + filtered_papers = [ + {"Paper ID": paper["paperId"], "Title": paper["title"]} + for paper in papers + if paper.get("title") and paper.get("authors") + ] + + df = pd.DataFrame(filtered_papers) + + papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in filtered_papers + ] + + markdown_table = df.to_markdown(tablefmt="grid") + logging.info("Search results: %s", papers) + + return { + "papers": papers, + "messages": [AIMessage(content=markdown_table)], + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": { + "name": "search_tool", + "arguments": {"query": query, "limit": limit}, + }, + } + ], + } diff --git a/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py new file mode 100644 index 00000000..96f47a98 --- /dev/null +++ b/aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +''' +This tool is used to return recommendations for a single paper. +''' + +import logging +from typing import Annotated, Any, Dict +import pandas as pd +import requests +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langchain_core.tools.base import InjectedToolCallId +from langgraph.types import Command +from pydantic import BaseModel, Field + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class SinglePaperRecInput(BaseModel): + """Input schema for single paper recommendation tool.""" + + paper_id: str = Field( + description="Semantic Scholar Paper ID to get recommendations for (40-character string)" + ) + limit: int = Field( + default=2, + description="Maximum number of recommendations to return", + ge=1, + le=500, + ) + tool_call_id: Annotated[str, InjectedToolCallId] + model_config = {"arbitrary_types_allowed": True} + + +@tool(args_schema=SinglePaperRecInput) +def get_single_paper_recommendations( + paper_id: str, + tool_call_id: Annotated[str, InjectedToolCallId], + limit: int = 2, + ) -> Dict[str, Any]: + """ + Get paper recommendations based on a single paper. + + Args: + paper_id (str): The Semantic Scholar Paper ID to get recommendations for. + tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID. + limit (int, optional): The maximum number of recommendations to return. Defaults to 2. + + Returns: + Dict[str, Any]: The recommendations and related information. + """ + logger.info("Starting single paper recommendations search.") + + endpoint = ( + f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}" + ) + params = { + "limit": min(limit, 500), # Max 500 per API docs + "fields": "title,paperId,abstract,year", + "from": "all-cs", # Using all-cs pool as specified in docs + } + + response = requests.get(endpoint, params=params, timeout=10) + # print(f"API Response Status: {response.status_code}") + logging.info("API Response Status for recommendations of paper %s: %s", + paper_id, + response.status_code) + # print(f"Request params: {params}") + logging.info("Request params: %s", params) + + data = response.json() + recommendations = data.get("recommendedPapers", []) + + # Extract paper ID and title from recommendations + filtered_papers = [ + {"Paper ID": paper["paperId"], "Title": paper["title"]} + for paper in recommendations + if paper.get("title") and paper.get("paperId") + ] + + # Create a DataFrame for pretty printing + df = pd.DataFrame(filtered_papers) + + papers = [ + f"Paper ID: {paper['Paper ID']}\nTitle: {paper['Title']}" + for paper in filtered_papers + ] + + # Convert DataFrame to markdown table + markdown_table = df.to_markdown(tablefmt="grid") + + return Command( + update={ + "papers": papers, + "messages": [ + ToolMessage( + content=markdown_table, tool_call_id=tool_call_id + ) + ], + } + ) diff --git a/app/frontend/streamlit_app_talk2competitors.py b/app/frontend/streamlit_app_talk2competitors.py new file mode 100644 index 00000000..7e536ad8 --- /dev/null +++ b/app/frontend/streamlit_app_talk2competitors.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 + +''' +Talk2Competitors: A Streamlit app for the Talk2Competitors graph. +''' + +import os +import sys +import random +import streamlit as st +from streamlit_feedback import streamlit_feedback +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +from langchain_core.messages import ChatMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.tracers.context import collect_runs +from langchain.callbacks.tracers import LangChainTracer +from langsmith import Client +sys.path.append('./') +from aiagents4pharma.talk2competitors.agents.main_agent import get_app + +st.set_page_config(page_title="Talk2Competitors", page_icon="🤖", layout="wide") + +# Check if env variable OPENAI_API_KEY exists +if "OPENAI_API_KEY" not in os.environ: + st.error("Please set the OPENAI_API_KEY environment \ + variable in the terminal where you run the app.") + st.stop() + +# Create a chat prompt template +prompt = ChatPromptTemplate.from_messages([ + ("system", "Welcome to Talk2Competitors!"), + MessagesPlaceholder(variable_name='chat_history', optional=True), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), +]) + +# Initialize chat history +if "messages" not in st.session_state: + st.session_state.messages = [] + +# Initialize project_name for Langsmith +if "project_name" not in st.session_state: + # st.session_state.project_name = str(st.session_state.user_name) + '@' + str(uuid.uuid4()) + st.session_state.project_name = 'Talk2Competitors-' + str(random.randint(1000, 9999)) + +# Initialize run_id for Langsmith +if "run_id" not in st.session_state: + st.session_state.run_id = None + +# Initialize graph +if "unique_id" not in st.session_state: + st.session_state.unique_id = random.randint(1, 1000) +if "app" not in st.session_state: + # st.session_state.app = get_app(st.session_state.unique_id) + if "llm_model" not in st.session_state: + st.session_state.app = get_app(st.session_state.unique_id) + else: + st.session_state.app = get_app(st.session_state.unique_id, + llm_model=st.session_state.llm_model) + +# Get the app +app = st.session_state.app + +def _submit_feedback(user_response): + ''' + Function to submit feedback to the developers. + ''' + client = Client() + client.create_feedback( + st.session_state.run_id, + key="feedback", + score=1 if user_response['score'] == "👍" else 0, + comment=user_response['text'] + ) + st.info("Your feedback is on its way to the developers. Thank you!", icon="🚀") + +@st.dialog("Warning ⚠️") +def update_llm_model(): + """ + Function to update the LLM model. + """ + llm_model = st.session_state.llm_model + st.warning(f"Clicking 'Continue' will reset all agents, \ + set the selected LLM to {llm_model}. \ + This action will reset the entire app, \ + and agents will lose access to the \ + conversation history. Are you sure \ + you want to proceed?") + if st.button("Continue"): + # Delete all the messages and the app key + for key in st.session_state.keys(): + if key in ["messages", "app"]: + del st.session_state[key] + +# Main layout of the app split into two columns +main_col1, main_col2 = st.columns([3, 7]) +# First column +with main_col1: + with st.container(border=True): + # Title + st.write(""" +

+ 🤖 Talk2Competitors +

+ """, + unsafe_allow_html=True) + + # LLM panel (Only at the front-end for now) + llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] + # llm_option = st.selectbox( + # "Pick an LLM to power the agent", + # llms, + # index=0, + # key="st_selectbox_llm" + # ) + st.selectbox( + "Pick an LLM to power the agent", + llms, + index=0, + key="llm_model", + on_change=update_llm_model + ) + + # Upload files (placeholder) + # uploaded_file = st.file_uploader( + # "Upload sequencing data", + # accept_multiple_files=False, + # type=["h5ad"], + # help='''Upload a single h5ad file containing the sequencing data. + # The file should be in the AnnData format.''' + # ) + + with st.container(border=False, height=500): + prompt = st.chat_input("Say something ...", key="st_chat_input") + +# Second column +with main_col2: + # Chat history panel + with st.container(border=True, height=575): + st.write("#### 💬 Chat History") + + # Display chat messages + for count, message in enumerate(st.session_state.messages): + with st.chat_message(message["content"].role, + avatar="🤖" + if message["content"].role != 'user' + else "👩🏻‍💻"): + st.markdown(message["content"].content) + st.empty() + + # When the user asks a question + if prompt: + # Create a key 'uploaded_file' to read the uploaded file + # if uploaded_file: + # st.session_state.article_pdf = uploaded_file.read().decode("utf-8") + + # Display user prompt + prompt_msg = ChatMessage(prompt, role="user") + st.session_state.messages.append( + { + "type": "message", + "content": prompt_msg + } + ) + with st.chat_message("user", avatar="👩🏻‍💻"): + st.markdown(prompt) + st.empty() + + with st.chat_message("assistant", avatar="🤖"): + # with st.spinner("Fetching response ..."): + with st.spinner(): + # Get chat history + history = [(m["content"].role, m["content"].content) + for m in st.session_state.messages + if m["type"] == "message"] + # Convert chat history to ChatMessage objects + chat_history = [ + SystemMessage(content=m[1]) if m[0] == "system" else + HumanMessage(content=m[1]) if m[0] == "human" else + AIMessage(content=m[1]) + for m in history + ] + + # Create config for the agent + config = {"configurable": {"thread_id": st.session_state.unique_id}} + + # Update the agent state with the selected LLM model + current_state = app.get_state(config) + app.update_state( + config, + {"llm_model": st.session_state.llm_model} + ) + + with collect_runs() as cb: + # Add Langsmith tracer + tracer = LangChainTracer( + project_name=st.session_state.project_name + ) + # Get response from the agent + response = app.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config=config|{"callbacks": [tracer]} + ) + st.session_state.run_id = cb.traced_runs[-1].id + # Print the response + # print (response) + + # Add assistant response to chat history + assistant_msg = ChatMessage(response["messages"][-1].content, + role="assistant") + st.session_state.messages.append({ + "type": "message", + "content": assistant_msg + }) + # Display the response in the chat + st.markdown(response["messages"][-1].content) + st.empty() + # Collect feedback and display the thumbs feedback + if st.session_state.get("run_id"): + feedback = streamlit_feedback( + feedback_type="thumbs", + optional_text_label="[Optional] Please provide an explanation", + on_submit=_submit_feedback, + key=f"feedback_{st.session_state.run_id}" + ) diff --git a/app/frontend/streamlit_app_talk2competitors2.py b/app/frontend/streamlit_app_talk2competitors2.py new file mode 100644 index 00000000..50deebbb --- /dev/null +++ b/app/frontend/streamlit_app_talk2competitors2.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +""" +talk2comp: A Streamlit app for academic paper search and recommendations +""" + +import os +import random +import sys + +import streamlit as st +from langchain_core.messages import ChatMessage, HumanMessage + +# from agents.main_agent import get_app +sys.path.append('./') +from aiagents4pharma.talk2competitors.agents.main_agent import get_app + +# Page configuration +st.set_page_config(page_title="talk2comp", page_icon="📚", layout="wide") + +# Styles for fixed bottom input +st.markdown( + """ + +""", + unsafe_allow_html=True, +) + +# Initialize session states +if "messages" not in st.session_state: + st.session_state.messages = [] +if "unique_id" not in st.session_state: + st.session_state.unique_id = random.randint(1, 1000) +if "app" not in st.session_state: + st.session_state.app = get_app(str(st.session_state.unique_id)) + +app = st.session_state.app + +# Check OpenAI API key +if "OPENAI_API_KEY" not in os.environ: + st.error("Please set the OPENAI_API_KEY environment variable.") + st.stop() + +# Main layout +col1, col2 = st.columns([1, 4]) + +# Sidebar for settings +with col1: + st.markdown("### 📚 talk2comp") + llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] + llm_option = st.selectbox( + "Select LLM Model", + llms, + index=0, + ) + +# Main chat area +with col2: + # Container for chat history + chat_container = st.container() + with chat_container: + st.markdown("### 💬 Chat History") + + # Display messages + for message in st.session_state.messages: + with st.chat_message( + message["content"].role, + avatar="🤖" if message["content"].role != "user" else "👤", + ): + st.markdown(message["content"].content) + +# Fixed bottom input +prompt = st.chat_input("Search for papers or ask questions...") + +if prompt: + # Display user message + prompt_msg = ChatMessage(prompt, role="user") + st.session_state.messages.append({"type": "message", "content": prompt_msg}) + + # Get agent response + with st.spinner("Processing your request..."): + # Debug print before invoke + print("Sending request to agent:", prompt) + + # Prepare initial state with all required fields + initial_state = { + "messages": [HumanMessage(content=prompt)], + "papers": [], + "search_table": "", + "next": None, + "current_agent": None, + "is_last_step": False, # Ensure this is included + } + + config = {"configurable": {"thread_id": str(st.session_state.unique_id)}} + os.environ["AIAGENTS4PHARMA_LLM_MODEL"] = llm_option + + response = app.invoke(initial_state, config=config) + + # Debug print response + print("Agent response:", response) + + # Add response to chat history and display + if "messages" in response: + # Get the last AI message + ai_messages = [msg for msg in response["messages"] if msg.type == "ai"] + if ai_messages: + last_ai_message = ai_messages[-1] + + # Format paper results if present + if "papers" in response and response["papers"]: + papers_content = response["papers"] + formatted_message = "Here are the papers I found:\n\n" + + for idx, paper in enumerate(papers_content, start=1): + if isinstance(paper, str): + parts = paper.split("\n") + paper_id = parts[0].replace("Paper ID: ", "").strip() + title = parts[1].replace("Title: ", "").strip() + formatted_message += f"{idx}. **{title}**\n" + formatted_message += f" - Paper ID: {paper_id}\n\n" + else: + # Use the AI message content + formatted_message = last_ai_message.content + + assistant_msg = ChatMessage(formatted_message, role="assistant") + st.session_state.messages.append( + {"type": "message", "content": assistant_msg} + ) + + # Rerun to update display + st.rerun()