Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OTP validation #3197

Merged
merged 27 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9670255
Island: Add IOTPRepository.update_otp()
shreyamalviya Apr 5, 2023
ca58d5f
Island: Implement MongoOTPRepository.update_otp()
shreyamalviya Apr 5, 2023
a70678e
Island: Add AuthenticationFacade.mark_otp_as_used()
shreyamalviya Apr 5, 2023
e1db463
Island: Add IOTPRepository.otp_is_used()
shreyamalviya Apr 5, 2023
2b83028
Island: Add MongoOTPRepository.otp_is_used()
shreyamalviya Apr 5, 2023
f456b8e
Island: Reduce duplication in MongoOTPRepository
shreyamalviya Apr 5, 2023
d5aa355
Island: Reorder methods in AuthenticationFacade
shreyamalviya Apr 5, 2023
35934c6
Island: Add AuthenticationFacade.otp_is_valid()
shreyamalviya Apr 5, 2023
2e4c808
UT: Add tests for OTP functions in AuthenticationFacade
shreyamalviya Apr 5, 2023
d723c51
Island: Cache object IDs in MongoOTPRepository
mssalvatore Apr 5, 2023
aad4233
Island: Change IOTPRepository.update_otp() -> set_used()
mssalvatore Apr 5, 2023
21b30b6
UT: Add test_set_used__storage_error()
mssalvatore Apr 5, 2023
74ec8ae
Island: Handle known OTP error in MongoOTPRepository.set_used()
mssalvatore Apr 5, 2023
90792ce
Island: Use cached _get_otp_object_id() in set_used()
mssalvatore Apr 5, 2023
b413666
UT: Add test for idempotence of set_used()
mssalvatore Apr 5, 2023
342c560
Island: Use uniform RetrievalError message in _get_otp_document()
mssalvatore Apr 5, 2023
2bd6f21
Island: Fix UnknownRecordError logic in set_used()
mssalvatore Apr 5, 2023
6ee6b71
Island: Set OTP as used if otp_is_valid() is called
mssalvatore Apr 5, 2023
a6a3dca
Island: Make otp_is_valid() more explicit
mssalvatore Apr 5, 2023
ea8afd4
Island: Handle unknown record error in otp_is_valid()
mssalvatore Apr 5, 2023
9d10aba
Island: Rename otp_is_valid() -> authorize_otp()
mssalvatore Apr 5, 2023
989aefd
UT: Fix "called_once_with()" calls
mssalvatore Apr 5, 2023
1195d1c
Island: Fix broken test_generate_otp__saves_otp()
mssalvatore Apr 5, 2023
a6bc055
Island: Remove unneeded AuthenticationFacade.mark_otp_as_used()
mssalvatore Apr 5, 2023
7b93791
Island: Perform real OTP authorization check in AgentOTPLogin
mssalvatore Apr 5, 2023
092d2b0
Island: Don't log out OTPs on failure
mssalvatore Apr 5, 2023
4e77738
Island: Prevent TOCTOU vulnerabilities in authorize_otp()
mssalvatore Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import string
import time
from threading import Lock
from typing import Tuple

from flask_security import UserDatastore

from common.utils.code_utils import secure_generate_random_string
from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic
from monkey_island.cc.models import IslandMode
from monkey_island.cc.repositories import UnknownRecordError
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
from monkey_island.cc.services.authentication_service.token_generator import TokenGenerator

Expand Down Expand Up @@ -39,6 +41,7 @@ def __init__(
self._token_generator = token_generator
self._token_parser = token_parser
self._otp_repository = otp_repository
self._otp_read_lock = Lock()

def needs_registration(self) -> bool:
"""
Expand All @@ -57,6 +60,13 @@ def revoke_all_tokens_for_user(self, user: User):
"""
self._datastore.set_uniquifier(user)

def revoke_all_tokens_for_all_users(self):
"""
Revokes all tokens for all users
"""
for user in User.objects:
self.revoke_all_tokens_for_user(user)

def generate_new_token_pair(self, refresh_token: Token) -> Tuple[Token, Token]:
"""
Generates a new access token and refresh, given a valid refresh token
Expand Down Expand Up @@ -97,12 +107,28 @@ def generate_refresh_token(self, user: User) -> Token:
"""
return self._token_generator.generate_token(user.fs_uniquifier)

def revoke_all_tokens_for_all_users(self):
"""
Revokes all tokens for all users
"""
for user in User.objects:
self.revoke_all_tokens_for_user(user)
def authorize_otp(self, otp: OTP) -> bool:
# SECURITY: This method must not run concurrently, otherwise there could be TOCTOU errors,
# resulting in an OTP being used twice.
with self._otp_read_lock:
try:
otp_is_used = self._otp_repository.otp_is_used(otp)
# When this method is called, that constitutes the OTP being "used".
# Set it as used ASAP.
self._otp_repository.set_used(otp)

if otp_is_used:
return False

if not self._otp_ttl_elapsed(otp):
return True

return False
except UnknownRecordError:
return False

def _otp_ttl_elapsed(self, otp: OTP) -> bool:
return self._otp_repository.get_expiration(otp) < time.monotonic()

def handle_successful_registration(self, username: str, password: str):
self._reset_island_data()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def post(self):
except ArgumentParsingException as err:
return make_response(str(err), HTTPStatus.BAD_REQUEST)

if not self._validate_otp(otp):
if not self._authentication_facade.authorize_otp(otp):
return make_response({}, HTTPStatus.UNAUTHORIZED)

agent_user = register_user(
Expand Down Expand Up @@ -91,6 +91,3 @@ def _get_request_arguments(self, request_data) -> Tuple[AgentID, OTP]:
raise ArgumentParsingException("Could not parse the login request")

return agent_id, otp

def _validate_otp(self, otp: OTP):
return len(otp) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,38 @@ def insert_otp(self, otp: OTP, expiration: float):
:raises StorageError: If an error occurs while attempting to insert the OTP
"""

@abstractmethod
def set_used(self, otp: OTP):
"""
Update an OTP in the repository

:param otp: The OTP set as "used"
:raises StorageError: If an error occurs while attempting to update the OTP
:raises UnknownRecordError: If the OTP is not found in the repository
"""

@abstractmethod
def get_expiration(self, otp: OTP) -> float:
"""
Get the expiration time of a given OTP

:param otp: OTP for which to get the expiration time
:param otp: The OTP for which to get the expiration time
:return: The time that the OTP expires
:raises RetrievalError: If an error occurs while attempting to retrieve the expiration time
:raises UnknownRecordError: If the OTP was not found
"""

@abstractmethod
def otp_is_used(self, otp: OTP) -> bool:
"""
Check if the OTP has already been used

:param otp: The OTP to check
:return: Whether the OTP has been used
:raises RetrievalError: If an error occurs while attempting to retrieve the OTP's usage
:raises UnknownRecordError: If the OTP was not found
"""

@abstractmethod
def reset(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from functools import lru_cache
from typing import Any, Mapping

from bson.objectid import ObjectId
from pymongo import MongoClient

from monkey_island.cc.repositories import (
Expand Down Expand Up @@ -26,25 +30,60 @@ def __init__(
def insert_otp(self, otp: OTP, expiration: float):
try:
encrypted_otp = self._encryptor.encrypt(otp.encode())
self._otp_collection.insert_one({"otp": encrypted_otp, "expiration_time": expiration})
self._otp_collection.insert_one(
{"otp": encrypted_otp, "expiration_time": expiration, "used": False}
)
except Exception as err:
raise StorageError(f"Error inserting OTP: {err}")

def set_used(self, otp: OTP):
try:
otp_id = self._get_otp_object_id(otp)
self._otp_collection.update_one({MONGO_OBJECT_ID_KEY: otp_id}, {"$set": {"used": True}})
except UnknownRecordError as err:
raise err
except Exception as err:
raise StorageError(f"Error updating otp: {err}")
raise StorageError(f"Error updating OTP: {err}")

def get_expiration(self, otp: OTP) -> float:
otp_dict = self._get_otp_document(otp)
return otp_dict["expiration_time"]

def otp_is_used(self, otp: OTP) -> bool:
otp_dict = self._get_otp_document(otp)
return otp_dict["used"]

def _get_otp_document(self, otp: OTP) -> Mapping[str, Any]:
otp_object_id = self._get_otp_object_id(otp)
retrieval_error_message = f"Error retrieving OTP with ID {otp_object_id}"

try:
encrypted_otp = self._encryptor.encrypt(otp.encode())
otp_dict = self._otp_collection.find_one(
{"otp": encrypted_otp}, {MONGO_OBJECT_ID_KEY: False}
{"_id": otp_object_id}, {MONGO_OBJECT_ID_KEY: False}
)
except Exception as err:
raise RetrievalError(f"Error retrieving otp: {err}")
raise RetrievalError(f"{retrieval_error_message}: {err}")

if otp_dict is None:
raise RetrievalError(retrieval_error_message)

return otp_dict

@lru_cache
def _get_otp_object_id(self, otp: OTP) -> ObjectId:
try:
encrypted_otp = self._encryptor.encrypt(otp.encode())
otp_dict = self._otp_collection.find_one({"otp": encrypted_otp}, [MONGO_OBJECT_ID_KEY])
except Exception as err:
raise RetrievalError(f"Error retrieving OTP: {err}")

if otp_dict is None:
raise UnknownRecordError("OTP not found")
return otp_dict["expiration_time"]

return otp_dict[MONGO_OBJECT_ID_KEY]

def reset(self):
try:
self._otp_collection.drop()
except Exception as err:
raise RemovalError(f"Error resetting the repository: {err}")
raise RemovalError(f"Error resetting the OTP repository: {err}")
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ def test_invalid_json(flask_client):
assert response.status_code == HTTPStatus.BAD_REQUEST


def test_unauthorized(agent_otp_login):
def test_unauthorized(mock_authentication_facade, agent_otp_login):
# TODO: Update this test when OTP validation is implemented.
response = agent_otp_login({"agent_id": AGENT_ID, "otp": ""})
mock_authentication_facade.authorize_otp.return_value = False
response = agent_otp_login({"agent_id": AGENT_ID, "otp": "password"})

assert response.status_code == HTTPStatus.UNAUTHORIZED


def test_unexpected_error(mock_authentication_facade, agent_otp_login):
mock_authentication_facade.generate_refresh_token.side_effect = Exception("Unexpected error")
mock_authentication_facade.authorize_otp.side_effect = Exception("Unexpected error")
response = agent_otp_login({"agent_id": AGENT_ID, "otp": "password"})

assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from monkey_island.cc.event_queue import IIslandEventQueue, IslandEventTopic
from monkey_island.cc.models import IslandMode
from monkey_island.cc.repositories import UnknownRecordError
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
from monkey_island.cc.services.authentication_service.authentication_facade import (
OTP_EXPIRATION_TIME,
Expand Down Expand Up @@ -198,7 +199,7 @@ def test_generate_otp__saves_otp(
):
otp = authentication_facade.generate_otp()

assert mock_otp_repository.insert_otp.called_once_with(otp)
assert mock_otp_repository.insert_otp.call_args[0][0] == otp


def test_generate_otp__uses_expected_expiration_time(
Expand All @@ -211,6 +212,53 @@ def test_generate_otp__uses_expected_expiration_time(
assert expiration_time == expected_expiration_time


TIME = "2020-01-01 00:00:00"
TIME_FLOAT = 1577836800.0


@pytest.mark.parametrize(
"otp_is_used_return_value, get_expiration_return_value, otp_is_valid_expected_value",
[
(False, TIME_FLOAT - 1, False), # not used, after expiration time
(True, TIME_FLOAT - 1, False), # used, after expiration time
(False, TIME_FLOAT, True), # not used, at expiration time
(True, TIME_FLOAT, False), # used, at expiration time
(False, TIME_FLOAT + 1, True), # not used, before expiration time
(True, TIME_FLOAT + 1, False), # used, before expiration time
],
)
def test_authorize_otp(
authentication_facade: AuthenticationFacade,
mock_otp_repository: IOTPRepository,
freezer,
otp_is_used_return_value: bool,
get_expiration_return_value: int,
otp_is_valid_expected_value: bool,
):
otp = "secret"

freezer.move_to(TIME)

mock_otp_repository.otp_is_used.return_value = otp_is_used_return_value
mock_otp_repository.get_expiration.return_value = get_expiration_return_value

assert authentication_facade.authorize_otp(otp) == otp_is_valid_expected_value
mock_otp_repository.set_used.assert_called_once()


def test_authorize_otp__unknown_otp(
authentication_facade: AuthenticationFacade,
mock_otp_repository: IOTPRepository,
):
otp = "secret"

mock_otp_repository.otp_is_used.side_effect = UnknownRecordError(f"Unknown otp {otp}")
mock_otp_repository.set_used.side_effect = UnknownRecordError(f"Unknown otp {otp}")
mock_otp_repository.get_expiration.side_effect = UnknownRecordError(f"Unknown otp {otp}")

assert authentication_facade.authorize_otp(otp) is False


# mongomock.MongoClient is not a pymongo.MongoClient. This class allows us to register a
# mongomock.MongoClient as a pymongo.MongoClient with the StubDIContainer.
class MockMongoClient(mongomock.MongoClient, pymongo.MongoClient):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
StorageError,
UnknownRecordError,
)
from monkey_island.cc.server_utils.encryption import ILockableEncryptor
from monkey_island.cc.services.authentication_service.i_otp_repository import IOTPRepository
from monkey_island.cc.services.authentication_service.mongo_otp_repository import MongoOTPRepository

Expand All @@ -28,8 +29,15 @@ class OTP:


@pytest.fixture
def otp_repository(repository_encryptor) -> IOTPRepository:
return MongoOTPRepository(mongomock.MongoClient(), repository_encryptor)
def mongo_client() -> mongomock.MongoClient:
return mongomock.MongoClient()


@pytest.fixture
def otp_repository(
mongo_client: mongomock.MongoClient, repository_encryptor: ILockableEncryptor
) -> IOTPRepository:
return MongoOTPRepository(mongo_client, repository_encryptor)


@pytest.fixture
Expand All @@ -38,6 +46,7 @@ def error_raising_mongo_client() -> mongomock.MongoClient:
client.monkey_island = MagicMock(spec=mongomock.Database)
client.monkey_island.otp = MagicMock(spec=mongomock.Collection)
client.monkey_island.otp.insert_one = MagicMock(side_effect=Exception("insert failed"))
client.monkey_island.otp.update_one = MagicMock(side_effect=Exception("insert failed"))
client.monkey_island.otp.find_one = MagicMock(side_effect=Exception("find failed"))
client.monkey_island.otp.delete_one = MagicMock(side_effect=Exception("delete failed"))
client.monkey_island.otp.drop = MagicMock(side_effect=Exception("drop failed"))
Expand Down Expand Up @@ -111,3 +120,36 @@ def test_reset__deletes_all_otp(otp_repository: IOTPRepository):
def test_reset__raises_removal_error_if_error_occurs(error_raising_otp_repository: IOTPRepository):
with pytest.raises(RemovalError):
error_raising_otp_repository.reset()


def test_set_used(otp_repository: IOTPRepository):
otp = "test_otp"
otp_repository.insert_otp(otp, 1)
assert not otp_repository.otp_is_used(otp)

otp_repository.set_used(otp)
assert otp_repository.otp_is_used(otp)


def test_set_used__storage_error(
error_raising_mongo_client: mongomock.MongoClient, error_raising_otp_repository: IOTPRepository
):
error_raising_mongo_client.monkey_island.otp.find_one.side_effect = None
with pytest.raises(StorageError):
error_raising_otp_repository.set_used("test_otp")


def test_set_used__unknown_record_error(otp_repository: IOTPRepository):
with pytest.raises(UnknownRecordError):
otp_repository.set_used("test_otp")


def test_set_used__idempotent(otp_repository: IOTPRepository):
otp = "test_otp"
otp_repository.insert_otp(otp, 1)

otp_repository.set_used(otp)
otp_repository.set_used(otp)
otp_repository.set_used(otp)

assert otp_repository.otp_is_used(otp)
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_on_terminate_agents_signal__stores_timestamp(
agent_signals_service.on_terminate_agents_signal(terminate_all_agents)

expected_value = Simulation(terminate_signal_time=timestamp)
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)
mock_simulation_repository.save_simulation.assert_called_once_with(expected_value)


def test_on_terminate_agents_signal__updates_timestamp(
Expand All @@ -180,7 +180,7 @@ def test_on_terminate_agents_signal__updates_timestamp(
agent_signals_service.on_terminate_agents_signal(terminate_all_agents)

expected_value = Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=timestamp)
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)
mock_simulation_repository.save_simulation.assert_called_once_with(expected_value)


def test_terminate_signal__not_set_if_agent_registered_before_another(agent_signals_service):
Expand Down