|
4 | 4 | from http import HTTPStatus
|
5 | 5 | from threading import Thread
|
6 | 6 | from time import sleep
|
| 7 | +from typing import List |
7 | 8 | from uuid import uuid4
|
8 | 9 |
|
9 | 10 | import pytest
|
@@ -205,6 +206,67 @@ def make_request(monkey_island_requests, request_callback):
|
205 | 206 | assert response_codes.count(HTTPStatus.TOO_MANY_REQUESTS) == 1
|
206 | 207 |
|
207 | 208 |
|
| 209 | +RATE_LIMIT_AGENT1_ID = uuid4() |
| 210 | +RATE_LIMIT_AGENT2_ID = uuid4() |
| 211 | + |
| 212 | + |
| 213 | +@pytest.mark.parametrize( |
| 214 | + "request_callback, successful_request_status, max_requests_per_second", |
| 215 | + [ |
| 216 | + (lambda mir: mir.get(GET_AGENT_OTP_ENDPOINT), HTTPStatus.OK, MAX_OTP_REQUESTS_PER_SECOND), |
| 217 | + ], |
| 218 | +) |
| 219 | +def test_rate_limit__agent_user( |
| 220 | + island, |
| 221 | + monkey_island_requests, |
| 222 | + request_callback, |
| 223 | + successful_request_status, |
| 224 | + max_requests_per_second, |
| 225 | +): |
| 226 | + monkey_island_requests.login() |
| 227 | + response = monkey_island_requests.get(GET_AGENT_OTP_ENDPOINT) |
| 228 | + otp1 = response.json()["otp"] |
| 229 | + response = monkey_island_requests.get(GET_AGENT_OTP_ENDPOINT) |
| 230 | + otp2 = response.json()["otp"] |
| 231 | + |
| 232 | + agent1_requests = AgentRequests(island, RATE_LIMIT_AGENT1_ID, OTP(otp1)) |
| 233 | + agent1_requests.login() |
| 234 | + agent2_requests = AgentRequests(island, RATE_LIMIT_AGENT2_ID, OTP(otp2)) |
| 235 | + agent2_requests.login() |
| 236 | + |
| 237 | + threads = [] |
| 238 | + response_codes1: List[int] = [] |
| 239 | + response_codes2: List[int] = [] |
| 240 | + |
| 241 | + def make_request(agent_requests, request_callback, response_codes): |
| 242 | + response = request_callback(agent_requests) |
| 243 | + response_codes.append(response.status_code) |
| 244 | + |
| 245 | + for _ in range(0, max_requests_per_second + 1): |
| 246 | + t1 = Thread( |
| 247 | + target=make_request, |
| 248 | + args=(agent1_requests, request_callback, response_codes1), |
| 249 | + daemon=True, |
| 250 | + ) |
| 251 | + t1.start() |
| 252 | + t2 = Thread( |
| 253 | + target=make_request, |
| 254 | + args=(agent2_requests, request_callback, response_codes2), |
| 255 | + daemon=True, |
| 256 | + ) |
| 257 | + t2.start() |
| 258 | + threads.append(t1) |
| 259 | + threads.append(t2) |
| 260 | + |
| 261 | + for t in threads: |
| 262 | + t.join() |
| 263 | + |
| 264 | + assert response_codes1.count(successful_request_status) == max_requests_per_second |
| 265 | + assert response_codes1.count(HTTPStatus.TOO_MANY_REQUESTS) == 1 |
| 266 | + assert response_codes2.count(successful_request_status) == max_requests_per_second |
| 267 | + assert response_codes2.count(HTTPStatus.TOO_MANY_REQUESTS) == 1 |
| 268 | + |
| 269 | + |
208 | 270 | def test_refresh_access_token(monkey_island_requests):
|
209 | 271 | monkey_island_requests.login()
|
210 | 272 | original_token = monkey_island_requests.token
|
|
0 commit comments