Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hackathon/kg model hack Jack Saleh Sandeep (Team Galway) #137

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/notebooks/talk2biomodels/descriptions_output.json

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions docs/notebooks/talk2biomodels/embed_descriptions_search_ncbi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import json
import numpy as np
import pandas as pd
import openai
import faiss
import pickle
from dotenv import load_dotenv

# Load API key from .env file
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

# Define file paths
DESCRIPTIONS_FILE = "descriptions_output.json"
EMBEDDINGS_FILE = "gene_embeddings.pkl"
INDEX_FILE = "faiss_index.bin"
GENE_MAPPING_FILE = "gene_id_mapping.pkl"
OUTPUT_FILE = "species_gene_matches.csv"


def get_openai_embedding(text, model="text-embedding-3-large"):
"""Generate text embeddings using OpenAI's API."""
response = openai.embeddings.create(input=[text], model=model)
return np.array(response.data[0].embedding)


def load_faiss_index():
"""Load FAISS index and gene mapping."""
if not os.path.exists(INDEX_FILE) or not os.path.exists(GENE_MAPPING_FILE):
raise FileNotFoundError("FAISS index or gene mapping file not found. Run the gene extraction script first.")

index = faiss.read_index(INDEX_FILE)
with open(GENE_MAPPING_FILE, "rb") as f:
gene_id_mapping = pickle.load(f)
return index, gene_id_mapping


def search_species_genes(descriptions):
"""Embed species descriptions and find closest gene/protein match using FAISS default L2 distance."""
index, gene_id_mapping = load_faiss_index()
results = []

for species, description in descriptions.items():
query_embedding = get_openai_embedding(description).astype("float32")
distances, indices = index.search(query_embedding.reshape(1, -1), k=1)
best_match_idx = indices[0][0]
best_match_score = distances[0][0]
best_match_gene, ncbi_id = gene_id_mapping[best_match_idx]

print(f"Best match for {species}: {best_match_gene} (NCBI ID: {ncbi_id}, Score: {best_match_score:.4f})")
results.append({"species_name": species, "ncbi_node_id": ncbi_id})

return results


def main():
"""Main function to read descriptions, search for genes, and save results."""
if not os.path.exists(DESCRIPTIONS_FILE):
raise FileNotFoundError("Descriptions JSON file not found!")

with open(DESCRIPTIONS_FILE, "r") as f:
descriptions = json.load(f)

results = search_species_genes(descriptions)
df = pd.DataFrame(results)
df.to_csv(OUTPUT_FILE, index=False)
print(f"Results saved to {OUTPUT_FILE}")


if __name__ == "__main__":
main()
154 changes: 154 additions & 0 deletions docs/notebooks/talk2biomodels/embed_genes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import numpy as np
import pandas as pd
import openai
import faiss
import pickle # To save/load embeddings
from dotenv import load_dotenv
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG

# Load API key from .env file
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

EMBEDDINGS_FILE = "gene_embeddings.pkl" # File to store/reuse embeddings
INDEX_FILE = "faiss_index.bin" # FAISS index storage
GENE_MAPPING_FILE = "gene_id_mapping.pkl" # Mapping of FAISS index to gene details


def get_openai_embedding(text, model="text-embedding-3-large"):
"""Generate text embeddings using OpenAI's API."""
response = openai.embeddings.create(input=[text], model=model)
return np.array(response.data[0].embedding)


def extract_genes_from_go(go_keywords=["immune", "inflammation"]):
"""Extract genes/proteins associated with GO terms (immune/inflammation)."""
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")
primekg_data.load_data()

primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

# STEP 1: Extract GO Terms
go_query = "|".join(go_keywords)
go_terms_df = primekg_nodes[
(primekg_nodes["node_type"].isin(["biological_process", "molecular_function", "cellular_component"])) &
(primekg_nodes["node_name"].str.contains(go_query, case=False, na=False))
]

if go_terms_df.empty:
print(f"No GO terms matching {go_keywords} found in PrimeKG!")
return None

go_term_ids = go_terms_df.index.values
print(f"Found {len(go_term_ids)} GO terms related to {go_keywords}.")

# STEP 2: Extract Gene-GO Relationships
gene_go_edges_df = primekg_edges[
((primekg_edges.head_index.isin(go_term_ids)) & (primekg_edges.tail_type == "gene/protein")) |
((primekg_edges.tail_index.isin(go_term_ids)) & (primekg_edges.head_type == "gene/protein"))
]

if gene_go_edges_df.empty:
print(f"No gene-GO relationships found for {go_keywords}!")
return None

gene_ids = np.unique(
np.concatenate([
gene_go_edges_df[gene_go_edges_df.head_type == "gene/protein"].head_index.unique(),
gene_go_edges_df[gene_go_edges_df.tail_type == "gene/protein"].tail_index.unique()
])
)

print(f"Found {len(gene_ids)} genes/proteins linked to immune/inflammation GO terms.")

genes_df = primekg_nodes[primekg_nodes.index.isin(gene_ids)][["node_index", "node_name", "node_id", "node_type"]]

print(f"Extracted {len(genes_df)} gene/protein nodes from PrimeKG based on immune/inflammation GO terms.")
return genes_df


def compute_embeddings_for_genes(genes_df):
"""Compute and save embeddings for genes, or load from cache if available."""
if os.path.exists(EMBEDDINGS_FILE):
print("Loading precomputed embeddings...")
with open(EMBEDDINGS_FILE, "rb") as f:
return pickle.load(f)

print("🔹 Computing embeddings for all genes/proteins...")
embeddings = {}

for idx, row in genes_df.iterrows():
try:
embedding = get_openai_embedding(row["node_name"])
embeddings[row["node_index"]] = embedding
except Exception as e:
print(f"Failed to embed: {row['node_name']} - {e}")

# Save embeddings
with open(EMBEDDINGS_FILE, "wb") as f:
pickle.dump(embeddings, f)

print(f"Saved {len(embeddings)} gene embeddings.")
return embeddings


def build_faiss_index(genes_df, embeddings):
"""Build and save a FAISS index for fast similarity search."""
if os.path.exists(INDEX_FILE) and os.path.exists(GENE_MAPPING_FILE):
print("Loading FAISS index from file...")
index = faiss.read_index(INDEX_FILE)
with open(GENE_MAPPING_FILE, "rb") as f:
gene_id_mapping = pickle.load(f)
return index, gene_id_mapping

gene_ids, vectors = zip(*embeddings.items())
vectors = np.vstack(vectors).astype("float32")

index = faiss.IndexFlatL2(vectors.shape[1])
index.add(vectors)

gene_id_mapping = {i: genes_df.loc[genes_df["node_index"] == gene_id, ["node_name", "node_id"]].values[0]
for i, gene_id in enumerate(gene_ids)}

faiss.write_index(index, INDEX_FILE)
with open(GENE_MAPPING_FILE, "wb") as f:
pickle.dump(gene_id_mapping, f)

print(f"FAISS index built and saved with {index.ntotal} vectors.")
return index, gene_id_mapping


def find_closest_gene(description, index, gene_id_mapping):
"""Find the closest matching gene using FAISS."""
if index is None or gene_id_mapping is None:
print("FAISS index not built.")
return None

print("🔹 Embedding description with OpenAI API...")
query_embedding = get_openai_embedding(description).astype("float32")

print("🔹 Performing FAISS vector search...")
distances, indices = index.search(query_embedding.reshape(1, -1), k=1)

best_match_idx = indices[0][0]
best_match_score = distances[0][0]
best_match_gene, ncbi_id = gene_id_mapping[best_match_idx]

print(f"Best match found: {best_match_gene} (NCBI ID: {ncbi_id}, Score: {best_match_score:.4f})")

return {"Gene Name": best_match_gene, "NCBI ID": ncbi_id, "Score": best_match_score}




# Extract genes
genes_df = extract_genes_from_go()

# Compute or load embeddings
embeddings = compute_embeddings_for_genes(genes_df)

# Build or load FAISS index
index, gene_id_mapping = build_faiss_index(genes_df, embeddings)

67 changes: 67 additions & 0 deletions docs/notebooks/talk2biomodels/extract_all_genes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import numpy as np
import pandas as pd
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG

def extract_genes_from_go(go_keywords=['drug discovery', 'immune system process', 'inflammatory response']):
"""
Extract genes/proteins that are associated with GO terms related to immune and inflammation processes.

Parameters:
- go_keywords (list): List of keywords to match GO terms (default: ["immune", "inflammation"]).

Returns:
- genes_df (pd.DataFrame): DataFrame containing genes/proteins linked to the selected GO terms.
"""

# Load PrimeKG dataset
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")
primekg_data.load_data()

# Get all nodes and edges
primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

# 🔹 STEP 1: Extract Relevant GO Terms
go_query = "|".join(go_keywords) # Create OR search pattern
go_terms_df = primekg_nodes[
(primekg_nodes["node_type"].isin(["biological_process", "molecular_function", "cellular_component"])) &
(primekg_nodes["node_name"].str.contains(go_query, case=False, na=False))
]

if go_terms_df.empty:
print(f"No GO terms matching {go_keywords} found in PrimeKG!")
return None

go_term_ids = go_terms_df.index.values
print(f"Found {len(go_term_ids)} GO terms related to {go_keywords}.")

# 🔹 STEP 2: Extract Gene-GO Relationships
gene_go_edges_df = primekg_edges[
((primekg_edges.head_index.isin(go_term_ids)) & (primekg_edges.tail_type == "gene/protein")) |
((primekg_edges.tail_index.isin(go_term_ids)) & (primekg_edges.head_type == "gene/protein"))
]

if gene_go_edges_df.empty:
print(f"No gene-GO relationships found for {go_keywords}!")
return None

gene_ids = np.unique(
np.concatenate([
gene_go_edges_df[gene_go_edges_df.head_type == "gene/protein"].head_index.unique(),
gene_go_edges_df[gene_go_edges_df.tail_type == "gene/protein"].tail_index.unique()
])
)

print(f"Found {len(gene_ids)} genes/proteins linked to immune/inflammation GO terms.")

# 🔹 STEP 3: Extract the Gene Nodes
genes_df = primekg_nodes[primekg_nodes.index.isin(gene_ids)]
genes_df.to_csv('genes_df.csv')

print(f"Extracted {len(genes_df)} gene/protein nodes from PrimeKG based on immune/inflammation GO terms.")

return genes_df


genes_df = extract_genes_from_go()
97 changes: 97 additions & 0 deletions docs/notebooks/talk2biomodels/extract_disease_subgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import numpy as np
import pandas as pd
import networkx as nx
import pickle
from tqdm import tqdm
from aiagents4pharma.talk2knowledgegraphs.datasets.primekg import PrimeKG

def extract_disease_subgraph(disease_names=["crohn", "inflammatory bowel disease", "ulcerative colitis"]):
"""
Extracts the subgraph of genes, GO terms, and ontologies related to given diseases.

Parameters:
- disease_names (list): List of disease names (default includes Crohn's, IBD, and Ulcerative Colitis).

Returns:
- subgraph_nodes_df (pd.DataFrame): DataFrame containing the nodes in the extracted subgraph.
- subgraph_edges_df (pd.DataFrame): DataFrame containing the edges in the extracted subgraph.
"""

# Load PrimeKG dataset
primekg_data = PrimeKG(local_dir="../../../../data/primekg/")
primekg_data.load_data()

# Get all nodes and edges
primekg_nodes = primekg_data.get_nodes()
primekg_edges = primekg_data.get_edges()

# 🔹 STEP 1: Extract Disease Nodes
disease_names = [d.lower() for d in disease_names]
disease_query = "|".join(disease_names) # Combine all disease names for OR search
disease_nodes_df = primekg_nodes[
(primekg_nodes["node_type"] == "disease") &
(primekg_nodes["node_name"].str.contains(disease_query, case=False, na=False))
]

if disease_nodes_df.empty:
print(f"⚠️ No matching disease found in PrimeKG!")
return None, None

disease_ids = disease_nodes_df.index.values
print(f"✅ Found {len(disease_ids)} disease nodes for '{disease_names}'.")

# 🔹 STEP 2: Extract Disease-Gene Relationships
disease_gene_edges_df = primekg_edges[
((primekg_edges.head_index.isin(disease_ids)) & (primekg_edges.tail_type == "gene/protein")) |
((primekg_edges.tail_index.isin(disease_ids)) & (primekg_edges.head_type == "gene/protein"))
]

gene_ids = np.unique(
np.concatenate([
disease_gene_edges_df[disease_gene_edges_df.head_type == "gene/protein"].head_index.unique(),
disease_gene_edges_df[disease_gene_edges_df.tail_type == "gene/protein"].tail_index.unique()
])
)
print(f"✅ Found {len(gene_ids)} genes/proteins related to '{disease_names}'.")

# 🔹 STEP 3: Extract GO Terms (biological_process, molecular_function, cellular_component)
go_terms_df = primekg_nodes[
primekg_nodes["node_type"].isin(["biological_process", "molecular_function", "cellular_component"])
]
go_term_ids = go_terms_df.index.values

# 🔹 STEP 4: Extract Ontologies (SNOMEDCT, BTO, FMA, Anatomy)
ontology_nodes = primekg_nodes[
primekg_nodes["node_type"].isin(["SNOMEDCT", "BTO", "FMA", "anatomy"])
]
ontology_ids = ontology_nodes.index.values

# 🔹 STEP 5: Extract Subgraph Edges (Gene → GO Terms → Ontologies)
subgraph_edges_df = primekg_edges[
((primekg_edges.head_index.isin(gene_ids)) & (primekg_edges.tail_index.isin(go_term_ids))) |
((primekg_edges.tail_index.isin(gene_ids)) & (primekg_edges.head_index.isin(go_term_ids))) |
((primekg_edges.head_index.isin(go_term_ids)) & (primekg_edges.tail_index.isin(ontology_ids))) |
((primekg_edges.tail_index.isin(go_term_ids)) & (primekg_edges.head_index.isin(ontology_ids))) |
((primekg_edges.head_index.isin(gene_ids)) & (primekg_edges.tail_index.isin(ontology_ids))) |
((primekg_edges.tail_index.isin(gene_ids)) & (primekg_edges.head_index.isin(ontology_ids)))
]

# 🔹 STEP 6: Extract All Related Nodes
subgraph_node_ids = np.unique(
np.hstack([subgraph_edges_df.head_index.unique(), subgraph_edges_df.tail_index.unique()])
)
subgraph_nodes_df = primekg_nodes[primekg_nodes.index.isin(subgraph_node_ids)]

print(f"✅ Final subgraph contains {len(subgraph_nodes_df)} nodes and {len(subgraph_edges_df)} edges.")

return subgraph_nodes_df, subgraph_edges_df


# Example Usage
subgraph_nodes, subgraph_edges = extract_disease_subgraph(
["Crohn's Disease", "Inflammatory Bowel Disease", "Ulcerative Colitis"]
)

# To extract for another set of diseases, simply change the list:
# subgraph_nodes, subgraph_edges = extract_disease_subgraph(["Lung Cancer", "Breast Cancer"])
Loading