i think i did the converstion correct

This commit is contained in:
2023-01-12 18:33:06 +01:00
parent 03c6341b19
commit ee444cb56c
9 changed files with 131 additions and 60 deletions

View File

@@ -1,3 +1,3 @@
import base_econ_wrapper
import reciever_econ_wrapper
import utils
from . import(
utils
)

View File

@@ -1,7 +1,7 @@
from ai_economist.foundation.base import base_env
from threading import Event, Lock, Thread
from queue import Queue
class BaseEconVecEnv():
class BaseEconWrapper():
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
base_notification=Event() #Notification for Base
@@ -33,11 +33,12 @@ class BaseEconVecEnv():
def register_vote(self):
"""Register reciever on base. Base now knows"""
self.n_voters+=1
def run(self):
"""Start the base wrapper"""
thr=Thread(target=self._run,daemon=True)
thr.run()
thr.start()
return thr
def _run(self):
@@ -46,16 +47,20 @@ class BaseEconVecEnv():
self.reset_notification.clear()
self.step_notification.clear()
self.stop_edit_lock.release()
if self.stop_edit_lock.locked():
self.stop_edit_lock.release()
self.stop=False
self.action_edit_lock.release()
if self.action_edit_lock.locked():
self.action_edit_lock.release()
self.actor_actions={}
self.vote_lock.release()
if self.vote_lock.locked():
self.vote_lock.release()
self.reset_notification.clear()
self.n_votes_reset=0
self.n_votes_step=0
self.env_data_lock.release()
if self.env_data_lock.locked():
self.env_data_lock.release()
self.obs=None
self.rew=None
self.done=None

View File

@@ -6,16 +6,15 @@ from ai_economist.foundation.base import base_env
import gym
import gym.spaces
import numpy as np
from base_econ_wrapper import BaseEconVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
from wrapper.base_econ_wrapper import BaseEconWrapper
from ai_economist import foundation
class RecieverEconVecEnv(gym.Env):
class RecieverEconWrapper(gym.Env):
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
def __init__(self, base_econ: BaseEconWrapper, agent_classname: str):
self.base_econ=base_econ
base_econ.register_vote()
self.econ=base_econ.env
@@ -23,7 +22,7 @@ class RecieverEconVecEnv(gym.Env):
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
self.idx_to_index={}
#create idx to index map
for i in range(len(self.agnet_idx)):
for i in range(len(self.agnet_idx)):
self.idx_to_index[self.agnet_idx[i]]=i
first_idx=self.agnet_idx[0]
@@ -36,13 +35,16 @@ class RecieverEconVecEnv(gym.Env):
def _dict_idx_to_index(self, data):
data_out={}
for k,v in data.items():
data_out[self.idx_to_index[k]]=v
if k in self.idx_to_index:
index=self.idx_to_index[k]
data_out[index]=v
return data_out
def _dict_index_to_idx(self, data):
data_out={}
for k,v in data.items():
data_out[self.agnet_idx[k]]=v
idx=self.agnet_idx[k]
data_out[idx]=v
return data_out
def step_wait(self):

View File

@@ -1,9 +1,10 @@
import gym
import gym.spaces
import numpy as np
import utils
from wrapper import utils
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from ai_economist.foundation.base import base_env,base_agent
from typing import Any, Callable, List, Optional, Sequence, Type, Union
class SB3EconConverter(VecEnv, gym.Env):
@@ -15,7 +16,7 @@ class SB3EconConverter(VecEnv, gym.Env):
self.num_envs=len(obs.keys())
#get action and obervation space
self.action_space=self._get_action_space_by_class(agentclass)
self.packager=utils.build_packager(obs[0],put_in_both=["time"])
self.packager=utils.build_packager(obs[0])
#flatten obervation of first agent
obs0=utils.package(obs[0],*self.packager)
self.observation_space=gym.spaces.Box(low=-np.inf,high=np.inf,shape=(len(obs0),1),dtype=np.float32)
@@ -23,8 +24,9 @@ class SB3EconConverter(VecEnv, gym.Env):
def _get_action_space_by_class(self,agentClass: str):
idx=self.econ.world._agent_class_idx_map[agentClass]
agent=base_agent.BaseAgent(self.econ.world.agents[idx[0]])
idx_list=self.econ.world._agent_class_idx_map[agentClass]
idx=int(idx_list[0])
agent=self.econ.world.agents[idx]
return gym.spaces.Discrete(agent.action_spaces)
def step_async(self, actions: np.ndarray) -> None:
@@ -34,12 +36,12 @@ class SB3EconConverter(VecEnv, gym.Env):
def step_wait(self) -> VecEnvStepReturn:
obs,rew,done,info=self.env.step_wait()
#flatten obs
f_obs=utils.package(obs,*self.packager)
#convert to flat
g_obs={}
for k,v in f_obs.items():
g_obs[k]=v["flat"]
c_obs=utils.convert_econ_to_gym(g_obs)
f_obs={}
for k,v in obs.items():
o=utils.package(v,*self.packager)
f_obs[k]=o["flat"]
c_obs=utils.convert_econ_to_gym(f_obs)
c_rew=utils.convert_econ_to_gym(rew)
c_done=utils.convert_econ_to_gym(done)
c_info=utils.convert_econ_to_gym(info)
@@ -53,9 +55,59 @@ class SB3EconConverter(VecEnv, gym.Env):
return c_obs,c_rew,c_done,c_info
def reset(self) -> VecEnvObs:
obs=self.env.reset()
f_obs=utils.package(obs,*self.packager)
f_obs={}
for k,v in obs.items():
f_obs[k]=utils.package(v,*self.packager)
g_obs={}
for k,v in f_obs.items():
g_obs[k]=v["flat"]
c_obs=utils.convert_econ_to_gym(g_obs)
return c_obs
return c_obs
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
seeds.append(env.seed(seed + idx))
return seeds
def close(self) -> None:
return
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View File

@@ -44,26 +44,28 @@ def build_packager(sub_obs, put_in_both=None):
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
"""Flattens observation with packagers."""
new_obs = {k: obs_dict[k] for k in keep_as_is}
if len(flatten) == 1:
k = flatten[0]
o = obs_dict[k]
if wrap_as_list[k]:
o = [o]
new_obs["flat"] = np.array(o, dtype=np.float32)
else:
to_flatten = [
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
]
try:
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
except ValueError:
for k, v in zip(flatten, to_flatten):
print(k, np.array(v).shape)
print(v)
print("")
raise
return new_obs
"""Flattens observation with packagers."""
new_obs = {k: obs_dict[k] for k in keep_as_is}
if len(flatten) == 1:
k = flatten[0]
o = obs_dict[k]
if wrap_as_list[k]:
o = [o]
new_obs["flat"] = np.array(o, dtype=np.float32)
else:
to_flatten = [
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
]
try:
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
except ValueError:
for k, v in zip(flatten, to_flatten):
print(k, np.array(v).shape)
print(v)
print("")
raise
return new_obs