it is working
This commit is contained in:
227
envs/econ_wrapper.py
Normal file
227
envs/econ_wrapper.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
from ai_economist.foundation.base import base_env
|
||||
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
from ai_economist import foundation
|
||||
|
||||
class EconVecEnv(VecEnv, gym.Env):
|
||||
"""
|
||||
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
|
||||
Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
|
||||
as the overhead of multiprocess or multithread outweighs the environment computation time.
|
||||
This can also be used for RL methods that
|
||||
require a vectorized environment, but that you want a single environments to train with.
|
||||
|
||||
:param env_fns: a list of functions
|
||||
that return environments to vectorize
|
||||
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
|
||||
"""
|
||||
|
||||
def __init__(self, env_config):
|
||||
##init for init
|
||||
self.config=env_config
|
||||
env=foundation.make_env_instance(**env_config)
|
||||
self.env = env
|
||||
# build spaces
|
||||
obs=env.reset()
|
||||
actions=env.world.agents[0].action_spaces
|
||||
obs1=obs["0"]
|
||||
del obs1["action_mask"]
|
||||
del obs1["time"]
|
||||
self.observation_space=gym.spaces.Box(low=0,high=np.inf,shape=(len(obs1),),dtype=np.float32)
|
||||
self.action_space=gym.spaces.Discrete(actions)
|
||||
|
||||
# count agents
|
||||
self.num_envs=env.world.n_agents
|
||||
|
||||
VecEnv.__init__(self, self.num_envs, self.observation_space, action_space=self.action_space)
|
||||
self.keys, shapes, dtypes = obs_space_info(self.observation_space)
|
||||
|
||||
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
|
||||
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
self.actions = actions
|
||||
|
||||
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
#convert to econ actions
|
||||
r_action={}
|
||||
for ai_idx in range(len(self.actions)):
|
||||
r_action[str(ai_idx)]=self.actions[ai_idx]
|
||||
|
||||
|
||||
obs,rew,done,info = self.env.step(r_action)
|
||||
obs_g=self._convert_econ_obs_to_gym(obs)
|
||||
rew_g=self._convert_econ_to_gym(rew)
|
||||
info_g=self._convert_econ_to_gym(info)
|
||||
#collect metrics
|
||||
prev_metrics=self.metrics
|
||||
self.metrics=self.env.scenario_metrics()
|
||||
curr_prod=self.metrics["social/productivity"]
|
||||
trend_pord=curr_prod-prev_metrics["social/productivity"]
|
||||
|
||||
for k in info_g:
|
||||
k["social/productivity"]=curr_prod
|
||||
k["trend/productivity"]=trend_pord
|
||||
done_g=[False]*self.num_envs
|
||||
done=(done["__all__"])
|
||||
if done:
|
||||
for i in range(self.num_envs):
|
||||
done_g[i]=done
|
||||
info_g[i]["terminal_observation"]=obs_g[i]
|
||||
obs_g=self.reset()
|
||||
|
||||
|
||||
return (np.copy(obs_g), np.copy(rew_g), np.copy(done_g), deepcopy(info_g))
|
||||
# fix with malformed action tensor from sb3 predict method
|
||||
def step_predict(self,actions):
|
||||
return self.step(actions[0])
|
||||
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
seeds = []
|
||||
for idx, env in enumerate(self.envs):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
return seeds
|
||||
|
||||
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
# env=foundation.make_env_instance(**self.config)
|
||||
# self.env = env
|
||||
obs = self.env.reset()
|
||||
self.metrics=self.env.scenario_metrics()
|
||||
obs_g=self._convert_econ_obs_to_gym(obs)
|
||||
|
||||
return obs_g
|
||||
|
||||
|
||||
|
||||
def close(self) -> None:
|
||||
|
||||
self.env.close()
|
||||
|
||||
|
||||
|
||||
def get_images(self) -> Sequence[np.ndarray]:
|
||||
return [env.render(mode="rgb_array") for env in self.envs]
|
||||
|
||||
|
||||
|
||||
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
||||
"""
|
||||
Gym environment rendering. If there are multiple environments then
|
||||
they are tiled together in one image via ``BaseVecEnv.render()``.
|
||||
Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
|
||||
underlying environment.
|
||||
|
||||
Therefore, some arguments such as ``mode`` will have values that are valid
|
||||
only when ``num_envs == 1``.
|
||||
|
||||
:param mode: The rendering type.
|
||||
"""
|
||||
if self.num_envs == 1:
|
||||
return self.envs[0].render(mode=mode)
|
||||
else:
|
||||
return super().render(mode=mode)
|
||||
|
||||
|
||||
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
||||
for key in self.keys:
|
||||
if key is None:
|
||||
self.buf_obs[key][env_idx] = obs
|
||||
else:
|
||||
self.buf_obs[key][env_idx] = obs[key]
|
||||
|
||||
def _obs_from_buf(self) -> VecEnvObs:
|
||||
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
|
||||
|
||||
|
||||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
||||
"""Set attribute inside vectorized environments (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
for env_i in target_envs:
|
||||
setattr(env_i, attr_name, value)
|
||||
|
||||
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
# Import here to avoid a circular import
|
||||
from stable_baselines3.common import env_util
|
||||
|
||||
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
||||
|
||||
|
||||
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
||||
indices = self._get_indices(indices)
|
||||
return [self.envs[i] for i in indices]
|
||||
|
||||
# Convert econ to gym
|
||||
def _convert_econ_to_gym(self, econ):
|
||||
gy=[]
|
||||
del econ["p"]
|
||||
gy=[v for k,v in econ.items()]
|
||||
return gy
|
||||
def _convert_gym_to_acon(self, gy):
|
||||
econ={}
|
||||
for k,v in gy:
|
||||
econ[k]=v
|
||||
return econ
|
||||
def _convert_econ_obs_to_gym(self, econ):
|
||||
gy=[None] * self.num_envs
|
||||
del econ["p"]
|
||||
for k,v in econ.items():
|
||||
|
||||
del v["time"]
|
||||
del v["action_mask"]
|
||||
out=self.extract_dict(v)
|
||||
|
||||
agent_obs=np.array(out)
|
||||
|
||||
gy[int(k)]=agent_obs
|
||||
return np.stack(gy)
|
||||
|
||||
def extract_dict(self,obj):
|
||||
output=[]
|
||||
use_key=isinstance(obj,dict)
|
||||
for v in obj:
|
||||
if use_key:
|
||||
v=obj[v]
|
||||
if isinstance(v,dict):
|
||||
temp=self.extract_dict(v)
|
||||
output.append(temp)
|
||||
else:
|
||||
output.append(v)
|
||||
return output
|
||||
Reference in New Issue
Block a user