added scenario self stop, agent setup, action masking in PPO

This commit is contained in:
2023-01-14 10:01:37 +01:00
parent 4f1044b87e
commit 692b932302
6 changed files with 50 additions and 27 deletions

View File

@@ -36,6 +36,7 @@ class SB3EconConverter(VecEnv, gym.Env):
def step_wait(self) -> VecEnvStepReturn:
obs,rew,done,info=self.env.step_wait()
self.curr_obs=obs
#flatten obs
f_obs={}
for k,v in obs.items():
@@ -62,11 +63,13 @@ class SB3EconConverter(VecEnv, gym.Env):
done_g[i]=done
c_info[i]["terminal_observation"]=c_obs[i]
c_obs=self.reset()
return np.copy(c_obs),np.copy(c_rew),np.copy(done_g),np.copy(c_info)
def reset(self) -> VecEnvObs:
obs=self.env.reset()
f_obs={}
self.curr_obs=obs
for k,v in obs.items():
f_obs[k]=utils.package(v,*self.packager)
g_obs={}
@@ -84,15 +87,20 @@ class SB3EconConverter(VecEnv, gym.Env):
seeds.append(env.seed(seed + idx))
return seeds
def action_masks(self):
"""Returns action masks for agents and current obs"""
masks=[]
for obs in self.curr_obs:
masks.append(self.curr_obs[obs]["action_mask"])
return masks
def close(self) -> None:
return
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
return getattr(self, attr_name)
@@ -106,8 +114,7 @@ class SB3EconConverter(VecEnv, gym.Env):
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
return getattr(self, method_name)(*method_args, **method_kwargs)