-
Notifications
You must be signed in to change notification settings - Fork 222
/
Copy pathSpikeConnection.kt
108 lines (97 loc) · 4.04 KB
/
SpikeConnection.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package burp
import burp.network.stack.http2.frame.DataFrame
import burp.network.stack.http2.frame.Frame
import burp.network.stack.http2.frame.HeaderFrame
import net.hackxor.api.Http2Constants
import net.hackxor.api.StreamFrameProcessor
import java.io.ByteArrayOutputStream
import java.util.*
import java.util.concurrent.ConcurrentHashMap
class SpikeConnection(private val engine: SpikeEngine) : StreamFrameProcessor {
var inflight: ConcurrentHashMap<Int, Request>
private val dataFrames: MutableMap<Int, MutableList<DataFrame>> = ConcurrentHashMap()
private val headerFrames: MutableMap<Int, MutableList<HeaderFrame>> = ConcurrentHashMap()
private val gates: ConcurrentHashMap<String, Int> = ConcurrentHashMap()
init {
inflight = ConcurrentHashMap<Int, Request>()
}
override fun process(frame: Frame) {
//System.out.println(frame.Q);
try {
if (frame is HeaderFrame) {
val time = System.nanoTime()
val req = inflight[frame.G]!!
req.time = (time - req.time) / 1000
req.arrival = (time - engine.start) / 1000
if (req.gate != null) {
val gateName = req!!.gate!!.name
val seen = gates.getOrDefault(gateName, 0)
req.order = seen
gates[gateName] = seen + 1
}
val newFrames = headerFrames.computeIfAbsent(
frame.G
) { id: Int? -> LinkedList() }
newFrames.add(frame)
} else if (frame is DataFrame) {
val newFrames = dataFrames.computeIfAbsent(
frame.G
) { id: Int? -> LinkedList() }
newFrames.add(frame)
}
if (frame.isFlagSet(Http2Constants.END_STREAM_FLAG)) {
prepareCallback(frame.G)
}
} catch (e: Exception) {
Utils.out("Oh no: " + e.message)
e.printStackTrace()
}
// if (frame instanceof HeaderFrame) {
// List<Header> headers = ((HeaderFrame) frame).headers();
// System.out.println(frame.Q);
// for (Header header: headers) {
// if (header.name().equals("x-time")) {
// long time = Long.parseLong(header.value());
// if (recordedTime == 0) {
// recordedTime = time;
// } else {
// System.out.println(frame.Q+": " + (time - recordedTime)+"μs");
// recordedTime = 0;
// }
// }
// }
//frame.isFlagSet(Http2Constants.END_HEADERS_FLAG)
//((DataFrame) frame).data()
//}
}
fun prepareCallback(streamID: Int) {
val headers: List<HeaderFrame> = headerFrames.remove(streamID)!!
val data: List<DataFrame> = dataFrames.remove(streamID)?: emptyList()
val resp = StringBuilder()
var shouldUnzip = false
for (frame in headers) {
for (header in frame.headers()) {
if (header.isPseudoHeader) {
resp.append("HTTP/2 ${header.value()} OK\r\n")
} else if ("content-encoding" == header.name() && "gzip" == header.value()) {
shouldUnzip = true
} else {
resp.append(header.name()+": "+header.value())
resp.append("\r\n")
}
}
}
resp.append("\r\n")
val bodyBytes = ByteArrayOutputStream()
for (frame in data) {
bodyBytes.writeBytes(frame.data())
}
if (shouldUnzip) {
resp.append(ThreadedRequestEngine.decompress(bodyBytes.toByteArray()))
} else {
resp.append(String(bodyBytes.toByteArray()))
}
val req = inflight.remove(streamID) ?: throw RuntimeException("Couldn't find "+streamID+ " in inflight: "+inflight.keys().asSequence())
engine.handleResponse(streamID, resp.toString(), req)
}
}