Skip to content

Commit f7278d5

Browse files
committed
Minor fixes
1 parent b1a5c11 commit f7278d5

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

sampler.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
class Agent:
1010
def __init__(self, env=Environment(), output_path=None):
11+
self.output_path = output_path
12+
1113
# Training
1214
self.n_steps = 100000 # 100k
1315
self.memory = list()
@@ -23,20 +25,26 @@ def __init__(self, env=Environment(), output_path=None):
2325
def sample(self):
2426
state = self.env.reset()
2527

28+
import time
29+
start = time.time()
30+
2631
for step in range(self.n_steps):
2732
action = random.randrange(self.n_actions)
2833

2934
next_state, reward, done, _ = self.env.step(action)
3035

31-
self.memory.append((state, [action], [reward], next_state, [done]))
36+
self.memory.append((state.tolist(), [action], [reward], next_state.tolist(), [done]))
3237

3338
state = next_state
3439

3540
if (step != 0 and step % self.target_update_interval == 0):
36-
with open('memory_samples.json', 'w+') as f:
41+
with open('{}_samples.json'.format(self.output_path), 'w+') as f:
3742
json.dump(self.memory, f)
43+
print("")
44+
print(time.time() - start)
3845
print(step)
3946
self.env.reset()
47+
break
4048

4149
self.env.close()
4250

@@ -45,6 +53,24 @@ def sample(self):
4553
import os
4654
print("Restarting PostgreSQL...")
4755
os.system('sudo systemctl restart postgresql@12-main')
56+
57+
agent1 = Agent(env=Environment(allow_columns=False, flip=False, window_size=40), output_path='env1')
58+
agent2 = Agent(env=Environment(allow_columns=False, flip=False, window_size=80), output_path='env2')
59+
60+
agent3 = Agent(env=Environment(allow_columns=True, flip=False, window_size=40), output_path='env3')
61+
agent4 = Agent(env=Environment(allow_columns=True, flip=False, window_size=80), output_path='env4')
62+
63+
agent5 = Agent(env=Environment(allow_columns=False, flip=True, window_size=40), output_path='env5')
64+
agent6 = Agent(env=Environment(allow_columns=False, flip=True, window_size=80), output_path='env6')
4865

49-
agent = Agent(env=Environment())
50-
agent.sample()
66+
agent7 = Agent(env=Environment(allow_columns=True, flip=True, window_size=40), output_path='env7')
67+
agent8 = Agent(env=Environment(allow_columns=True, flip=True, window_size=80), output_path='env8')
68+
69+
agent1.sample()
70+
agent2.sample()
71+
agent3.sample()
72+
agent4.sample()
73+
agent5.sample()
74+
agent6.sample()
75+
agent7.sample()
76+
agent8.sample()

0 commit comments

Comments
 (0)