Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Island: Set terminate signal for duplicate agents #3058

Merged
merged 4 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion monkey/monkey_island/cc/services/agent_signals_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from common.agent_signals import AgentSignals
from common.types import AgentID
from monkey_island.cc.models import Simulation, TerminateAllAgents
from monkey_island.cc.models import Agent, MachineID, Simulation, TerminateAllAgents
from monkey_island.cc.repositories import IAgentRepository, ISimulationRepository

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,9 +40,24 @@ def get_signals(self, agent_id: AgentID) -> AgentSignals:
if agent.stop_time is not None:
return AgentSignals(terminate=agent.stop_time)

if not self._agent_is_first_to_register(agent):
return AgentSignals(terminate=agent.registration_time)

terminate_timestamp = self._get_terminate_signal_timestamp(agent_id)
return AgentSignals(terminate=terminate_timestamp)

def _agent_is_first_to_register(self, agent: Agent) -> bool:
agents_on_same_machine = self._agents_running_on_machine(agent.machine_id)
first_to_register = min(
agents_on_same_machine, key=lambda a: a.registration_time, default=agent
)
return agent is first_to_register

def _agents_running_on_machine(self, machine_id: MachineID):
return [
a for a in self._agent_repository.get_running_agents() if a.machine_id == machine_id
]

def _get_terminate_signal_timestamp(self, agent_id: AgentID) -> Optional[datetime]:
simulation = self._simulation_repository.get_simulation()
terminate_all_signal_time = simulation.terminate_signal_time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,19 @@
AGENT_3 = Agent(
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252484"),
machine_id=3,
registration_time=301,
start_time=300,
parent_id=AGENT_2.id,
)

DUPLICATE_MACHINE_AGENT = Agent(
id=UUID("0fc9afcb-1902-436b-bd5c-1ad194252485"),
machine_id=3,
registration_time=302,
start_time=299,
parent_id=AGENT_2.id,
)

AGENTS = [AGENT_1, AGENT_2, AGENT_3]

STOPPED_AGENT = Agent(
Expand All @@ -43,12 +52,14 @@
parent_id=AGENT_3.id,
)

ALL_AGENTS = [*AGENTS, STOPPED_AGENT]
ALL_AGENTS = [*AGENTS, DUPLICATE_MACHINE_AGENT, STOPPED_AGENT]


@pytest.fixture
def mock_simulation_repository() -> IAgentRepository:
return MagicMock(spec=ISimulationRepository)
repository = MagicMock(spec=ISimulationRepository)
repository.get_simulation = MagicMock(return_value=Simulation(terminate_signal_time=None))
return repository


@pytest.fixture(scope="session")
Expand All @@ -63,6 +74,9 @@ def get_agent_by_id(agent_id: AgentID) -> Agent:
agent_repository = MagicMock(spec=IAgentRepository)
agent_repository.get_progenitor = MagicMock(return_value=AGENT_1)
agent_repository.get_agent_by_id = MagicMock(side_effect=get_agent_by_id)
agent_repository.get_running_agents = MagicMock(
return_value=[a for a in ALL_AGENTS if a.stop_time is None]
)

return agent_repository

Expand All @@ -77,9 +91,6 @@ def test_stopped_agent(
mock_simulation_repository: ISimulationRepository,
):
agent = STOPPED_AGENT
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=None)
)

signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate == agent.stop_time
Expand All @@ -91,10 +102,6 @@ def test_terminate_is_none(
agent_signals_service: AgentSignalsService,
mock_simulation_repository: ISimulationRepository,
):
mock_simulation_repository.get_simulation = MagicMock(
return_value=Simulation(terminate_signal_time=None)
)

signals = agent_signals_service.get_signals(agent.id)
assert signals.terminate is None

Expand Down Expand Up @@ -153,7 +160,6 @@ def test_on_terminate_agents_signal__stores_timestamp(
timestamp = 100

terminate_all_agents = TerminateAllAgents(timestamp=timestamp)
mock_simulation_repository.get_simulation = MagicMock(return_value=Simulation())
agent_signals_service.on_terminate_agents_signal(terminate_all_agents)

expected_value = Simulation(terminate_signal_time=timestamp)
Expand All @@ -174,3 +180,24 @@ def test_on_terminate_agents_signal__updates_timestamp(

expected_value = Simulation(mode=IslandMode.RANSOMWARE, terminate_signal_time=timestamp)
assert mock_simulation_repository.save_simulation.called_once_with(expected_value)


def test_terminate_signal__not_set_if_agent_registered_before_another(agent_signals_service):
signals = agent_signals_service.get_signals(AGENT_3.id)

assert signals.terminate is None


def test_terminate_signal__set_if_agent_registered_after_another(agent_signals_service):
signals = agent_signals_service.get_signals(DUPLICATE_MACHINE_AGENT.id)

assert signals.terminate is not None


def test_terminate_signal__not_set_if_agent_registered_after_stopped_agent(
agent_signals_service: AgentSignalsService, mock_agent_repository: IAgentRepository
):
mock_agent_repository.get_running_agents = MagicMock(return_value=[AGENT_1, AGENT_2])
signals = agent_signals_service.get_signals(DUPLICATE_MACHINE_AGENT.id)

assert signals.terminate is None