140 lines
5.1 KiB
Python
140 lines
5.1 KiB
Python
import gym
|
|
import gym.spaces
|
|
import numpy as np
|
|
from wrapper import utils
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
|
from ai_economist.foundation.base import base_env,base_agent
|
|
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
|
|
|
class SB3EconConverter(VecEnv, gym.Env):
|
|
|
|
def __init__(self, env: gym.Env, econ: base_env.BaseEnvironment,agentclass: str,auto_reset: bool):
|
|
self.env=env
|
|
self.econ=econ
|
|
#get observation sample
|
|
obs=env.reset()
|
|
self.num_envs=len(obs.keys())
|
|
#get action and obervation space
|
|
self.action_space=self._get_action_space_by_class(agentclass)
|
|
self.packager=utils.build_packager(obs[0])
|
|
#flatten obervation of first agent
|
|
obs0=utils.package(obs[0],*self.packager)
|
|
obs0["flat"]
|
|
self.step_request_send=False
|
|
self.auto_reset=auto_reset
|
|
self.observation_space=gym.spaces.Box(low=0,high=10,shape=(len(obs0["flat"]),),dtype=np.float32)
|
|
super().__init__(self.num_envs, self.observation_space, self.action_space)
|
|
|
|
|
|
def _get_action_space_by_class(self,agentClass: str):
|
|
idx_list=self.econ.world._agent_class_idx_map[agentClass]
|
|
idx=int(idx_list[0])
|
|
agent=self.econ.world.agents[idx]
|
|
return gym.spaces.Discrete(agent.action_spaces)
|
|
|
|
def step_async(self, actions: np.ndarray):
|
|
if self.step_request_send==False:
|
|
self.step_request_send=True
|
|
d_actions=utils.convert_gym_to_econ(actions)
|
|
return self.env.step_async(d_actions)
|
|
|
|
def step_wait(self) -> VecEnvStepReturn:
|
|
obs,rew,done,info=self.env.step_wait()
|
|
self.curr_obs=obs
|
|
#flatten obs
|
|
f_obs={}
|
|
for k,v in obs.items():
|
|
o=utils.package(v,*self.packager)
|
|
f_obs[k]=o["flat"]
|
|
|
|
c_obs=utils.convert_econ_to_gym(f_obs)
|
|
c_rew=utils.convert_econ_to_gym(rew)
|
|
c_done={}
|
|
c_info=utils.convert_econ_to_gym(info)
|
|
prev_metrics=self.metrics
|
|
self.metrics=self.econ.scenario_metrics()
|
|
curr_prod=self.metrics["social/productivity"]
|
|
trend_pord=curr_prod-prev_metrics["social/productivity"]
|
|
|
|
for k in c_info:
|
|
k["social/productivity"]=curr_prod
|
|
k["trend/productivity"]=trend_pord
|
|
|
|
done_g=[False]*self.num_envs
|
|
done=(done["__all__"])
|
|
if done:
|
|
for i in range(self.num_envs):
|
|
done_g[i]=done
|
|
c_info[i]["terminal_observation"]=c_obs[i]
|
|
if self.auto_reset:
|
|
c_obs=self.reset()
|
|
self.step_request_send=False
|
|
return np.copy(c_obs),np.copy(c_rew),np.copy(done_g),np.copy(c_info)
|
|
|
|
def reset(self) -> VecEnvObs:
|
|
obs=self.env.reset()
|
|
self.step_request_send=False
|
|
f_obs={}
|
|
self.curr_obs=obs
|
|
for k,v in obs.items():
|
|
f_obs[k]=utils.package(v,*self.packager)
|
|
g_obs={}
|
|
for k,v in f_obs.items():
|
|
g_obs[k]=v["flat"]
|
|
c_obs=utils.convert_econ_to_gym(g_obs)
|
|
self.metrics=self.econ.scenario_metrics()
|
|
return np.copy(c_obs)
|
|
|
|
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
|
if seed is None:
|
|
seed = np.random.randint(0, 2**32 - 1)
|
|
self.econ.seed(seed)
|
|
seeds=[seed]
|
|
return seeds
|
|
|
|
def action_masks(self):
|
|
"""Returns action masks for agents and current obs"""
|
|
masks=[]
|
|
for obs in self.curr_obs:
|
|
mask=[]
|
|
for num in self.curr_obs[obs]["action_mask"]:
|
|
mask.append(num==1.0)
|
|
masks.append(mask)
|
|
return masks
|
|
|
|
def close(self) -> None:
|
|
return
|
|
|
|
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
|
"""Return attribute from vectorized environment (see base class)."""
|
|
|
|
return getattr(self, attr_name)
|
|
|
|
|
|
|
|
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
|
"""Set attribute inside vectorized environments (see base class)."""
|
|
target_envs = self._get_target_envs(indices)
|
|
for env_i in target_envs:
|
|
setattr(env_i, attr_name, value)
|
|
|
|
|
|
|
|
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
|
"""Call instance methods of vectorized environments."""
|
|
return getattr(self, method_name)(*method_args, **method_kwargs)
|
|
|
|
|
|
|
|
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
|
"""Check if worker environments are wrapped with a given wrapper"""
|
|
target_envs = self._get_target_envs(indices)
|
|
# Import here to avoid a circular import
|
|
from stable_baselines3.common import env_util
|
|
|
|
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
|
|
|
|
|
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
|
indices = self._get_indices(indices)
|
|
return [self.envs[i] for i in indices] |