again
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
|
||||
from . import (
|
||||
simple_market,
|
||||
econ_wrapper,
|
||||
base_econ_wrapper
|
||||
econ_wrapper
|
||||
)
|
||||
|
||||
3
wrapper/__init__.py
Normal file
3
wrapper/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import base_econ_wrapper
|
||||
import reciever_econ_wrapper
|
||||
import utils
|
||||
@@ -54,7 +54,7 @@ class RecieverEconVecEnv(gym.Env):
|
||||
c_info=self._dict_idx_to_index(info)
|
||||
return c_obs,c_rew,c_done,c_info
|
||||
|
||||
|
||||
|
||||
def reset(self):
|
||||
# env=foundation.make_env_instance(**self.config)
|
||||
# self.env = env
|
||||
@@ -63,5 +63,7 @@ class RecieverEconVecEnv(gym.Env):
|
||||
c_obs=self._dict_idx_to_index(obs)
|
||||
return c_obs
|
||||
|
||||
|
||||
def step(self, action):
|
||||
self.step_async(action)
|
||||
return self.step_wait()
|
||||
|
||||
61
wrapper/sb3_econ_converter.py
Normal file
61
wrapper/sb3_econ_converter.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
import utils
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from ai_economist.foundation.base import base_env,base_agent
|
||||
|
||||
class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
def __init__(self, env: gym.Env, econ: base_env.BaseEnvironment,agentclass: str):
|
||||
self.env=env
|
||||
self.econ=econ
|
||||
#get observation sample
|
||||
obs=env.reset()
|
||||
self.num_envs=len(obs.keys())
|
||||
#get action and obervation space
|
||||
self.action_space=self._get_action_space_by_class(agentclass)
|
||||
self.packager=utils.build_packager(obs[0],put_in_both=["time"])
|
||||
#flatten obervation of first agent
|
||||
obs0=utils.package(obs[0],*self.packager)
|
||||
self.observation_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(len(obs0),1),dtype=np.float32)
|
||||
super().__init__(self.num_envs, self.observation_space, self.action_space)
|
||||
|
||||
|
||||
def _get_action_space_by_class(self,agentClass: str):
|
||||
idx=self.econ.world._agent_class_idx_map[agentClass]
|
||||
agent=base_agent.BaseAgent(self.econ.world.agents[idx[0]])
|
||||
return gym.spaces.Discrete(agent.action_spaces)
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
d_actions=utils.convert_gym_to_econ(actions)
|
||||
return self.env.step_async(d_actions)
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs,rew,done,info=self.env.step_wait()
|
||||
#flatten obs
|
||||
f_obs=utils.package(obs,*self.packager)
|
||||
#convert to flat
|
||||
g_obs={}
|
||||
for k,v in f_obs.items():
|
||||
g_obs[k]=v["flat"]
|
||||
c_obs=utils.convert_econ_to_gym(g_obs)
|
||||
c_rew=utils.convert_econ_to_gym(rew)
|
||||
c_done=utils.convert_econ_to_gym(done)
|
||||
c_info=utils.convert_econ_to_gym(info)
|
||||
done_g=[False]*self.num_envs
|
||||
done=(done["__all__"])
|
||||
if done:
|
||||
for i in range(self.num_envs):
|
||||
c_done[i]=done
|
||||
c_info[i]["terminal_observation"]=c_obs[i]
|
||||
c_obs=self.reset()
|
||||
return c_obs,c_rew,c_done,c_info
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs=self.env.reset()
|
||||
f_obs=utils.package(obs,*self.packager)
|
||||
g_obs={}
|
||||
for k,v in f_obs.items():
|
||||
g_obs[k]=v["flat"]
|
||||
c_obs=utils.convert_econ_to_gym(g_obs)
|
||||
return c_obs
|
||||
69
wrapper/utils.py
Normal file
69
wrapper/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
# Convert econ to gym
|
||||
def convert_econ_to_gym(econ):
|
||||
gy=[]
|
||||
gy=[v for k,v in econ.items()]
|
||||
return gy
|
||||
|
||||
def convert_gym_to_econ(gy):
|
||||
econ={}
|
||||
for k,v in gy:
|
||||
econ[k]=v
|
||||
return econ
|
||||
|
||||
|
||||
|
||||
def build_packager(sub_obs, put_in_both=None):
|
||||
"""
|
||||
Decides which keys-vals should be flattened or not.
|
||||
put_in_both: include in both (e.g., 'time')
|
||||
"""
|
||||
if put_in_both is None:
|
||||
put_in_both = []
|
||||
keep_as_is = []
|
||||
flatten = []
|
||||
wrap_as_list = {}
|
||||
for k, v in sub_obs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
multi_d_array = len(v.shape) > 1
|
||||
else:
|
||||
multi_d_array = False
|
||||
|
||||
if k == "action_mask" or multi_d_array:
|
||||
keep_as_is.append(k)
|
||||
else:
|
||||
flatten.append(k)
|
||||
if k in put_in_both:
|
||||
keep_as_is.append(k)
|
||||
|
||||
wrap_as_list[k] = np.isscalar(v)
|
||||
|
||||
flatten = sorted(flatten)
|
||||
|
||||
return keep_as_is, flatten, wrap_as_list
|
||||
|
||||
|
||||
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
|
||||
"""Flattens observation with packagers."""
|
||||
new_obs = {k: obs_dict[k] for k in keep_as_is}
|
||||
if len(flatten) == 1:
|
||||
k = flatten[0]
|
||||
o = obs_dict[k]
|
||||
if wrap_as_list[k]:
|
||||
o = [o]
|
||||
new_obs["flat"] = np.array(o, dtype=np.float32)
|
||||
else:
|
||||
to_flatten = [
|
||||
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
|
||||
]
|
||||
try:
|
||||
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
|
||||
except ValueError:
|
||||
for k, v in zip(flatten, to_flatten):
|
||||
print(k, np.array(v).shape)
|
||||
print(v)
|
||||
print("")
|
||||
raise
|
||||
return new_obs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user