Skip to content

Commit

Permalink
Merge pull request #238 from pfeinsper/wrappers-for-training-2
Browse files Browse the repository at this point in the history
new wrapper + fix retain position wrapper
  • Loading branch information
JorasOliveira authored May 1, 2024
2 parents 4cacfdc + 362d8d1 commit b1bfbe5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
2 changes: 2 additions & 0 deletions DSSE/environment/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .matrix_encode_wrapper import MatrixEncodeWrapper
from .top_n_cells_wrapper import TopNProbsWrapper
from .retain_drone_pos_wrapper import RetainDronePosWrapper
from .all_flatten_wrapper import AllFlattenWrapper


__all__ = [
"AllPositionsWrapper",
"MatrixEncodeWrapper",
"TopNProbsWrapper",
"RetainDronePosWrapper",
"AllFlattenWrapper",
]
48 changes: 48 additions & 0 deletions DSSE/environment/wrappers/all_flatten_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
from pettingzoo.utils.wrappers import BaseParallelWrapper
from gymnasium.spaces import Box
from DSSE import DroneSwarmSearch


class AllFlattenWrapper(BaseParallelWrapper):
"""
Wrapper that modifies the observation space to include the positions of all agents + the flatten matrix.
"""
def __init__(self, env: DroneSwarmSearch):
super().__init__(env)

self.observation_spaces = {
agent: self.observation_space(agent)
for agent in self.env.possible_agents
}

def step(self, actions):
obs, reward, terminated, truncated, infos = self.env.step(actions)
self.flatten_obs(obs)
return obs, reward, terminated, truncated, infos

def flatten_obs(self, obs):
for idx, agent in enumerate(obs.keys()):
agents_positions = np.array(self.env.agents_positions) / (self.env.grid_size - 1)
agents_positions[[0, idx]] = agents_positions[[idx, 0]]
obs[agent] = (
np.concatenate((agents_positions.flatten(), obs[agent][1].flatten()))
)


def reset(self, **kwargs):
obs, infos = self.env.reset(**kwargs)
self.flatten_obs(obs)
return obs, infos

def observation_space(self, agent):
return Box(
low=0,
high=1,
shape=(len(self.env.possible_agents) * 2 + self.env.grid_size * self.env.grid_size, ),
dtype=np.float64,
)




15 changes: 11 additions & 4 deletions DSSE/environment/wrappers/retain_drone_pos_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ def __init__(self, env: DroneSwarmSearch, drone_positions: list):


def reset(self, **kwargs):
options = {
"drone_positions": self.drone_positions
}
obs, infos = self.env.reset(options=options)
opt = kwargs.get("options", {})
if not opt:
options = {
"drones_positions": self.drone_positions
}
kwargs["options"] = options
else:
opt["drones_positions"] = self.drone_positions
kwargs["options"] = opt
obs, infos = self.env.reset(**kwargs)

return obs, infos


Expand Down

0 comments on commit b1bfbe5

Please sign in to comment.