-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathstarkqa_primekg.py
198 lines (163 loc) · 7.17 KB
/
starkqa_primekg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Class for loading StarkQAPrimeKG dataset.
"""
import os
import shutil
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from huggingface_hub import hf_hub_download, list_repo_files
import gdown
from .dataset import Dataset
class StarkQAPrimeKG(Dataset):
"""
Class for loading StarkQAPrimeKG dataset.
It downloads the data from the HuggingFace repo and stores it in the local directory.
The data is then loaded into pandas DataFrame of QA pairs, dictionary of split indices,
and node information.
"""
def __init__(self, local_dir: str = "../../../data/starkqa_primekg/"):
"""
Constructor for StarkQAPrimeKG class.
Args:
local_dir (str): The local directory to store the dataset files.
"""
self.name: str = "starkqa_primekg"
self.hf_repo_id: str = "snap-stanford/stark"
self.local_dir: str = local_dir
# Attributes to store the data
self.starkqa: pd.DataFrame = None
self.starkqa_split_idx: dict = None
self.starkqa_node_info: dict = None
self.query_emb_dict: dict = None
self.node_emb_dict: dict = None
# Set up the dataset
self.setup()
def setup(self):
"""
A method to set up the dataset.
"""
# Make the directory if it doesn't exist
os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)
def _load_stark_repo(self) -> pd.DataFrame:
"""
Private method to load related files of StarkQAPrimeKG dataset.
Returns:
pd.DataFrame: The nodes dataframe of StarkQAPrimeKG dataset.
"""
# Download the file if it does not exist in the local directory
# Otherwise, load the data from the local directory
local_file = os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv")
if os.path.exists(local_file):
print(f"{local_file} already exists. Loading the data from the local directory.")
else:
print(f"Downloading files from {self.hf_repo_id}")
# List all related files in the HuggingFace Hub repository
files = list_repo_files(self.hf_repo_id, repo_type="dataset")
files = [f for f in files if ((f.startswith("qa/prime/") or
f.startswith("skb/prime/")) and f.find("raw") == -1)]
# Download and save each file in the specified folder
for file in tqdm(files):
_ = hf_hub_download(self.hf_repo_id,
file,
repo_type="dataset",
local_dir=self.local_dir)
# Unzip the processed files
shutil.unpack_archive(
os.path.join(self.local_dir, "skb/prime/processed.zip"),
os.path.join(self.local_dir, "skb/prime/")
)
# Load StarkQA dataframe
starkqa = pd.read_csv(
os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv"),
low_memory=False)
# Read split indices
qa_indices = sorted(starkqa['id'].tolist())
starkqa_split_idx = {}
for split in ['train', 'val', 'test', 'test-0.1']:
indices_file = os.path.join(self.local_dir, "qa/prime/split", f'{split}.index')
with open(indices_file, 'r', encoding='utf-8') as f:
indices = f.read().strip().split('\n')
query_ids = [int(idx) for idx in indices]
starkqa_split_idx[split] = np.array(
[qa_indices.index(query_id) for query_id in query_ids]
)
# Load the node info of PrimeKG preprocessed for StarkQA
with open(os.path.join(self.local_dir, 'skb/prime/processed/node_info.pkl'), 'rb') as f:
starkqa_node_info = pickle.load(f)
return starkqa, starkqa_split_idx, starkqa_node_info
def _load_stark_embeddings(self) -> tuple:
"""
Private method to load the embeddings of StarkQAPrimeKG dataset.
Returns:
tuple: A tuple of query and node embeddings dictionaries.
"""
# Load the provided embeddings of query and nodes
# Note that they utilized 'text-embedding-ada-002' for embeddings
emb_model = 'text-embedding-ada-002'
query_emb_url = 'https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU'
node_emb_url = 'https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy'
# Prepare respective directories to store the embeddings
emb_dir = os.path.join(self.local_dir, emb_model)
query_emb_dir = os.path.join(emb_dir, "query")
node_emb_dir = os.path.join(emb_dir, "doc")
os.makedirs(query_emb_dir, exist_ok=True)
os.makedirs(node_emb_dir, exist_ok=True)
query_emb_path = os.path.join(query_emb_dir, "query_emb_dict.pt")
node_emb_path = os.path.join(node_emb_dir, "candidate_emb_dict.pt")
# Download the embeddings if they do not exist in the local directory
if not os.path.exists(query_emb_path) or not os.path.exists(node_emb_path):
# Download the query embeddings
gdown.download(query_emb_url, query_emb_path, quiet=False)
# Download the node embeddings
gdown.download(node_emb_url, node_emb_path, quiet=False)
# Load the embeddings
query_emb_dict = torch.load(query_emb_path)
node_emb_dict = torch.load(node_emb_path)
return query_emb_dict, node_emb_dict
def load_data(self):
"""
Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs,
dictionary of split indices, and node information.
"""
print("Loading StarkQAPrimeKG dataset...")
self.starkqa, self.starkqa_split_idx, self.starkqa_node_info = self._load_stark_repo()
print("Loading StarkQAPrimeKG embeddings...")
self.query_emb_dict, self.node_emb_dict = self._load_stark_embeddings()
def get_starkqa(self) -> pd.DataFrame:
"""
Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.
Returns:
pd.DataFrame: The nodes dataframe of PrimeKG dataset.
"""
return self.starkqa
def get_starkqa_split_indicies(self) -> dict:
"""
Get the split indices of StarkQAPrimeKG dataset.
Returns:
dict: The split indices of StarkQAPrimeKG dataset.
"""
return self.starkqa_split_idx
def get_starkqa_node_info(self) -> dict:
"""
Get the node information of StarkQAPrimeKG dataset.
Returns:
dict: The node information of StarkQAPrimeKG dataset.
"""
return self.starkqa_node_info
def get_query_embeddings(self) -> dict:
"""
Get the query embeddings of StarkQAPrimeKG dataset.
Returns:
dict: The query embeddings of StarkQAPrimeKG dataset.
"""
return self.query_emb_dict
def get_node_embeddings(self) -> dict:
"""
Get the node embeddings of StarkQAPrimeKG dataset.
Returns:
dict: The node embeddings of StarkQAPrimeKG dataset.
"""
return self.node_emb_dict