-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_program_dataset_generator.py
94 lines (68 loc) · 3.5 KB
/
main_program_dataset_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pickle
import numpy as np
import tqdm
from config import Config
from dsl import DSL
from dsl.program_generator import ProgramGenerator
from karel.world_generator import WorldGenerator
from logger.stdout_logger import StdoutLogger
from search.sketch_sampler import SketchSampler
from multiprocessing import pool
if __name__ == '__main__':
dsl = DSL.init_default_karel()
program_generator = ProgramGenerator(dsl)
sketch_sampler = SketchSampler()
world_generator = WorldGenerator()
generate_demos = Config.datagen_generate_demos
generate_sketches = Config.datagen_generate_sketches
seen_programs = set()
program_dataset = []
program_and_demos_dataset = []
sketches_dataset = []
programs_and_sketches_dataset = []
program_setup = f'size{Config.data_max_program_size}_dep{Config.data_max_program_depth}_seq{Config.data_max_program_sequence}'
sketch_setup = f'sk{Config.datagen_sketch_iterations}'
StdoutLogger.log('Generator', f'Generating programs in setup {program_setup}')
with tqdm.tqdm(total=Config.datagen_num_programs) as pbar:
while len(program_dataset) < Config.datagen_num_programs:
program = program_generator.generate_program()
program_str = dsl.parse_node_to_str(program)
if program_str in seen_programs: continue
if generate_demos:
try:
s_h, a_h = program_generator.generate_demos(program, world_generator,
Config.data_num_demo_per_program,
Config.data_max_demo_length)
except Exception: # In case a timeout occurs
continue
seen_programs.add(program_str)
p = dsl.parse_str_to_int(program_str)
program_dataset.append(p)
if generate_demos:
program_and_demos_dataset.append((p, s_h, a_h))
pbar.update(1)
if generate_sketches:
programs_nodes = [dsl.parse_int_to_node(p) for p in program_dataset]
StdoutLogger.log('Generator', f'Generating sketches in setup {sketch_setup}')
with tqdm.tqdm(total=Config.datagen_num_programs) as pbar:
def sample(program_nodes):
s = sketch_sampler.sample_sketch(program_nodes)
pbar.update(1)
return s
with pool.Pool() as pl:
sketches_nodes = pl.map(sample, programs_nodes)
for p, sketch_nodes in zip(program_dataset, sketches_nodes):
s = dsl.parse_node_to_int(sketch_nodes)
sketches_dataset.append(s)
programs_and_sketches_dataset.append((p, s))
StdoutLogger.log('Generator', 'Saving files.')
with open(f'data/programs_{program_setup}_only.pkl', 'wb') as f:
pickle.dump(program_dataset, f)
if generate_demos:
with open(f'data/programs_{program_setup}_and_demos.pkl', 'wb') as f:
pickle.dump(program_and_demos_dataset, f)
if generate_sketches:
with open(f'data/sketches_{program_setup}_{sketch_setup}.pkl', 'wb') as f:
pickle.dump(sketches_dataset, f)
with open(f'data/programs_{program_setup}_and_sketches_{sketch_setup}.pkl', 'wb') as f:
pickle.dump(programs_and_sketches_dataset, f)