-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwiki-table-questions.py
112 lines (66 loc) · 2.76 KB
/
wiki-table-questions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Dataset loading script for WikiTableQuestions sparated by tab."""
import datasets
logger = datasets.logging.get_logger(__name__)
_DESCRIPTION = "WikiTableQuestions"
class WikiTableQuestionsConfig(datasets.BuilderConfig):
"""BuilderConfig for OntoNotes 4.0"""
def __init__(self, **kwargs):
"""BuilderConfig for OntoNotes 4.0.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(WikiTableQuestionsConfig, self).__init__(**kwargs)
class WikiTableQuestions(datasets.GeneratorBasedBuilder):
"""OntoNotes 4.0."""
BUILDER_CONFIGS = [
WikiTableQuestionsConfig(name='WikiTableQuestions', description="WikiTableQuestions dataset.")
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"length": datasets.Value("int32"),
"tokens": datasets.Sequence(datasets.Value("string")),
"label": datasets.features.ClassLabel(
names=[
"who",
"what",
"when",
"where",
"why",
"how",
"which",
"whose",
]
)
}
)
)
def _split_generators(self, dl_manager):
train_file = "./data/wikiTable/sample.train"
dev_file = "./data/wikiTable/sample.dev"
test_file ="./data/wikiTable/sample.test"
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_file}),
datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dev_file}),
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_file}),
]
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
with open(filepath, encoding="utf-8") as f:
guid = 0
for line in f:
line_list = line.strip().split("\t")
label, length, sentence = line_list
if len(line_list) != 3:
print(line_list)
yield guid, {
"id": str(guid),
"tokens": sentence.split(),
"length": int(length),
"label": label,
}
if __name__ == "__main__":
WikiTableQuestions(__name__)