This commit is contained in:
2023-01-12 17:48:15 +01:00
parent f945247fd6
commit 03c6341b19
6 changed files with 139 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
from . import (
simple_market,
econ_wrapper,
base_econ_wrapper
econ_wrapper
)

3
wrapper/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
import base_econ_wrapper
import reciever_econ_wrapper
import utils

View File

@@ -54,7 +54,7 @@ class RecieverEconVecEnv(gym.Env):
c_info=self._dict_idx_to_index(info)
return c_obs,c_rew,c_done,c_info
def reset(self):
# env=foundation.make_env_instance(**self.config)
# self.env = env
@@ -63,5 +63,7 @@ class RecieverEconVecEnv(gym.Env):
c_obs=self._dict_idx_to_index(obs)
return c_obs
def step(self, action):
self.step_async(action)
return self.step_wait()

View File

@@ -0,0 +1,61 @@
import gym
import gym.spaces
import numpy as np
import utils
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from ai_economist.foundation.base import base_env,base_agent
class SB3EconConverter(VecEnv, gym.Env):
def __init__(self, env: gym.Env, econ: base_env.BaseEnvironment,agentclass: str):
self.env=env
self.econ=econ
#get observation sample
obs=env.reset()
self.num_envs=len(obs.keys())
#get action and obervation space
self.action_space=self._get_action_space_by_class(agentclass)
self.packager=utils.build_packager(obs[0],put_in_both=["time"])
#flatten obervation of first agent
obs0=utils.package(obs[0],*self.packager)
self.observation_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(len(obs0),1),dtype=np.float32)
super().__init__(self.num_envs, self.observation_space, self.action_space)
def _get_action_space_by_class(self,agentClass: str):
idx=self.econ.world._agent_class_idx_map[agentClass]
agent=base_agent.BaseAgent(self.econ.world.agents[idx[0]])
return gym.spaces.Discrete(agent.action_spaces)
def step_async(self, actions: np.ndarray) -> None:
d_actions=utils.convert_gym_to_econ(actions)
return self.env.step_async(d_actions)
def step_wait(self) -> VecEnvStepReturn:
obs,rew,done,info=self.env.step_wait()
#flatten obs
f_obs=utils.package(obs,*self.packager)
#convert to flat
g_obs={}
for k,v in f_obs.items():
g_obs[k]=v["flat"]
c_obs=utils.convert_econ_to_gym(g_obs)
c_rew=utils.convert_econ_to_gym(rew)
c_done=utils.convert_econ_to_gym(done)
c_info=utils.convert_econ_to_gym(info)
done_g=[False]*self.num_envs
done=(done["__all__"])
if done:
for i in range(self.num_envs):
c_done[i]=done
c_info[i]["terminal_observation"]=c_obs[i]
c_obs=self.reset()
return c_obs,c_rew,c_done,c_info
def reset(self) -> VecEnvObs:
obs=self.env.reset()
f_obs=utils.package(obs,*self.packager)
g_obs={}
for k,v in f_obs.items():
g_obs[k]=v["flat"]
c_obs=utils.convert_econ_to_gym(g_obs)
return c_obs

69
wrapper/utils.py Normal file
View File

@@ -0,0 +1,69 @@
import numpy as np
# Convert econ to gym
def convert_econ_to_gym(econ):
gy=[]
gy=[v for k,v in econ.items()]
return gy
def convert_gym_to_econ(gy):
econ={}
for k,v in gy:
econ[k]=v
return econ
def build_packager(sub_obs, put_in_both=None):
"""
Decides which keys-vals should be flattened or not.
put_in_both: include in both (e.g., 'time')
"""
if put_in_both is None:
put_in_both = []
keep_as_is = []
flatten = []
wrap_as_list = {}
for k, v in sub_obs.items():
if isinstance(v, np.ndarray):
multi_d_array = len(v.shape) > 1
else:
multi_d_array = False
if k == "action_mask" or multi_d_array:
keep_as_is.append(k)
else:
flatten.append(k)
if k in put_in_both:
keep_as_is.append(k)
wrap_as_list[k] = np.isscalar(v)
flatten = sorted(flatten)
return keep_as_is, flatten, wrap_as_list
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
"""Flattens observation with packagers."""
new_obs = {k: obs_dict[k] for k in keep_as_is}
if len(flatten) == 1:
k = flatten[0]
o = obs_dict[k]
if wrap_as_list[k]:
o = [o]
new_obs["flat"] = np.array(o, dtype=np.float32)
else:
to_flatten = [
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
]
try:
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
except ValueError:
for k, v in zip(flatten, to_flatten):
print(k, np.array(v).shape)
print(v)
print("")
raise
return new_obs