Skip to content

Commit

Permalink
adding lates training and test files
Browse files Browse the repository at this point in the history
  • Loading branch information
JorasOliveira committed May 7, 2024
1 parent 62de098 commit 2080cbc
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 60 deletions.
4 changes: 2 additions & 2 deletions DSSE/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def step(self, actions):
if self.timestep >= self.timestep_limit:
# TODO: Check if is really necessary to add rewards_sum into this reward
rewards[agent] = self.reward_scheme.exceed_timestep
if self.rewards_sum[agent] > 0:
rewards[agent] += self.rewards_sum[agent] // 2
# if self.rewards_sum[agent] > 0:
# rewards[agent] += self.rewards_sum[agent] // 2
truncations[agent] = True
terminations[agent] = True
continue
Expand Down
166 changes: 127 additions & 39 deletions descentralized_ppo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pathlib
from DSSE import DroneSwarmSearch
from DSSE.environment.wrappers import TopNProbsWrapper
from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
import ray
from ray import tune
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
Expand All @@ -26,7 +28,7 @@



class MLPModel(TorchModelV2, nn.Module):
class CNNModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space,
Expand All @@ -42,18 +44,52 @@ def __init__(
)
nn.Module.__init__(self)

self.model = nn.Sequential(
nn.Linear(obs_space.shape[0], 512),
nn.ReLU(),
flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
self.cnn = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=(8, 8),
stride=(1, 1),
),
nn.Tanh(),
nn.Conv2d(
in_channels=16,
out_channels=32,
kernel_size=(4, 4),
stride=(1, 1),
),
nn.Tanh(),
nn.Flatten(),
nn.Linear(flatten_size, 256),
nn.Tanh(),
)

self.linear = nn.Sequential(
nn.Linear(obs_space[0].shape[0], 512),
nn.Tanh(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Tanh(),
)

self.join = nn.Sequential(
nn.Linear(256 * 2, 256),
nn.Tanh(),
)

self.policy_fn = nn.Linear(256, num_outputs)
self.value_fn = nn.Linear(256, 1)

def forward(self, input_dict, state, seq_lens):
input_ = input_dict["obs"].float()
value_input = self.model(input_)
input_positions = input_dict["obs"][0].float()
input_matrix = input_dict["obs"][1].float()

input_matrix = input_matrix.unsqueeze(1)
cnn_out = self.cnn(input_matrix)
linear_out = self.linear(input_positions)

value_input = torch.cat((cnn_out, linear_out), dim=1)
value_input = self.join(value_input)

self._value_out = self.value_fn(value_input)
return self.policy_fn(value_input), state
Expand All @@ -62,60 +98,112 @@ def value_function(self):
return self._value_out.flatten()



def env_creator(args):
env = DroneSwarmSearch(
drone_amount=4,
grid_size=20,
dispersion_inc=0.08,
person_initial_position=(10, 10),
grid_size=40,
dispersion_inc=0.1,
person_initial_position=(20, 20),
)
env = TopNProbsWrapper(env, 10)
# env = RetainDronePosWrapper(env, [(10, 0), (0, 10), (10, 19), (19, 10)])
positions = [
(20, 0),
(20, 39),
(0, 20),
(39, 20),
]
env = AllPositionsWrapper(env)
env = RetainDronePosWrapper(env, positions)
return env


if __name__ == "__main__":
ray.init()

env_name = "DSSE"

register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
ModelCatalog.register_custom_model("MLPModel", MLPModel)
ModelCatalog.register_custom_model("CNNModel", CNNModel)

# Policies are called just like the agents (exact 1:1 mapping).
num_agents = 4
policies = {f"drone{i}" for i in range(num_agents)}
# policies = {f"drone{i}": (None, obs_space, act_space, {}) for i in range(num_agents)}
config = {
"env": env_name,
"rollout_fragment_length": "auto",
"num_workers": 14,
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id, _, **kwargs: agent_id),
},
"train_batch_size": 8192,
"lr": 1e-5,
"gamma": 0.9999999,
"lambda": 0.9,
"use_gae": True,
"sgd_minibatch_size": 300,
"num_sgd_iter": 10,
"model": {
"custom_model": "MLPModel",
"_disable_preprocessor_api": True,
},
"framework": "torch",
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "1")),
}

config = (
PPOConfig()
.environment(env=env_name)
.rollouts(num_rollout_workers=12, rollout_fragment_length="auto")
.multi_agent(
policies=policies,
# Exact 1:1 mapping from AgentID to ModuleID.
policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
)
.training(
train_batch_size=8192,
lr=1e-5,
gamma=0.9999999,
lambda_=0.9,
use_gae=True,
# clip_param=0.3,
# grad_clip=None,
entropy_coeff=0.01,
# vf_loss_coeff=0.25,
# vf_clip_param=10,
sgd_minibatch_size=300,
num_sgd_iter=10,
model={
"custom_model": "CNNModel",
"_disable_preprocessor_api": True,
},
)
.rl_module(
model_config_dict={"vf_share_layers": True},
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={p: SingleAgentRLModuleSpec() for p in policies},
),
)
.experimental(_disable_preprocessor_api=True)
.debugging(log_level="ERROR")
.framework(framework="torch")
.resources(num_gpus=1)
)


# config = {
# "env": env_name,
# "rollout_fragment_length": "auto",
# "num_workers": 14,
# "multiagent": {
# "policies": policies,
# "policy_mapping_fn": (lambda agent_id, _, **kwargs: agent_id),
# },
# "train_batch_size": 8192,
# "lr": 1e-5,
# "gamma": 0.9999999,
# "lambda": 0.9,
# "use_gae": True,
# "sgd_minibatch_size": 300,
# "num_sgd_iter": 10,
# "model": {
# "custom_model": "CNNModel",
# "_disable_preprocessor_api": True,
# },
# "framework": "torch",
# "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "1")),
# "rl_module":{
# "model_config_dict":{"vf_share_layers": True},
# "rl_module_spec" : MultiAgentRLModuleSpec( module_specs = {p: SingleAgentRLModuleSpec() for p in policies}),
# },
# "experimental": {"_disable_preprocessor_api": True},
# }


curr_path = pathlib.Path().resolve()
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5_000_000 if not os.environ.get("CI") else 50000},
checkpoint_freq=10,
stop={"timesteps_total": 5_000_000 if not os.environ.get("CI") else 1.7},
checkpoint_freq=20,
storage_path=f"{curr_path}/ray_res/" + env_name,
config=config,
)
149 changes: 149 additions & 0 deletions jj_rllib_ppo_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import pathlib
from DSSE import DroneSwarmSearch
from DSSE.environment.wrappers import RetainDronePosWrapper, AllPositionsWrapper
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn
import torch


class CNNModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space,
act_space,
num_outputs,
model_config,
name,
**kw,
):
print("OBSSPACE: ", obs_space)
TorchModelV2.__init__(
self, obs_space, act_space, num_outputs, model_config, name, **kw
)
nn.Module.__init__(self)

flatten_size = 32 * (obs_space[1].shape[0] - 7 - 3) * (obs_space[1].shape[1] - 7 - 3)
self.cnn = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=(8, 8),
stride=(1, 1),
),
nn.Tanh(),
nn.Conv2d(
in_channels=16,
out_channels=32,
kernel_size=(4, 4),
stride=(1, 1),
),
nn.Tanh(),
nn.Flatten(),
nn.Linear(flatten_size, 256),
nn.Tanh(),
)

self.linear = nn.Sequential(
nn.Linear(obs_space[0].shape[0], 512),
nn.Tanh(),
nn.Linear(512, 256),
nn.Tanh(),
)

self.join = nn.Sequential(
nn.Linear(256 * 2, 256),
nn.Tanh(),
)

self.policy_fn = nn.Linear(256, num_outputs)
self.value_fn = nn.Linear(256, 1)

def forward(self, input_dict, state, seq_lens):
input_positions = input_dict["obs"][0].float()
input_matrix = input_dict["obs"][1].float()

input_matrix = input_matrix.unsqueeze(1)
cnn_out = self.cnn(input_matrix)
linear_out = self.linear(input_positions)

value_input = torch.cat((cnn_out, linear_out), dim=1)
value_input = self.join(value_input)

self._value_out = self.value_fn(value_input)
return self.policy_fn(value_input), state

def value_function(self):
return self._value_out.flatten()


def env_creator(args):
env = DroneSwarmSearch(
drone_amount=4,
grid_size=40,
dispersion_inc=0.1,
person_initial_position=(20, 20),
person_amount=4,
)
positions = [
(20, 0),
(20, 39),
(0, 20),
(39, 20),
]
env = AllPositionsWrapper(env)
env = RetainDronePosWrapper(env, positions)
return env


if __name__ == "__main__":
ray.init()

env_name = "DSSE"

register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
ModelCatalog.register_custom_model("CNNModel", CNNModel)

config = (
PPOConfig()
.environment(env=env_name)
.rollouts(num_rollout_workers=14, rollout_fragment_length="auto")
.training(
train_batch_size=8192,
lr=1e-5,
gamma=0.9999999,
lambda_=0.9,
use_gae=True,
# clip_param=0.3,
# grad_clip=None,
entropy_coeff=0.01,
# vf_loss_coeff=0.25,
# vf_clip_param=10,
sgd_minibatch_size=300,
num_sgd_iter=10,
model={
"custom_model": "CNNModel",
"_disable_preprocessor_api": True,
},
)
.experimental(_disable_preprocessor_api=True)
.debugging(log_level="ERROR")
.framework(framework="torch")
.resources(num_gpus=1)
)

curr_path = pathlib.Path().resolve()
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 10_000_000 if not os.environ.get("CI") else 6.5, "episode_reward_mean": 5000 }, #1.75 * 4}
checkpoint_freq=10,
storage_path=f"{curr_path}/ray_res/" + env_name,
config=config.to_dict(),
)
Loading

0 comments on commit 2080cbc

Please sign in to comment.