added scenario self stop, agent setup, action masking in PPO
This commit is contained in:
@@ -36,6 +36,7 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs,rew,done,info=self.env.step_wait()
|
||||
self.curr_obs=obs
|
||||
#flatten obs
|
||||
f_obs={}
|
||||
for k,v in obs.items():
|
||||
@@ -62,11 +63,13 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
done_g[i]=done
|
||||
c_info[i]["terminal_observation"]=c_obs[i]
|
||||
c_obs=self.reset()
|
||||
|
||||
return np.copy(c_obs),np.copy(c_rew),np.copy(done_g),np.copy(c_info)
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs=self.env.reset()
|
||||
f_obs={}
|
||||
self.curr_obs=obs
|
||||
for k,v in obs.items():
|
||||
f_obs[k]=utils.package(v,*self.packager)
|
||||
g_obs={}
|
||||
@@ -84,15 +87,20 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
return seeds
|
||||
|
||||
|
||||
def action_masks(self):
|
||||
"""Returns action masks for agents and current obs"""
|
||||
masks=[]
|
||||
for obs in self.curr_obs:
|
||||
masks.append(self.curr_obs[obs]["action_mask"])
|
||||
return masks
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
|
||||
return getattr(self, attr_name)
|
||||
|
||||
|
||||
|
||||
@@ -106,8 +114,7 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
return getattr(self, method_name)(*method_args, **method_kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user