again
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
|
||||
from . import (
|
||||
simple_market,
|
||||
econ_wrapper,
|
||||
base_econ_wrapper
|
||||
econ_wrapper
|
||||
)
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
from ai_economist.foundation.base import base_env
|
||||
from threading import Event, Lock, Thread
|
||||
from queue import Queue
|
||||
class BaseEconVecEnv():
|
||||
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
|
||||
|
||||
base_notification=Event() #Notification for Base
|
||||
reset_notification=Event() #Notification for recievers
|
||||
step_notification=Event() #Notification for recievers
|
||||
|
||||
action_edit_lock=Lock()
|
||||
actor_actions={}
|
||||
|
||||
stop_edit_lock=Lock()
|
||||
stop=False
|
||||
|
||||
vote_lock=Lock()
|
||||
n_voters=0
|
||||
n_votes_reset=0
|
||||
|
||||
|
||||
|
||||
# States of Env
|
||||
env_data_lock=Lock()
|
||||
obs=None
|
||||
rew=None
|
||||
done=None
|
||||
info=None
|
||||
n_data_retrieved=0
|
||||
|
||||
def __init__(self, econ: base_env.BaseEnvironment):
|
||||
self.env=econ
|
||||
|
||||
def register_vote(self):
|
||||
"""Register reciever on base. Base now knows"""
|
||||
|
||||
def run(self):
|
||||
"""Start the base wrapper"""
|
||||
thr=Thread(target=self._run,daemon=True)
|
||||
thr.run()
|
||||
return thr
|
||||
|
||||
def _run(self):
|
||||
#Reset for run
|
||||
self.base_notification.clear()
|
||||
self.reset_notification.clear()
|
||||
self.step_notification.clear()
|
||||
|
||||
self.stop_edit_lock.release()
|
||||
self.stop=False
|
||||
self.action_edit_lock.release()
|
||||
self.actor_actions={}
|
||||
self.vote_lock.release()
|
||||
self.reset_notification.clear()
|
||||
self.n_votes_reset=0
|
||||
self.n_votes_step=0
|
||||
|
||||
self.env_data_lock.release()
|
||||
self.obs=None
|
||||
self.rew=None
|
||||
self.done=None
|
||||
self.info=None
|
||||
#Reseting Env
|
||||
self._reset()
|
||||
|
||||
while True:
|
||||
# Wait for notification
|
||||
self.base_notification.wait()
|
||||
self.base_notification.clear() # Cleard
|
||||
#Check for stop signal
|
||||
self.stop_edit_lock.acquire()
|
||||
if self.stop:
|
||||
return
|
||||
self.stop_edit_lock.release()
|
||||
|
||||
#check for reset
|
||||
self.vote_lock.acquire() # we might edit votes
|
||||
if self.n_voters==self.n_votes_reset:
|
||||
## perform reset
|
||||
self.n_votes_reset=0
|
||||
self._reset()
|
||||
self.vote_lock.release()
|
||||
|
||||
#check for actions
|
||||
self.action_edit_lock.acquire()
|
||||
if self.env.n_agents==len(self.actor_actions.keys) & self.step_notification.is_set()==False:
|
||||
# we have all the actions -> STEP
|
||||
self._step()
|
||||
self.action_edit_lock.release() # release actions
|
||||
# we are done
|
||||
|
||||
def stop_env(self):
|
||||
"""Stops the wrapper"""
|
||||
self.stop_edit_lock.acquire()
|
||||
self.stop=True
|
||||
self.stop_edit_lock.release()
|
||||
self.base_notification.set()
|
||||
|
||||
|
||||
def _reset(self):
|
||||
# Aquire Lock
|
||||
self.env_data_lock.acquire()
|
||||
self.n_votes_reset=0
|
||||
self.obs=self.env.reset() #Reset env
|
||||
self.rew=None
|
||||
self.done=None
|
||||
self.info=None
|
||||
self.env_data_lock.release() #Release lock
|
||||
# Notify for reset
|
||||
self.reset_notification.set()
|
||||
|
||||
def _step(self):
|
||||
"""Steping interaly"""
|
||||
|
||||
self.env_data_lock.acquire()
|
||||
self.reset_notification.clear() # reset after first step
|
||||
self.obs,self.rew,self.done,self.info=self.env.step(self.actor_actions) # write data
|
||||
self.n_data_retrieved=0
|
||||
self.env_data_lock.release()
|
||||
self.action_edit_lock.acquire() # prevent steps until everybody had the chanse to look at it
|
||||
self.step_notification.set() # notify recievers
|
||||
|
||||
def _prepare_step(self):
|
||||
#prepare base for next step
|
||||
self.action_edit_lock.acquire() # we are editing action data
|
||||
if self.step_notification.is_set():
|
||||
self.step_notification.clear()
|
||||
self.actor_actions={}
|
||||
self.action_edit_lock.release()
|
||||
|
||||
def reciever_request_step(self, actions):
|
||||
"""Submits actions to base processing queue. Actions as dict pairing of idx and action id"""
|
||||
self._prepare_step() # New actions are bening submitted. Prepare base for new step
|
||||
self.action_edit_lock.acquire() # Start to submit action dict
|
||||
for k,v in actions:
|
||||
if self.actor_actions[k]!=None:
|
||||
raise Exception("Actor action has already been submitted. {}".format(k))
|
||||
self.actor_actions[k]=v
|
||||
self.base_notification.set() #Alert base for action changes
|
||||
self.action_edit_lock.release()
|
||||
|
||||
def reciever_block_step(self):
|
||||
"""Returns with newest data after step request has been called. Blocks until all actors have submitted an action"""
|
||||
self.step_notification.wait() # new data available
|
||||
self.env_data_lock.acquire() # get data
|
||||
obs=self.obs
|
||||
rew=self.rew
|
||||
done=self.done
|
||||
info=self.info
|
||||
self.n_data_retrieved+=1
|
||||
if self.n_data_retrieved>=self.n_voters:
|
||||
self.action_edit_lock.release() # release the step so that new actions can be submitted
|
||||
|
||||
self.env_data_lock.release()
|
||||
return obs,rew,done,info
|
||||
|
||||
def reciever_request_reset(self):
|
||||
"""Adds to vote count to reset. If limit is reached reset will occure"""
|
||||
self.vote_lock.acquire()
|
||||
self.n_votes_reset+=1
|
||||
self.vote_lock.release()
|
||||
|
||||
def reciever_block_reset(self):
|
||||
"""Called after request will block until reset occures. Returns observations."""
|
||||
self.reset_notification.wait()
|
||||
self.env_data_lock.acquire()
|
||||
obs=self.obs
|
||||
self.env_data_lock.release()
|
||||
return obs
|
||||
@@ -1,67 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
from ai_economist.foundation.base import base_env
|
||||
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
from base_econ_wrapper import BaseEconVecEnv
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
from ai_economist import foundation
|
||||
|
||||
class RecieverEconVecEnv(gym.Env):
|
||||
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
|
||||
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
|
||||
self.base_econ=base_econ
|
||||
base_econ.register_vote()
|
||||
self.econ=base_econ.env
|
||||
self.agent_name=agent_classname
|
||||
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
|
||||
self.idx_to_index={}
|
||||
#create idx to index map
|
||||
for i in range(len(self.agnet_idx)):
|
||||
self.idx_to_index[self.agnet_idx[i]]=i
|
||||
first_idx=self.agnet_idx[0]
|
||||
|
||||
|
||||
def step_async(self, actions: dict) -> None:
|
||||
"""Submittes actions to Env. actions is a dict with idx -> action pair"""
|
||||
data=self._dict_index_to_idx(actions)
|
||||
self.base_econ.reciever_request_step(data)
|
||||
|
||||
def _dict_idx_to_index(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.idx_to_index[k]]=v
|
||||
return data_out
|
||||
|
||||
def _dict_index_to_idx(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.agnet_idx[k]]=v
|
||||
return data_out
|
||||
|
||||
def step_wait(self):
|
||||
#convert to econ actions
|
||||
obs,rew,done,info=self.base_econ.reciever_block_step()
|
||||
c_obs=self._dict_idx_to_index(obs)
|
||||
c_rew=self._dict_idx_to_index(rew)
|
||||
c_done=self._dict_idx_to_index(done)
|
||||
c_info=self._dict_idx_to_index(info)
|
||||
return c_obs,c_rew,c_done,c_info
|
||||
|
||||
|
||||
def reset(self):
|
||||
# env=foundation.make_env_instance(**self.config)
|
||||
# self.env = env
|
||||
self.base_econ.reciever_request_reset()
|
||||
obs =self.base_econ.reciever_block_reset()
|
||||
c_obs=self._dict_idx_to_index(obs)
|
||||
return c_obs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user