-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsse.py
117 lines (87 loc) · 2.89 KB
/
sse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import multiprocessing
import os
import psutil
import redis
from flask import Flask, render_template, request
from flask_sse import sse
from modules.training import Training
from modules.observer import Observer
app = Flask(__name__)
app.config["REDIS_URL"] = "redis://localhost"
app.register_blueprint(sse, url_prefix='/stream')
app.training = None
app.thread = None
app.conf = None
app.pid = None
@app.route('/')
def index():
redis.Redis(host='localhost', port=6379, decode_responses=True)
return render_template("index.html")
@app.route('/training', methods=['GET', 'POST'])
def training():
group = request.args.get('group')
subj_id = request.args.get('id')
if group == "1":
return render_template("training_baseline.html", id=subj_id)
else:
return render_template("training.html", id=subj_id, group=group)
@app.route('/rating', methods=['GET', 'POST'])
def rating():
id = request.args.get('id')
return render_template("rating.html", id=id)
@app.route('/finale', methods=['GET', 'POST'])
def finale():
id = request.args.get('id')
rating = request.args.get('rating')
rating2 = request.args.get('rating2')
return render_template("finale.html", id=id, rating=rating, rating2=rating2)
@app.route('/rating2', methods=['GET', 'POST'])
def rating2():
id = request.args.get('id')
rating = request.args.get('rating')
return render_template("rating2.html", id=id, rating=rating)
@app.route('/start', methods=['GET', 'POST'])
def start():
stop()
arg = request.form
if arg:
conf = eval(list(arg.to_dict())[0])
app.conf = conf
# print(os.getpid())
app.thread = multiprocessing.Process(target=startTraining, args=[conf])
app.thread.start()
return "Training is starting..."
@app.route('/startTraining')
def startTraining(args):
print(args)
app.pid = os.getpid()
app.training = Training(args["group"], args["subj_id"])
logger = ObservableLogger()
app.training.start(args["training_round"], logger, args["kwargs"])
def kill_proc_tree(pid, including_parent=False):
parent = psutil.Process(pid)
for child in parent.children(recursive=True):
child.kill()
if including_parent:
parent.kill()
@app.route('/stop', methods=['GET', 'POST'])
def stop():
# app.thread.terminate()
# app.thread.kill()
if app.thread:
kill_proc_tree(app.thread.pid)
app.training = None
return "Training stopped!"
def send_message(message, head):
message = message.replace("\'", "\"")
sse.publish({head: message}, type='greeting')
return "Message sent!"
class ObservableLogger(Observer):
def __init__(self):
Observer.__init__(self)
def append(self, message="Default"):
send_message(message, "message")
def plot_cm(self, message="Default"):
send_message(message, "plot_cm")
if __name__ == "__main__":
app.run(debug=True)