1
1
import logging
2
- import threading
3
2
from copy import copy
4
3
from threading import RLock
5
4
from typing import Any , Dict
6
5
7
- from serpentarium import MultiUsePlugin , PluginLoader , PluginThreadName , SingleUsePlugin
6
+ from serpentarium import PluginThreadName , SingleUsePlugin
8
7
9
8
from common import OperatingSystem
10
9
from common .agent_plugins import AgentPlugin , AgentPluginType
11
- from common .event_queue import IAgentEventPublisher
12
- from common .types import AgentID
13
- from infection_monkey .exploit import IAgentBinaryRepository , IAgentOTPProvider
14
10
from infection_monkey .i_puppet import UnknownPluginError
15
11
from infection_monkey .island_api_client import IIslandAPIClient , IslandAPIRequestError
16
- from infection_monkey .network import TCPPortSelector
17
- from infection_monkey .propagation_credentials_repository import IPropagationCredentialsRepository
12
+ from infection_monkey .plugin .i_plugin_factory import IPluginFactory
18
13
19
14
from . import PluginSourceExtractor
20
15
21
16
logger = logging .getLogger ()
22
17
23
18
24
- # TODO: We should add an ExploiterPluginFactor and pass that to this component instead of passing
25
- # all of the requirements for exploiters.
26
19
class PluginRegistry :
27
20
def __init__ (
28
21
self ,
29
22
operating_system : OperatingSystem ,
30
23
island_api_client : IIslandAPIClient ,
31
24
plugin_source_extractor : PluginSourceExtractor ,
32
- plugin_loader : PluginLoader ,
33
- agent_binary_repository : IAgentBinaryRepository ,
34
- agent_event_publisher : IAgentEventPublisher ,
35
- propagation_credentials_repository : IPropagationCredentialsRepository ,
36
- tcp_port_selector : TCPPortSelector ,
37
- otp_provider : IAgentOTPProvider ,
38
- agent_id : AgentID ,
25
+ plugin_factories : Dict [AgentPluginType , IPluginFactory ],
39
26
):
40
27
"""
41
28
`self._registry` looks like -
@@ -51,14 +38,8 @@ def __init__(
51
38
self ._operating_system = operating_system
52
39
self ._island_api_client = island_api_client
53
40
self ._plugin_source_extractor = plugin_source_extractor
54
- self ._plugin_loader = plugin_loader
55
- self ._agent_binary_repository = agent_binary_repository
56
- self ._agent_event_publisher = agent_event_publisher
57
- self ._propagation_credentials_repository = propagation_credentials_repository
58
- self ._tcp_port_selector = tcp_port_selector
59
- self ._otp_provider = otp_provider
60
-
61
- self ._agent_id = agent_id
41
+ self ._plugin_factories = plugin_factories
42
+
62
43
self ._lock = RLock ()
63
44
64
45
def get_plugin (self , plugin_type : AgentPluginType , plugin_name : str ) -> Any :
@@ -74,18 +55,15 @@ def get_plugin(self, plugin_type: AgentPluginType, plugin_name: str) -> Any:
74
55
def _load_plugin_from_island (self , plugin_name : str , plugin_type : AgentPluginType ):
75
56
agent_plugin = self ._download_plugin_from_island (plugin_name , plugin_type )
76
57
self ._plugin_source_extractor .extract_plugin_source (agent_plugin )
77
- multiprocessing_plugin = MultiprocessingPluginWrapper (
78
- plugin_loader = self ._plugin_loader ,
79
- plugin_name = plugin_name ,
80
- reset_modules_cache = False ,
81
- main_thread_name = PluginThreadName .CALLING_THREAD ,
82
- agent_id = self ._agent_id ,
83
- agent_binary_repository = self ._agent_binary_repository ,
84
- agent_event_publisher = self ._agent_event_publisher ,
85
- propagation_credentials_repository = self ._propagation_credentials_repository ,
86
- tcp_port_selector = self ._tcp_port_selector ,
87
- otp_provider = self ._otp_provider ,
88
- )
58
+
59
+ if plugin_type in self ._plugin_factories :
60
+ factory = self ._plugin_factories [plugin_type ]
61
+ multiprocessing_plugin = factory .create (plugin_name )
62
+ else :
63
+ raise UnknownPluginError (
64
+ "Loading of custom plugins has not been implemented for plugin type "
65
+ f"'{ plugin_type .value } '"
66
+ )
89
67
90
68
self .load_plugin (plugin_type , plugin_name , multiprocessing_plugin )
91
69
@@ -110,39 +88,3 @@ def load_plugin(self, plugin_type: AgentPluginType, plugin_name: str, plugin: ob
110
88
111
89
self ._registry [plugin_type ][plugin_name ] = plugin
112
90
logger .debug (f"Plugin '{ plugin_name } ' loaded" )
113
-
114
-
115
- # NOTE: This should probably get moved to serpentarium.
116
- class MultiprocessingPluginWrapper (MultiUsePlugin ):
117
- """
118
- Wraps a MultiprocessingPlugin so it can be used like a MultiUsePlugin
119
- """
120
-
121
- process_start_lock = threading .Lock ()
122
-
123
- def __init__ (self , * , plugin_loader : PluginLoader , plugin_name : str , ** kwargs ):
124
- self ._plugin_loader = plugin_loader
125
- self ._name = plugin_name
126
- self ._constructor_kwargs = kwargs
127
-
128
- def run (self , ** kwargs ) -> Any :
129
- logger .debug (f"Constructing a new instance of { self ._name } " )
130
- plugin = self ._plugin_loader .load_multiprocessing_plugin (
131
- plugin_name = self ._name , ** self ._constructor_kwargs
132
- )
133
-
134
- # HERE BE DRAGONS! multiprocessing.Process.start() is not thread-safe on Linux when used
135
- # with the "spawn" method. See https://github.com/pyinstaller/pyinstaller/issues/7410 for
136
- # more details.
137
- # UPDATE: This has been resolved in PyInstaller 5.8.0. Consider removing this lock, but
138
- # leaving a comment here for future reference.
139
- with MultiprocessingPluginWrapper .process_start_lock :
140
- logger .debug ("Invoking plugin.start()" )
141
- plugin .start (** kwargs )
142
-
143
- plugin .join ()
144
- return plugin .return_value
145
-
146
- @property
147
- def name (self ) -> str :
148
- return self ._name
0 commit comments