184 lines
5.9 KiB
Python
184 lines
5.9 KiB
Python
from ai_economist.foundation.base import base_env
|
|
from threading import Event, Lock, Thread
|
|
from queue import Queue
|
|
class BaseEconWrapper():
|
|
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
|
|
|
|
base_notification=Event() #Notification for Base
|
|
reset_notification=Event() #Notification for recievers
|
|
step_notifications=[] #Notification for recievers
|
|
|
|
action_edit_lock=Lock()
|
|
actor_actions={}
|
|
|
|
stop_edit_lock=Lock()
|
|
stop=False
|
|
|
|
vote_lock=Lock()
|
|
n_voters=0
|
|
n_votes_reset=0
|
|
|
|
|
|
|
|
# States of Env
|
|
env_data_lock=Lock()
|
|
obs=None
|
|
rew=None
|
|
done=None
|
|
info=None
|
|
n_data_retrieved=0
|
|
|
|
def __init__(self, econ: base_env.BaseEnvironment):
|
|
self.env=econ
|
|
|
|
def register_vote(self):
|
|
"""Register reciever on base. Returns ID of Voter to pass on during blocking"""
|
|
voterID=self.n_voters
|
|
self.n_voters+=1
|
|
self.step_notifications.append(Event())
|
|
return voterID
|
|
|
|
def run(self):
|
|
"""Start the base wrapper"""
|
|
thr=Thread(target=self._run,daemon=True)
|
|
thr.start()
|
|
return thr
|
|
|
|
def _run(self):
|
|
#Reset for run
|
|
self.base_notification.clear()
|
|
self.reset_notification.clear()
|
|
|
|
|
|
if self.stop_edit_lock.locked():
|
|
self.stop_edit_lock.release()
|
|
self.stop=False
|
|
|
|
if self.action_edit_lock.locked():
|
|
self.action_edit_lock.release()
|
|
self.actor_actions={}
|
|
if self.vote_lock.locked():
|
|
self.vote_lock.release()
|
|
self.reset_notification.clear()
|
|
self.n_votes_reset=0
|
|
self.n_votes_step=0
|
|
if self.env_data_lock.locked():
|
|
self.env_data_lock.release()
|
|
self.obs=None
|
|
self.rew=None
|
|
self.done=None
|
|
self.info=None
|
|
#Reseting Env
|
|
self._reset()
|
|
|
|
while True:
|
|
# Wait for notification
|
|
self.base_notification.wait()
|
|
self.base_notification.clear() # Cleard
|
|
#Check for stop signal
|
|
self.stop_edit_lock.acquire()
|
|
if self.stop:
|
|
return
|
|
self.stop_edit_lock.release()
|
|
|
|
#check for reset
|
|
self.vote_lock.acquire() # we might edit votes
|
|
if self.n_voters==self.n_votes_reset:
|
|
## perform reset
|
|
self.n_votes_reset=0
|
|
self._reset()
|
|
self.vote_lock.release()
|
|
|
|
#check for actions
|
|
self.action_edit_lock.acquire()
|
|
enough_votes_for_step=self.env.n_agents==len(self.actor_actions.keys())
|
|
|
|
self.action_edit_lock.release()
|
|
if enough_votes_for_step:
|
|
# we have all the actions -> STEP
|
|
self._step()
|
|
# release actions
|
|
# we are done
|
|
|
|
def stop_env(self):
|
|
"""Stops the wrapper"""
|
|
self.stop_edit_lock.acquire()
|
|
self.stop=True
|
|
self.stop_edit_lock.release()
|
|
self.base_notification.set()
|
|
|
|
|
|
def _reset(self):
|
|
# Aquire Lock
|
|
self.env_data_lock.acquire()
|
|
self.n_votes_reset=0
|
|
self.obs=self.env.reset() #Reset env
|
|
self.rew=None
|
|
self.done=None
|
|
self.info=None
|
|
self.env_data_lock.release() #Release lock
|
|
# Notify for reset
|
|
self.reset_notification.set()
|
|
for v in self.step_notifications:
|
|
v.clear() # unlock stepping
|
|
|
|
def _step(self):
|
|
"""Steping interaly"""
|
|
|
|
self.env_data_lock.acquire()
|
|
self.reset_notification.clear() # reset after first step
|
|
self.obs,self.rew,self.done,self.info=self.env.step(self.actor_actions) # write data
|
|
self.n_data_retrieved=0
|
|
self.action_edit_lock.acquire() # Start to submit action dict
|
|
self.actor_actions={}
|
|
self.action_edit_lock.release() # Start to submit action dict
|
|
self.env_data_lock.release()
|
|
for v in self.step_notifications:
|
|
v.set() # send notifications
|
|
|
|
def _prepare_step(self, voter):
|
|
#prepare base for next step
|
|
|
|
self.step_notifications[voter].clear()
|
|
|
|
|
|
|
|
def reciever_request_step(self, actions,voter_id):
|
|
"""Submits actions to base processing queue. Actions as dict pairing of idx and action id. voter_id retrieved from registration."""
|
|
self._prepare_step(voter_id) # New actions are bening submitted. Prepare base for new step
|
|
self.action_edit_lock.acquire() # Start to submit action dict
|
|
for k,v in actions.items():
|
|
if k in self.actor_actions.keys():
|
|
print("Actor action has already been submitted. {}".format(k))
|
|
continue
|
|
self.actor_actions[k]=v
|
|
self.step_notifications[voter_id].clear()
|
|
self.base_notification.set() #Alert base for action changes
|
|
self.action_edit_lock.release()
|
|
|
|
def reciever_block_step(self,voter_id):
|
|
"""Returns with newest data after step request has been called. voter_id is id from registration. Blocks until all actors have submitted an action"""
|
|
self.step_notifications[voter_id].wait() # new data available
|
|
self.env_data_lock.acquire() # get data
|
|
obs=self.obs
|
|
rew=self.rew
|
|
done=self.done
|
|
info=self.info
|
|
self.env_data_lock.release()
|
|
return obs,rew,done,info
|
|
|
|
def reciever_request_reset(self):
|
|
"""Adds to vote count to reset. If limit is reached reset will occure"""
|
|
self.vote_lock.acquire()
|
|
self.n_votes_reset+=1
|
|
self.vote_lock.release()
|
|
self.base_notification.set() #Alert base for action changes
|
|
|
|
def reciever_block_reset(self):
|
|
"""Called after request will block until reset occures. Returns observations."""
|
|
self.reset_notification.wait()
|
|
self.env_data_lock.acquire()
|
|
obs=self.obs
|
|
self.env_data_lock.release()
|
|
return obs
|