8
8
9
9
class Agent :
10
10
def __init__ (self , env = Environment (), output_path = None ):
11
+ self .output_path = output_path
12
+
11
13
# Training
12
14
self .n_steps = 100000 # 100k
13
15
self .memory = list ()
@@ -23,20 +25,26 @@ def __init__(self, env=Environment(), output_path=None):
23
25
def sample (self ):
24
26
state = self .env .reset ()
25
27
28
+ import time
29
+ start = time .time ()
30
+
26
31
for step in range (self .n_steps ):
27
32
action = random .randrange (self .n_actions )
28
33
29
34
next_state , reward , done , _ = self .env .step (action )
30
35
31
- self .memory .append ((state , [action ], [reward ], next_state , [done ]))
36
+ self .memory .append ((state . tolist () , [action ], [reward ], next_state . tolist () , [done ]))
32
37
33
38
state = next_state
34
39
35
40
if (step != 0 and step % self .target_update_interval == 0 ):
36
- with open ('memory_samples .json' , 'w+' ) as f :
41
+ with open ('{}_samples .json' . format ( self . output_path ) , 'w+' ) as f :
37
42
json .dump (self .memory , f )
43
+ print ("" )
44
+ print (time .time () - start )
38
45
print (step )
39
46
self .env .reset ()
47
+ break
40
48
41
49
self .env .close ()
42
50
@@ -45,6 +53,24 @@ def sample(self):
45
53
import os
46
54
print ("Restarting PostgreSQL..." )
47
55
os .system ('sudo systemctl restart postgresql@12-main' )
56
+
57
+ agent1 = Agent (env = Environment (allow_columns = False , flip = False , window_size = 40 ), output_path = 'env1' )
58
+ agent2 = Agent (env = Environment (allow_columns = False , flip = False , window_size = 80 ), output_path = 'env2' )
59
+
60
+ agent3 = Agent (env = Environment (allow_columns = True , flip = False , window_size = 40 ), output_path = 'env3' )
61
+ agent4 = Agent (env = Environment (allow_columns = True , flip = False , window_size = 80 ), output_path = 'env4' )
62
+
63
+ agent5 = Agent (env = Environment (allow_columns = False , flip = True , window_size = 40 ), output_path = 'env5' )
64
+ agent6 = Agent (env = Environment (allow_columns = False , flip = True , window_size = 80 ), output_path = 'env6' )
48
65
49
- agent = Agent (env = Environment ())
50
- agent .sample ()
66
+ agent7 = Agent (env = Environment (allow_columns = True , flip = True , window_size = 40 ), output_path = 'env7' )
67
+ agent8 = Agent (env = Environment (allow_columns = True , flip = True , window_size = 80 ), output_path = 'env8' )
68
+
69
+ agent1 .sample ()
70
+ agent2 .sample ()
71
+ agent3 .sample ()
72
+ agent4 .sample ()
73
+ agent5 .sample ()
74
+ agent6 .sample ()
75
+ agent7 .sample ()
76
+ agent8 .sample ()
0 commit comments