i think i did the converstion correct
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
import base_econ_wrapper
|
||||
import reciever_econ_wrapper
|
||||
import utils
|
||||
from . import(
|
||||
utils
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from ai_economist.foundation.base import base_env
|
||||
from threading import Event, Lock, Thread
|
||||
from queue import Queue
|
||||
class BaseEconVecEnv():
|
||||
class BaseEconWrapper():
|
||||
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
|
||||
|
||||
base_notification=Event() #Notification for Base
|
||||
@@ -33,11 +33,12 @@ class BaseEconVecEnv():
|
||||
|
||||
def register_vote(self):
|
||||
"""Register reciever on base. Base now knows"""
|
||||
self.n_voters+=1
|
||||
|
||||
def run(self):
|
||||
"""Start the base wrapper"""
|
||||
thr=Thread(target=self._run,daemon=True)
|
||||
thr.run()
|
||||
thr.start()
|
||||
return thr
|
||||
|
||||
def _run(self):
|
||||
@@ -46,16 +47,20 @@ class BaseEconVecEnv():
|
||||
self.reset_notification.clear()
|
||||
self.step_notification.clear()
|
||||
|
||||
self.stop_edit_lock.release()
|
||||
if self.stop_edit_lock.locked():
|
||||
self.stop_edit_lock.release()
|
||||
self.stop=False
|
||||
self.action_edit_lock.release()
|
||||
|
||||
if self.action_edit_lock.locked():
|
||||
self.action_edit_lock.release()
|
||||
self.actor_actions={}
|
||||
self.vote_lock.release()
|
||||
if self.vote_lock.locked():
|
||||
self.vote_lock.release()
|
||||
self.reset_notification.clear()
|
||||
self.n_votes_reset=0
|
||||
self.n_votes_step=0
|
||||
|
||||
self.env_data_lock.release()
|
||||
if self.env_data_lock.locked():
|
||||
self.env_data_lock.release()
|
||||
self.obs=None
|
||||
self.rew=None
|
||||
self.done=None
|
||||
|
||||
@@ -6,16 +6,15 @@ from ai_economist.foundation.base import base_env
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
from base_econ_wrapper import BaseEconVecEnv
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
from wrapper.base_econ_wrapper import BaseEconWrapper
|
||||
from ai_economist import foundation
|
||||
|
||||
class RecieverEconVecEnv(gym.Env):
|
||||
class RecieverEconWrapper(gym.Env):
|
||||
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
|
||||
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
|
||||
def __init__(self, base_econ: BaseEconWrapper, agent_classname: str):
|
||||
self.base_econ=base_econ
|
||||
base_econ.register_vote()
|
||||
self.econ=base_econ.env
|
||||
@@ -23,7 +22,7 @@ class RecieverEconVecEnv(gym.Env):
|
||||
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
|
||||
self.idx_to_index={}
|
||||
#create idx to index map
|
||||
for i in range(len(self.agnet_idx)):
|
||||
for i in range(len(self.agnet_idx)):
|
||||
self.idx_to_index[self.agnet_idx[i]]=i
|
||||
first_idx=self.agnet_idx[0]
|
||||
|
||||
@@ -36,13 +35,16 @@ class RecieverEconVecEnv(gym.Env):
|
||||
def _dict_idx_to_index(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.idx_to_index[k]]=v
|
||||
if k in self.idx_to_index:
|
||||
index=self.idx_to_index[k]
|
||||
data_out[index]=v
|
||||
return data_out
|
||||
|
||||
def _dict_index_to_idx(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.agnet_idx[k]]=v
|
||||
idx=self.agnet_idx[k]
|
||||
data_out[idx]=v
|
||||
return data_out
|
||||
|
||||
def step_wait(self):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
import utils
|
||||
from wrapper import utils
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from ai_economist.foundation.base import base_env,base_agent
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
|
||||
class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
@@ -15,7 +16,7 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
self.num_envs=len(obs.keys())
|
||||
#get action and obervation space
|
||||
self.action_space=self._get_action_space_by_class(agentclass)
|
||||
self.packager=utils.build_packager(obs[0],put_in_both=["time"])
|
||||
self.packager=utils.build_packager(obs[0])
|
||||
#flatten obervation of first agent
|
||||
obs0=utils.package(obs[0],*self.packager)
|
||||
self.observation_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(len(obs0),1),dtype=np.float32)
|
||||
@@ -23,8 +24,9 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
|
||||
|
||||
def _get_action_space_by_class(self,agentClass: str):
|
||||
idx=self.econ.world._agent_class_idx_map[agentClass]
|
||||
agent=base_agent.BaseAgent(self.econ.world.agents[idx[0]])
|
||||
idx_list=self.econ.world._agent_class_idx_map[agentClass]
|
||||
idx=int(idx_list[0])
|
||||
agent=self.econ.world.agents[idx]
|
||||
return gym.spaces.Discrete(agent.action_spaces)
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
@@ -34,12 +36,12 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs,rew,done,info=self.env.step_wait()
|
||||
#flatten obs
|
||||
f_obs=utils.package(obs,*self.packager)
|
||||
#convert to flat
|
||||
g_obs={}
|
||||
for k,v in f_obs.items():
|
||||
g_obs[k]=v["flat"]
|
||||
c_obs=utils.convert_econ_to_gym(g_obs)
|
||||
f_obs={}
|
||||
for k,v in obs.items():
|
||||
o=utils.package(v,*self.packager)
|
||||
f_obs[k]=o["flat"]
|
||||
|
||||
c_obs=utils.convert_econ_to_gym(f_obs)
|
||||
c_rew=utils.convert_econ_to_gym(rew)
|
||||
c_done=utils.convert_econ_to_gym(done)
|
||||
c_info=utils.convert_econ_to_gym(info)
|
||||
@@ -53,9 +55,59 @@ class SB3EconConverter(VecEnv, gym.Env):
|
||||
return c_obs,c_rew,c_done,c_info
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs=self.env.reset()
|
||||
f_obs=utils.package(obs,*self.packager)
|
||||
f_obs={}
|
||||
for k,v in obs.items():
|
||||
f_obs[k]=utils.package(v,*self.packager)
|
||||
g_obs={}
|
||||
for k,v in f_obs.items():
|
||||
g_obs[k]=v["flat"]
|
||||
c_obs=utils.convert_econ_to_gym(g_obs)
|
||||
return c_obs
|
||||
return c_obs
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
seeds = []
|
||||
for idx, env in enumerate(self.envs):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
return seeds
|
||||
|
||||
|
||||
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, attr_name) for env_i in target_envs]
|
||||
|
||||
|
||||
|
||||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
||||
"""Set attribute inside vectorized environments (see base class)."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
for env_i in target_envs:
|
||||
setattr(env_i, attr_name, value)
|
||||
|
||||
|
||||
|
||||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
||||
"""Call instance methods of vectorized environments."""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
|
||||
|
||||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
||||
"""Check if worker environments are wrapped with a given wrapper"""
|
||||
target_envs = self._get_target_envs(indices)
|
||||
# Import here to avoid a circular import
|
||||
from stable_baselines3.common import env_util
|
||||
|
||||
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
||||
|
||||
|
||||
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
||||
indices = self._get_indices(indices)
|
||||
return [self.envs[i] for i in indices]
|
||||
@@ -44,26 +44,28 @@ def build_packager(sub_obs, put_in_both=None):
|
||||
|
||||
|
||||
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
|
||||
"""Flattens observation with packagers."""
|
||||
new_obs = {k: obs_dict[k] for k in keep_as_is}
|
||||
if len(flatten) == 1:
|
||||
k = flatten[0]
|
||||
o = obs_dict[k]
|
||||
if wrap_as_list[k]:
|
||||
o = [o]
|
||||
new_obs["flat"] = np.array(o, dtype=np.float32)
|
||||
else:
|
||||
to_flatten = [
|
||||
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
|
||||
]
|
||||
try:
|
||||
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
|
||||
except ValueError:
|
||||
for k, v in zip(flatten, to_flatten):
|
||||
print(k, np.array(v).shape)
|
||||
print(v)
|
||||
print("")
|
||||
raise
|
||||
return new_obs
|
||||
"""Flattens observation with packagers."""
|
||||
|
||||
new_obs = {k: obs_dict[k] for k in keep_as_is}
|
||||
if len(flatten) == 1:
|
||||
k = flatten[0]
|
||||
o = obs_dict[k]
|
||||
if wrap_as_list[k]:
|
||||
o = [o]
|
||||
new_obs["flat"] = np.array(o, dtype=np.float32)
|
||||
else:
|
||||
to_flatten = [
|
||||
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
|
||||
]
|
||||
try:
|
||||
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
|
||||
except ValueError:
|
||||
for k, v in zip(flatten, to_flatten):
|
||||
print(k, np.array(v).shape)
|
||||
print(v)
|
||||
print("")
|
||||
raise
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user