This commit is contained in:
2023-01-12 17:48:15 +01:00
parent f945247fd6
commit 03c6341b19
6 changed files with 139 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
from . import (
simple_market,
econ_wrapper,
base_econ_wrapper
econ_wrapper
)

View File

@@ -1,169 +0,0 @@
from ai_economist.foundation.base import base_env
from threading import Event, Lock, Thread
from queue import Queue
class BaseEconVecEnv():
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
base_notification=Event() #Notification for Base
reset_notification=Event() #Notification for recievers
step_notification=Event() #Notification for recievers
action_edit_lock=Lock()
actor_actions={}
stop_edit_lock=Lock()
stop=False
vote_lock=Lock()
n_voters=0
n_votes_reset=0
# States of Env
env_data_lock=Lock()
obs=None
rew=None
done=None
info=None
n_data_retrieved=0
def __init__(self, econ: base_env.BaseEnvironment):
self.env=econ
def register_vote(self):
"""Register reciever on base. Base now knows"""
def run(self):
"""Start the base wrapper"""
thr=Thread(target=self._run,daemon=True)
thr.run()
return thr
def _run(self):
#Reset for run
self.base_notification.clear()
self.reset_notification.clear()
self.step_notification.clear()
self.stop_edit_lock.release()
self.stop=False
self.action_edit_lock.release()
self.actor_actions={}
self.vote_lock.release()
self.reset_notification.clear()
self.n_votes_reset=0
self.n_votes_step=0
self.env_data_lock.release()
self.obs=None
self.rew=None
self.done=None
self.info=None
#Reseting Env
self._reset()
while True:
# Wait for notification
self.base_notification.wait()
self.base_notification.clear() # Cleard
#Check for stop signal
self.stop_edit_lock.acquire()
if self.stop:
return
self.stop_edit_lock.release()
#check for reset
self.vote_lock.acquire() # we might edit votes
if self.n_voters==self.n_votes_reset:
## perform reset
self.n_votes_reset=0
self._reset()
self.vote_lock.release()
#check for actions
self.action_edit_lock.acquire()
if self.env.n_agents==len(self.actor_actions.keys) & self.step_notification.is_set()==False:
# we have all the actions -> STEP
self._step()
self.action_edit_lock.release() # release actions
# we are done
def stop_env(self):
"""Stops the wrapper"""
self.stop_edit_lock.acquire()
self.stop=True
self.stop_edit_lock.release()
self.base_notification.set()
def _reset(self):
# Aquire Lock
self.env_data_lock.acquire()
self.n_votes_reset=0
self.obs=self.env.reset() #Reset env
self.rew=None
self.done=None
self.info=None
self.env_data_lock.release() #Release lock
# Notify for reset
self.reset_notification.set()
def _step(self):
"""Steping interaly"""
self.env_data_lock.acquire()
self.reset_notification.clear() # reset after first step
self.obs,self.rew,self.done,self.info=self.env.step(self.actor_actions) # write data
self.n_data_retrieved=0
self.env_data_lock.release()
self.action_edit_lock.acquire() # prevent steps until everybody had the chanse to look at it
self.step_notification.set() # notify recievers
def _prepare_step(self):
#prepare base for next step
self.action_edit_lock.acquire() # we are editing action data
if self.step_notification.is_set():
self.step_notification.clear()
self.actor_actions={}
self.action_edit_lock.release()
def reciever_request_step(self, actions):
"""Submits actions to base processing queue. Actions as dict pairing of idx and action id"""
self._prepare_step() # New actions are bening submitted. Prepare base for new step
self.action_edit_lock.acquire() # Start to submit action dict
for k,v in actions:
if self.actor_actions[k]!=None:
raise Exception("Actor action has already been submitted. {}".format(k))
self.actor_actions[k]=v
self.base_notification.set() #Alert base for action changes
self.action_edit_lock.release()
def reciever_block_step(self):
"""Returns with newest data after step request has been called. Blocks until all actors have submitted an action"""
self.step_notification.wait() # new data available
self.env_data_lock.acquire() # get data
obs=self.obs
rew=self.rew
done=self.done
info=self.info
self.n_data_retrieved+=1
if self.n_data_retrieved>=self.n_voters:
self.action_edit_lock.release() # release the step so that new actions can be submitted
self.env_data_lock.release()
return obs,rew,done,info
def reciever_request_reset(self):
"""Adds to vote count to reset. If limit is reached reset will occure"""
self.vote_lock.acquire()
self.n_votes_reset+=1
self.vote_lock.release()
def reciever_block_reset(self):
"""Called after request will block until reset occures. Returns observations."""
self.reset_notification.wait()
self.env_data_lock.acquire()
obs=self.obs
self.env_data_lock.release()
return obs

View File

@@ -1,67 +0,0 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Type, Union
from ai_economist.foundation.base import base_env
import gym
import gym.spaces
import numpy as np
from base_econ_wrapper import BaseEconVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
from ai_economist import foundation
class RecieverEconVecEnv(gym.Env):
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
self.base_econ=base_econ
base_econ.register_vote()
self.econ=base_econ.env
self.agent_name=agent_classname
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
self.idx_to_index={}
#create idx to index map
for i in range(len(self.agnet_idx)):
self.idx_to_index[self.agnet_idx[i]]=i
first_idx=self.agnet_idx[0]
def step_async(self, actions: dict) -> None:
"""Submittes actions to Env. actions is a dict with idx -> action pair"""
data=self._dict_index_to_idx(actions)
self.base_econ.reciever_request_step(data)
def _dict_idx_to_index(self, data):
data_out={}
for k,v in data.items():
data_out[self.idx_to_index[k]]=v
return data_out
def _dict_index_to_idx(self, data):
data_out={}
for k,v in data.items():
data_out[self.agnet_idx[k]]=v
return data_out
def step_wait(self):
#convert to econ actions
obs,rew,done,info=self.base_econ.reciever_block_step()
c_obs=self._dict_idx_to_index(obs)
c_rew=self._dict_idx_to_index(rew)
c_done=self._dict_idx_to_index(done)
c_info=self._dict_idx_to_index(info)
return c_obs,c_rew,c_done,c_info
def reset(self):
# env=foundation.make_env_instance(**self.config)
# self.env = env
self.base_econ.reciever_request_reset()
obs =self.base_econ.reciever_block_reset()
c_obs=self._dict_idx_to_index(obs)
return c_obs