Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1605 get updated credentials #1721

Merged
merged 5 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monkey/infection_monkey/master/automated_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(
self._control_channel = control_channel

ip_scanner = IPScanner(self._puppet, NUM_SCAN_THREADS)
exploiter = Exploiter(self._puppet, NUM_EXPLOIT_THREADS)
exploiter = Exploiter(
self._puppet, NUM_EXPLOIT_THREADS, self._control_channel.get_credentials_for_propagation
)
self._propagator = Propagator(
self._telemetry_messenger,
ip_scanner,
Expand Down
12 changes: 9 additions & 3 deletions monkey/infection_monkey/master/control_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from infection_monkey.config import WormConfiguration
from infection_monkey.control import ControlClient
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.utils.decorators import request_cache

requests.packages.urllib3.disable_warnings()

logger = logging.getLogger(__name__)

CREDENTIALS_POLL_PERIOD_SEC = 30


class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: str):
Expand Down Expand Up @@ -66,18 +69,21 @@ def get_config(self) -> dict:
) as e:
raise IslandCommunicationError(e)

@request_cache(CREDENTIALS_POLL_PERIOD_SEC)
def get_credentials_for_propagation(self) -> dict:
propagation_credentials_url = (
f"https://{self._control_channel_server}/api/propagation-credentials/{self._agent_id}"
)
try:
response = requests.get( # noqa: DUO123
f"{self._control_channel_server}/api/propagation-credentials/{self._agent_id}",
propagation_credentials_url,
verify=False,
proxies=ControlClient.proxies,
timeout=SHORT_REQUEST_TIMEOUT,
)
response.raise_for_status()

response = json.loads(response.content.decode())["propagation_credentials"]
return response
return json.loads(response.content.decode())["propagation_credentials"]
except (
json.JSONDecodeError,
requests.exceptions.ConnectionError,
Expand Down
25 changes: 22 additions & 3 deletions monkey/infection_monkey/master/exploiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
from queue import Queue
from threading import Event
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Mapping

from infection_monkey.i_puppet import ExploiterResultData, IPuppet
from infection_monkey.model import VictimHost
Expand All @@ -18,9 +18,15 @@


class Exploiter:
def __init__(self, puppet: IPuppet, num_workers: int):
def __init__(
self,
puppet: IPuppet,
num_workers: int,
get_updated_credentials_for_propagation: Callable[[], Mapping],
):
self._puppet = puppet
self._num_workers = num_workers
self._get_updated_credentials_for_propagation = get_updated_credentials_for_propagation

def exploit_hosts(
self,
Expand Down Expand Up @@ -74,6 +80,7 @@ def _run_all_exploiters(
results_callback: Callback,
stop: Event,
):

for exploiter in interruptable_iter(exploiters_to_run, stop):
exploiter_name = exploiter["name"]
exploiter_results = self._run_exploiter(exploiter_name, victim_host, stop)
Expand All @@ -86,7 +93,19 @@ def _run_exploiter(
self, exploiter_name: str, victim_host: VictimHost, stop: Event
) -> ExploiterResultData:
logger.debug(f"Attempting to use {exploiter_name} on {victim_host}")
return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, {}, stop)

credentials = self._get_credentials_for_propagation()
options = {"credentials": credentials}

return self._puppet.exploit_host(exploiter_name, victim_host.ip_addr, options, stop)

def _get_credentials_for_propagation(self) -> Mapping:
try:
return self._get_updated_credentials_for_propagation()
except Exception as ex:
logger.error(f"Error while attempting to retrieve credentials for propagation: {ex}")

return {}


def _all_hosts_have_been_processed(scan_completed: Event, hosts_to_exploit: Queue):
Expand Down
46 changes: 46 additions & 0 deletions monkey/infection_monkey/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import threading
from functools import wraps

from .timer import Timer


def request_cache(ttl: float):
"""
This is a decorator that allows a single response of a function to be cached with an expiration
time (TTL). The first call to the function is executed and the response is cached. Subsequent
calls to the function result in the cached value being returned until the TTL elapses. Once the
TTL elapses, the cache is considered stale and the decorated function will be called, its
response cached, and the TTL reset.

An example usage of this decorator is to wrap a function that makes frequent slow calls to an
external resource, such as an HTTP request to a remote endpoint. If the most up-to-date
information is not need, this decorator provides a simple way to cache the response for a
certain amount of time.

Example:
@request_cache(600)
def raining_outside():
return requests.get(f"https://weather.service.api/check_for_rain/{MY_ZIP_CODE}")

:param ttl: The time-to-live in seconds for the cached return value
:return: The return value of the decorated function, or the cached return value if the TTL has
not elapsed.
"""

def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
with wrapper.lock:
if wrapper.timer.is_expired():
wrapper.cached_value = fn(*args, **kwargs)
wrapper.timer.set(ttl)

return wrapper.cached_value

wrapper.cached_value = None
wrapper.timer = Timer()
wrapper.lock = threading.Lock()

return wrapper

return decorator
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def test_terminate_without_start():
m = AutomatedMaster(None, None, None, None, [])
m = AutomatedMaster(None, None, None, MagicMock(), [])

# Test that call to terminate does not raise exception
m.terminate()
Expand Down
35 changes: 29 additions & 6 deletions monkey/tests/unit_tests/infection_monkey/master/test_exploiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,27 @@ def hosts_to_exploit(hosts):
return q


def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, hosts_to_exploit):
# Set this so that Exploiter() exits once it has processed all victims
scan_completed.set()
CREDENTIALS_FOR_PROPAGATION = {"usernames": ["m0nk3y", "user"], "passwords": ["1234", "pword"]}

e = Exploiter(MockPuppet(), 2)
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)

def get_credentials_for_propagation():
return CREDENTIALS_FOR_PROPAGATION


@pytest.fixture
def run_exploiters(exploiter_config, hosts_to_exploit, callback, scan_completed, stop):
def inner(puppet, num_workers):
# Set this so that Exploiter() exits once it has processed all victims
scan_completed.set()

e = Exploiter(puppet, num_workers, get_credentials_for_propagation)
e.exploit_hosts(exploiter_config, hosts_to_exploit, callback, scan_completed, stop)

return inner


def test_exploiter(callback, hosts, hosts_to_exploit, run_exploiters):
run_exploiters(MockPuppet(), 2)

assert callback.call_count == 5
host_exploit_combos = set()
Expand All @@ -81,6 +96,14 @@ def test_exploiter(exploiter_config, callback, scan_completed, stop, hosts, host
assert ("SSHExploiter", hosts[1]) in host_exploit_combos


def test_credentials_passed_to_exploiter(run_exploiters):
mock_puppet = MagicMock()
run_exploiters(mock_puppet, 1)

for call_args in mock_puppet.exploit_host.call_args_list:
assert call_args[0][2].get("credentials") == CREDENTIALS_FOR_PROPAGATION


def test_stop_after_callback(exploiter_config, callback, scan_completed, stop, hosts_to_exploit):
callback_barrier_count = 2

Expand All @@ -96,7 +119,7 @@ def _callback(*_):

# Intentionally NOT setting scan_completed.set(); _callback() will set stop

e = Exploiter(MockPuppet(), callback_barrier_count + 2)
e = Exploiter(MockPuppet(), callback_barrier_count + 2, get_credentials_for_propagation)
e.exploit_hosts(exploiter_config, hosts_to_exploit, stoppable_callback, scan_completed, stop)

assert stoppable_callback.call_count == 2
78 changes: 78 additions & 0 deletions monkey/tests/unit_tests/infection_monkey/utils/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import time
from unittest.mock import MagicMock

import pytest

from infection_monkey.utils.decorators import request_cache
from infection_monkey.utils.timer import Timer


class MockTimer(Timer):
def __init__(self):
self._time_remaining = 0
self._set_time = 0

def set(self, timeout_sec: float):
self._time_remaining = timeout_sec
self._set_time = timeout_sec

def set_expired(self):
self._time_remaining = 0

@property
def time_remaining(self) -> float:
return self._time_remaining

def reset(self):
"""
Reset the timer without changing the timeout
"""
self._time_remaining = self._set_time


class MockTimerFactory:
def __init__(self):
self._instance = None

def __call__(self):
if self._instance is None:
mt = MockTimer()
self._instance = mt

return self._instance

def reset(self):
self._instance = None


mock_timer_factory = MockTimerFactory()


@pytest.fixture
def mock_timer(monkeypatch):
mock_timer_factory.reset

monkeypatch.setattr("infection_monkey.utils.decorators.Timer", mock_timer_factory)

return mock_timer_factory()


def test_request_cache(mock_timer):
mock_request = MagicMock(side_effect=lambda: time.time())

@request_cache(10)
def make_request():
return mock_request()

t1 = make_request()
t2 = make_request()

assert t1 == t2

mock_timer.set_expired()

t3 = make_request()
t4 = make_request()

assert t3 != t1
assert t3 == t4