Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3078 rate limit login #3216

Merged
merged 9 commits into from
Apr 10, 2023
25 changes: 19 additions & 6 deletions envs/monkey_zoo/blackbox/test_blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,37 @@ def test_logout_invalidates_all_tokens(island):
assert resp.status_code == HTTPStatus.UNAUTHORIZED


def test_agent_otp_rate_limit(monkey_island_requests):
AGENT_OTP_LOGIN_ENDPOINT = "/api/agent-otp-login"


@pytest.mark.parametrize(
"request_callback, successful_request_status, max_requests_per_second",
[
(lambda mir: mir.get(GET_AGENT_OTP_ENDPOINT), HTTPStatus.OK, MAX_OTP_REQUESTS_PER_SECOND),
],
)
def test_rate_limit(
monkey_island_requests, request_callback, successful_request_status, max_requests_per_second
):
monkey_island_requests.login()
threads = []
response_codes = []

def make_request():
response = monkey_island_requests.get(GET_AGENT_OTP_ENDPOINT)
def make_request(monkey_island_requests, request_callback):
response = request_callback(monkey_island_requests)
response_codes.append(response.status_code)

for _ in range(0, MAX_OTP_REQUESTS_PER_SECOND + 1):
t = Thread(target=make_request, daemon=True)
for _ in range(0, max_requests_per_second + 1):
t = Thread(
target=make_request, args=(monkey_island_requests, request_callback), daemon=True
)
t.start()
threads.append(t)

for t in threads:
t.join()

assert response_codes.count(HTTPStatus.OK) == MAX_OTP_REQUESTS_PER_SECOND
assert response_codes.count(successful_request_status) == max_requests_per_second
assert response_codes.count(HTTPStatus.TOO_MANY_REQUESTS) == 1


Expand Down
3 changes: 3 additions & 0 deletions monkey/infection_monkey/island_api_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
IslandAPIError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPIRequestLimitExceededError,
IslandAPITimeoutError,
)

Expand Down Expand Up @@ -47,6 +48,8 @@ def decorated(*args, **kwargs):
HTTPStatus.FORBIDDEN.value,
]:
raise IslandAPIAuthenticationError(err)
if err.response.status_code == HTTPStatus.TOO_MANY_REQUESTS:
raise IslandAPIRequestLimitExceededError(err)
if 400 <= err.response.status_code < 500:
raise IslandAPIRequestError(err)
if 500 <= err.response.status_code < 600:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
import json
import logging
from http import HTTPStatus
from pprint import pformat
from time import sleep
from typing import Any, Dict, List, Sequence

import requests
Expand Down Expand Up @@ -110,8 +110,14 @@ def logout(self):

@handle_response_parsing_errors
def _refresh_token(self):
response = self._http_client.post("/refresh-authentication-token", {})
self._update_token_from_response(response)
for _ in range(6):
try:
response = self._http_client.post("/refresh-authentication-token", {})
self._update_token_from_response(response)
break
except IslandAPIRequestLimitExceededError:
sleep(0.5)
continue

@handle_authentication_token_expiration
def get_agent_binary(self, operating_system: OperatingSystem) -> bytes:
Expand All @@ -123,8 +129,6 @@ def get_agent_binary(self, operating_system: OperatingSystem) -> bytes:
@handle_authentication_token_expiration
def get_otp(self) -> str:
response = self._http_client.get("/agent-otp")
if response.status_code == HTTPStatus.TOO_MANY_REQUESTS:
raise IslandAPIRequestLimitExceededError("Too many requests to get OTP.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the IslandAPIAgentOTPProvider still expect this error to be raised?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, it's now handled by the HTTPClient decorator

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It gets raised from self._http_client.get() in the handle_island_errors() decorator.

return response.json()["otp"]

@handle_response_parsing_errors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def refresh_user_token(self, user: User) -> Tuple[Token, int]:
:param user: The user to refresh the token for
:return: The new token and the time when it will expire (in Unix time)
"""
self.revoke_all_tokens_for_user(user)
with self._user_lock:
self.revoke_all_tokens_for_user(user)

return Token(user.get_auth_token()), self._token_ttl_sec
return Token(user.get_auth_token()), self._token_ttl_sec

def authorize_otp(self, otp: OTP) -> bool:
# SECURITY: This method must not run concurrently, otherwise there could be TOCTOU errors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, otp_generator: IOTPGenerator, limiter: Limiter):
# the class variable.
#
# TODO: The limit is currently applied per IP address. We will want to change
# it to per-user once we require authentication for this endpoint.
# it to per-user, per-IP once we require authentication for this endpoint.
# Note that we do not want to limit to just per-user, otherwise this endpoint could be used
# to enumerate users/tokens.
with AgentOTP.lock:
if AgentOTP.limiter is None:
AgentOTP.limiter = limiter.limit(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import string
from http import HTTPStatus
from threading import Lock
from typing import Tuple

from flask import make_response, request
from flask_limiter import Limiter, RateLimitExceeded
from flask_limiter.util import get_remote_address

from common.common_consts.token_keys import ACCESS_TOKEN_KEY_NAME, TOKEN_TTL_KEY_NAME
from common.types import OTP, AgentID
Expand All @@ -13,6 +16,11 @@
from ..authentication_facade import AuthenticationFacade
from .utils import include_auth_token

# 100 requests per second is arbitrary, but is expected to be a good-enough limit. Remember that,
# because of the agent's relay/tunnel capability, many requests could be funneled through the same
# agent's relay, making them appear to come from the same IP.
MAX_OTP_LOGIN_REQUESTS_PER_SECOND = 100


class ArgumentParsingException(Exception):
pass
Expand All @@ -26,8 +34,21 @@ class AgentOTPLogin(AbstractResource):
"""

urls = ["/api/agent-otp-login"]
lock = Lock()
limiter = None

def __init__(self, authentication_facade: AuthenticationFacade, limiter: Limiter):
# Since flask generates a new instance of this class for each request,
# we need to ensure that a single instance of the limiter is used. Hence
# the class variable.
with AgentOTPLogin.lock:
if AgentOTPLogin.limiter is None:
AgentOTPLogin.limiter = limiter.limit(
f"{MAX_OTP_LOGIN_REQUESTS_PER_SECOND}/second",
key_func=get_remote_address,
per_method=True,
)

def __init__(self, authentication_facade: AuthenticationFacade):
self._authentication_facade = authentication_facade

# Secured via OTP, not via authentication token.
Expand All @@ -39,6 +60,16 @@ def post(self):

:return: Authentication token in the response body
"""
if AgentOTPLogin.limiter is None:
raise RuntimeError("limiter has not been initialized")

try:
with AgentOTPLogin.limiter:
return self._handle_otp_login_request()
except RateLimitExceeded:
return make_response("Rate limit exceeded", HTTPStatus.TOO_MANY_REQUESTS)

def _handle_otp_login_request(self):
try:
agent_id, otp = self._get_request_arguments(request.json)
except ArgumentParsingException as err:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from http import HTTPStatus
from threading import Lock

from flask import Response, make_response, request
from flask.typing import ResponseValue
from flask_limiter import Limiter, RateLimitExceeded
from flask_security.views import login

from monkey_island.cc.flask_utils import AbstractResource, responses
Expand All @@ -12,16 +14,27 @@

logger = logging.getLogger(__name__)

MAX_LOGIN_REQUESTS_PER_SECOND = 5


class Login(AbstractResource):
"""
A resource for user authentication
"""

urls = ["/api/login"]
lock = Lock()
limiter = None

def __init__(self, authentication_facade: AuthenticationFacade):
def __init__(self, authentication_facade: AuthenticationFacade, limiter: Limiter):
self._authentication_facade = authentication_facade
with Login.lock:
if Login.limiter is None:
Login.limiter = limiter.limit(
f"{MAX_LOGIN_REQUESTS_PER_SECOND}/second",
key_func=lambda: "key", # Limit all requests, not just per IP
per_method=True,
)

# Can't be secured, used for login
@include_auth_token
Expand All @@ -34,6 +47,16 @@ def post(self):

:return: Access token in the response body
"""
if Login.limiter is None:
raise RuntimeError("limiter has not been initialized")

try:
with Login.limiter:
return self._handle_login_request()
except RateLimitExceeded:
return make_response("Rate limit exceeded", HTTPStatus.TOO_MANY_REQUESTS)

def _handle_login_request(self):
try:
username, password = get_username_password_from_request(request)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,47 @@
import logging
from http import HTTPStatus
from threading import Lock

from flask import make_response
from flask_limiter import Limiter, RateLimitExceeded
from flask_limiter.util import get_remote_address
from flask_login import current_user
from flask_security import auth_token_required

from common.common_consts.token_keys import ACCESS_TOKEN_KEY_NAME, TOKEN_TTL_KEY_NAME
from monkey_island.cc.flask_utils import AbstractResource, responses

from ..authentication_facade import AuthenticationFacade
from .agent_otp_login import MAX_OTP_LOGIN_REQUESTS_PER_SECOND

logger = logging.getLogger(__name__)

# We're assuming that whatever agents registered with the island simultaneously will more or less
# request refresh tokens simultaneously.
MAX_REFRESH_AUTHENTICATION_TOKEN_REQUESTS_PER_SECOND = MAX_OTP_LOGIN_REQUESTS_PER_SECOND


class RefreshAuthenticationToken(AbstractResource):
"""
A resource for refreshing tokens
"""

urls = ["/api/refresh-authentication-token"]
lock = Lock()
limiter = None

def __init__(self, authentication_facade: AuthenticationFacade, limiter: Limiter):
# Since flask generates a new instance of this class for each request,
# we need to ensure that a single instance of the limiter is used. Hence
# the class variable.
with RefreshAuthenticationToken.lock:
if RefreshAuthenticationToken.limiter is None:
RefreshAuthenticationToken.limiter = limiter.limit(
f"{MAX_REFRESH_AUTHENTICATION_TOKEN_REQUESTS_PER_SECOND}/second",
key_func=get_remote_address,
per_method=True,
)

def __init__(self, authentication_facade: AuthenticationFacade):
self._authentication_facade = authentication_facade

@auth_token_required
Expand All @@ -29,6 +51,16 @@ def post(self):

:return: Response with a new token or an invalid request response
"""
if RefreshAuthenticationToken.limiter is None:
raise RuntimeError("limiter has not been initialized")

try:
with RefreshAuthenticationToken.limiter:
return self._handle_refresh_authentication_token_request()
except RateLimitExceeded:
return make_response("Rate limit exceeded", HTTPStatus.TOO_MANY_REQUESTS)

def _handle_refresh_authentication_token_request(self):
try:
new_token, token_ttl_sec = self._authentication_facade.refresh_user_token(current_user)
response = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def register_resources(
api.add_resource(
RegistrationStatus, *RegistrationStatus.urls, resource_class_args=(authentication_facade,)
)
api.add_resource(Login, *Login.urls, resource_class_args=(authentication_facade,))
api.add_resource(Login, *Login.urls, resource_class_args=(authentication_facade, limiter))
api.add_resource(Logout, *Logout.urls, resource_class_args=(authentication_facade,))

api.add_resource(AgentOTP, *AgentOTP.urls, resource_class_args=(otp_generator, limiter))
api.add_resource(
AgentOTPLogin,
*AgentOTPLogin.urls,
resource_class_args=(authentication_facade,),
resource_class_args=(authentication_facade, limiter),
)
api.add_resource(
RefreshAuthenticationToken,
*RefreshAuthenticationToken.urls,
resource_class_args=(authentication_facade,),
resource_class_args=(authentication_facade, limiter),
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IslandAPIError,
IslandAPIRequestError,
IslandAPIRequestFailedError,
IslandAPIRequestLimitExceededError,
IslandAPITimeoutError,
)
from infection_monkey.island_api_client.http_client import RETRIES, HTTPClient
Expand Down Expand Up @@ -69,6 +70,7 @@ def test_http_client__unsupported_protocol(server):
(401, IslandAPIAuthenticationError),
(403, IslandAPIAuthenticationError),
(400, IslandAPIRequestError),
(429, IslandAPIRequestLimitExceededError),
(501, IslandAPIRequestFailedError),
],
)
Expand Down
Loading