added scenario self stop, agent setup, action masking in PPO
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import random
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -38,7 +39,7 @@ class BaseAgent:
|
||||
|
||||
if idx is None:
|
||||
idx = 0
|
||||
|
||||
self.uuid=uuid.uuid4()
|
||||
if multi_action_mode is None:
|
||||
multi_action_mode = False
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ class BaseComponent(ABC):
|
||||
def reset(self):
|
||||
"""Reset any portion of the state managed by this component."""
|
||||
world = self.world
|
||||
self.n_agents = world.n_agents
|
||||
all_agents = world.agents + [world.planner]
|
||||
for agent in all_agents:
|
||||
agent.state.update(self.get_additional_state_fields(agent.name))
|
||||
|
||||
@@ -234,7 +234,7 @@ class BaseEnvironment(ABC):
|
||||
self.num_agents = (
|
||||
n_agents + n_planners
|
||||
) # used in the warp_drive env wrapper (+ 1 for the planner)
|
||||
|
||||
|
||||
# Components must be a tuple/list where each element is either a...
|
||||
# tuple: ('Component Name', {Component kwargs})
|
||||
# dict : {'Component Name': {Component kwargs}}
|
||||
@@ -345,11 +345,11 @@ class BaseEnvironment(ABC):
|
||||
|
||||
self.world.planner.register_inventory(self.resources)
|
||||
self.world.planner.register_components(self._components)
|
||||
self.apply_scenario_config_to_agents()
|
||||
self.reapply_scenario_config_to_agents()
|
||||
|
||||
|
||||
self._completions = 0
|
||||
|
||||
self._finish_episode=False
|
||||
self._last_ep_metrics = None
|
||||
|
||||
# For dense logging
|
||||
@@ -366,7 +366,7 @@ class BaseEnvironment(ABC):
|
||||
# into a single agent with index 'a'
|
||||
self.collate_agent_step_and_reset_data = collate_agent_step_and_reset_data
|
||||
|
||||
def apply_scenario_config_to_agents(self):
|
||||
def reapply_scenario_config_to_agents(self):
|
||||
# Register the components with the agents
|
||||
# to finish setting up their state/action spaces.
|
||||
for agent in self.world.agents:
|
||||
@@ -506,6 +506,8 @@ class BaseEnvironment(ABC):
|
||||
|
||||
# Getters & Setters
|
||||
# -----------------
|
||||
def set_finish_episode(self,done):
|
||||
self._finish_episode=done
|
||||
|
||||
def get_component(self, component_name):
|
||||
"""
|
||||
@@ -909,6 +911,9 @@ class BaseEnvironment(ABC):
|
||||
# Reset the timestep counter
|
||||
self.world.timestep = 0
|
||||
|
||||
# Reset done flag
|
||||
self._finish_episode=False
|
||||
|
||||
# Perform the scenario reset,
|
||||
# which includes resetting the world and agent states
|
||||
self.reset_starting_layout()
|
||||
@@ -1021,7 +1026,7 @@ class BaseEnvironment(ABC):
|
||||
flatten_masks=self._flatten_masks,
|
||||
)
|
||||
rew = self._generate_rewards()
|
||||
done = {"__all__": self.world.timestep >= self._episode_length}
|
||||
done = {"__all__": self.world.timestep >= self._episode_length | self._finish_episode}
|
||||
info = {k: {} for k in obs.keys()}
|
||||
|
||||
if self._dense_log_this_episode:
|
||||
|
||||
34
envs/econ.py
34
envs/econ.py
@@ -23,8 +23,7 @@ class Econ(BaseEnvironment):
|
||||
stone, wood, and water tiles.
|
||||
|
||||
Args:
|
||||
planner_gets_spatial_obs (bool): Whether the planner agent receives spatial
|
||||
observations from the world.
|
||||
action_against_mask_penelty=-1 (int): Reward penelty for performing action against mask
|
||||
full_observability (bool): Whether the mobile agents' spatial observation
|
||||
includes the full world view or is instead an egocentric view.
|
||||
mobile_agent_observation_range (int): If not using full_observability,
|
||||
@@ -64,7 +63,7 @@ class Econ(BaseEnvironment):
|
||||
|
||||
name = "econ"
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
required_entities = ["Wood", "Stone", "Water"]
|
||||
required_entities = ["Wood", "Stone", "Water","Gem_Raw","Gem"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -143,6 +142,7 @@ class Econ(BaseEnvironment):
|
||||
"""
|
||||
self.agent_starting_pos = {agent.idx: [] for agent in self.world.agents}
|
||||
|
||||
self._persist_between_resets=False
|
||||
|
||||
|
||||
self.last_log_loged={}
|
||||
@@ -172,6 +172,7 @@ class Econ(BaseEnvironment):
|
||||
bad=agent.bad_action
|
||||
agent.bad_action=False
|
||||
return bad
|
||||
|
||||
def get_current_optimization_metrics(self):
|
||||
"""
|
||||
Compute optimization metrics based on the current state. Used to compute reward.
|
||||
@@ -239,9 +240,13 @@ class Econ(BaseEnvironment):
|
||||
|
||||
Here, reset to the layout in the fixed layout file
|
||||
"""
|
||||
|
||||
if self._persist_between_resets: # if we only want to modify some values and not accualy reset
|
||||
return
|
||||
|
||||
self.world.maps.clear()
|
||||
|
||||
resources = ["Wood", "Stone"]
|
||||
resources = ["Wood", "Stone","Gem_Raw"]
|
||||
|
||||
for resource in resources:
|
||||
self.world.maps.set_point_add(resource,0,0,1)
|
||||
@@ -255,15 +260,18 @@ class Econ(BaseEnvironment):
|
||||
locations to start. Note: If using fixed_four_skill_and_loc, the starting
|
||||
locations will be overridden in self.additional_reset_steps.
|
||||
"""
|
||||
self.world.clear_agent_locs()
|
||||
if not self._persist_between_resets:
|
||||
self.world.clear_agent_locs()
|
||||
|
||||
for agent in self.world.agents:
|
||||
if not agent.is_setup():
|
||||
|
||||
agent.state["inventory"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["escrow"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["endogenous"] = {k: 0 for k in agent.endogenous.keys()}
|
||||
# Add starting coin
|
||||
agent.state["inventory"]["Coin"] = float(self.starting_agent_coin)
|
||||
if not self._persist_between_resets:
|
||||
agent.set_setup(False) # resets agent states
|
||||
if not agent.is_setup(): # agent has not been setup for scenario
|
||||
agent.state["inventory"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["escrow"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["endogenous"] = {k: 0 for k in agent.endogenous.keys()}
|
||||
# Add starting coin
|
||||
agent.state["inventory"]["Coin"] = float(self.starting_agent_coin)
|
||||
agent.bad_action=False
|
||||
|
||||
self.world.planner.state["inventory"] = {
|
||||
@@ -286,7 +294,7 @@ class Econ(BaseEnvironment):
|
||||
regeneration.
|
||||
"""
|
||||
|
||||
resources = ["Wood", "Stone"]
|
||||
resources = ["Wood", "Stone", "Gem_Raw"]
|
||||
|
||||
for resource in resources:
|
||||
self.world.maps.set_point_add(resource,0,0,20)
|
||||
|
||||
7
main.py
7
main.py
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from ai_economist import foundation
|
||||
from stable_baselines3.common.vec_env import vec_frame_stack
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from sb3_contrib.ppo_mask import MaskablePPO
|
||||
import envs
|
||||
import wrapper
|
||||
from wrapper.base_econ_wrapper import BaseEconWrapper
|
||||
@@ -69,7 +70,7 @@ env_config = {
|
||||
'flatten_observations': False,
|
||||
# When Flattening masks, concatenate each action subspace mask into a single array.
|
||||
# Note: flatten_masks = True is required for masking action logits in the code below.
|
||||
'flatten_masks': False,
|
||||
'flatten_masks': True,
|
||||
}
|
||||
|
||||
|
||||
@@ -121,7 +122,7 @@ eval_env_config = {
|
||||
'flatten_observations': False,
|
||||
# When Flattening masks, concatenate each action subspace mask into a single array.
|
||||
# Note: flatten_masks = True is required for masking action logits in the code below.
|
||||
'flatten_masks': False,
|
||||
'flatten_masks': True,
|
||||
}
|
||||
|
||||
num_frames=2
|
||||
@@ -226,7 +227,7 @@ obs=monenv.reset()
|
||||
|
||||
runname="run_{}".format(int(np.random.rand()*100))
|
||||
|
||||
model = PPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=monenv, verbose=1,device="cuda",tensorboard_log="./log")
|
||||
model = MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=monenv, verbose=1,device="cuda",tensorboard_log="./log")
|
||||
n_agents=econ.n_agents
|
||||
total_required_for_episode=n_agents*env_config['episode_length']
|
||||
print("this is run {}".format(runname))
|
||||
|
||||
@@ -36,6 +36,7 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs,rew,done,info=self.env.step_wait()
|
||||
self.curr_obs=obs
|
||||
#flatten obs
|
||||
f_obs={}
|
||||
for k,v in obs.items():
|
||||
@@ -62,11 +63,13 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
done_g[i]=done
|
||||
c_info[i]["terminal_observation"]=c_obs[i]
|
||||
c_obs=self.reset()
|
||||
|
||||
return np.copy(c_obs),np.copy(c_rew),np.copy(done_g),np.copy(c_info)
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs=self.env.reset()
|
||||
f_obs={}
|
||||
self.curr_obs=obs
|
||||
for k,v in obs.items():
|
||||
f_obs[k]=utils.package(v,*self.packager)
|
||||
g_obs={}
|
||||
@@ -84,15 +87,20 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
return seeds
|
||||
|
||||
|
||||
def action_masks(self):
|
||||
"""Returns action masks for agents and current obs"""
|
||||
masks=[]
|
||||
for obs in self.curr_obs:
|
||||
masks.append(self.curr_obs[obs]["action_mask"])
|
||||
return masks
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
|
||||
return getattr(self, attr_name)
|
||||
|
||||
|
||||
|
||||
@@ -106,8 +114,7 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
return getattr(self, method_name)(*method_args, **method_kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user