its not braking anymore
This commit is contained in:
@@ -50,6 +50,7 @@ class Craft(BaseComponent):
|
||||
assert len(commodities)>0
|
||||
#setup commodities
|
||||
self.recip_map={}
|
||||
self.commodities=[]
|
||||
for v in commodities:
|
||||
res_class=resource_registry.get(v)
|
||||
res=res_class()
|
||||
@@ -80,11 +81,14 @@ class Craft(BaseComponent):
|
||||
def agent_can_build(self, agent, res):
|
||||
"""Return True if agent can actually build in its current location."""
|
||||
# See if the agent has the resources necessary to complete the action
|
||||
recipe= self.recip_map[res]
|
||||
for resource, cost in recipe.items():
|
||||
if agent.state["inventory"][resource] < cost:
|
||||
return False
|
||||
return True
|
||||
if res in self.recip_map:
|
||||
recipe= self.recip_map[res]
|
||||
for resource, cost in recipe.items():
|
||||
if agent.state["inventory"][resource] < cost:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
@@ -136,9 +140,10 @@ class Craft(BaseComponent):
|
||||
|
||||
# Build! (If you can.)
|
||||
else:
|
||||
action-=1
|
||||
comm=self.commodities[action]
|
||||
|
||||
if self.agent_can_build(agent,comm.craft_recp):
|
||||
if self.agent_can_build(agent,comm.name):
|
||||
# Remove the resources
|
||||
for resource, cost in comm.craft_recp.items():
|
||||
agent.state["inventory"][resource] -= cost
|
||||
@@ -211,7 +216,7 @@ class Craft(BaseComponent):
|
||||
where metric_value is a scalar.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
"""
|
||||
build_stats = {a.idx: {"n_builds": 0} for a in world.agents}
|
||||
for builds in self.builds:
|
||||
for build in builds:
|
||||
@@ -225,8 +230,8 @@ class Craft(BaseComponent):
|
||||
|
||||
num_houses = np.sum(world.maps.get("House") > 0)
|
||||
out_dict["total_builds"] = num_houses
|
||||
|
||||
return out_dict
|
||||
"""
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
@@ -257,7 +262,8 @@ class Craft(BaseComponent):
|
||||
elif self.skill_dist == "pareto":
|
||||
labour = 1
|
||||
sampled_skill = np.random.pareto(2)
|
||||
amount = np.minimum(MSAB, MSAB * sampled_skill)
|
||||
|
||||
amount = 1+np.minimum(MSAB,(MSAB-1) * (sampled_skill) )
|
||||
labour_modifier = 1 - np.minimum(1 - MSLB, (1 - MSLB) * sampled_skill)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -107,8 +107,8 @@ class ExternalMarket(BaseComponent):
|
||||
|
||||
# NO-OP!
|
||||
if action == 0:
|
||||
pass
|
||||
action-=1
|
||||
continue
|
||||
|
||||
res_name=self.action_res_map[action]
|
||||
# Build! (If you can.)
|
||||
|
||||
@@ -163,10 +163,11 @@ class ExternalMarket(BaseComponent):
|
||||
# Mobile agents' build action is masked if they cannot build with their
|
||||
# current location and/or endowment
|
||||
for agent in self.world.agents:
|
||||
mask=[]
|
||||
for res in self.market_demand:
|
||||
mask.append(self.agent_can_sell(agent,res))
|
||||
masks[agent.idx] = mask
|
||||
if agent.name in self.agent_subclasses:
|
||||
mask=[]
|
||||
for res in self.market_demand:
|
||||
mask.append(self.agent_can_sell(agent,res))
|
||||
masks[agent.idx] = mask
|
||||
|
||||
return masks
|
||||
|
||||
@@ -182,7 +183,7 @@ class ExternalMarket(BaseComponent):
|
||||
where metric_value is a scalar.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
"""
|
||||
build_stats = {a.idx: {"n_builds": 0} for a in world.agents}
|
||||
for builds in self.builds:
|
||||
for build in builds:
|
||||
@@ -196,8 +197,8 @@ class ExternalMarket(BaseComponent):
|
||||
|
||||
num_houses = np.sum(world.maps.get("House") > 0)
|
||||
out_dict["total_builds"] = num_houses
|
||||
|
||||
return out_dict
|
||||
"""
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
|
||||
31
main.py
31
main.py
@@ -23,6 +23,7 @@ from envs.econ_wrapper import EconVecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
import yaml
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
env_config = {
|
||||
# ===== SCENARIO CLASS =====
|
||||
@@ -61,7 +62,7 @@ env_config = {
|
||||
'allow_observation_scaling': True,
|
||||
'dense_log_frequency': 100,
|
||||
'world_dense_log_frequency':1,
|
||||
'energy_cost':0.21,
|
||||
'energy_cost':0,
|
||||
'energy_warmup_method': "auto",
|
||||
'energy_warmup_constant': 4000,
|
||||
|
||||
@@ -116,7 +117,7 @@ eval_env_config = {
|
||||
'allow_observation_scaling': True,
|
||||
'dense_log_frequency': 10,
|
||||
'world_dense_log_frequency':1,
|
||||
'energy_cost':0.21,
|
||||
'energy_cost':0,
|
||||
'energy_warmup_method': "auto",
|
||||
'energy_warmup_constant': 4000,
|
||||
|
||||
@@ -223,23 +224,33 @@ baseEconWrapper=BaseEconWrapper(econ)
|
||||
baseEconWrapper.run()
|
||||
mobileRecieverEconWrapper=RecieverEconWrapper(base_econ=baseEconWrapper,agent_classname="BasicMobileAgent")
|
||||
tradeRecieverEconWrapper=RecieverEconWrapper(base_econ=baseEconWrapper,agent_classname="TradingAgent")
|
||||
sb3_traderConverter=SB3EconConverter(tradeRecieverEconWrapper,econ,"TradingAgent")
|
||||
sb3Converter=SB3EconConverter(mobileRecieverEconWrapper,econ,"BasicMobileAgent")
|
||||
#obs=sb3Converter.reset()
|
||||
#vecenv=EconVecEnv(env_config=env_config)
|
||||
|
||||
monenv=VecMonitor(venv=sb3Converter,info_keywords=["social/productivity","trend/productivity"])
|
||||
|
||||
montraidingenv=VecMonitor(venv=sb3_traderConverter)
|
||||
#normenv=VecNormalize(sb3Converter,norm_reward=False,clip_obs=1)
|
||||
#stackenv=vec_frame_stack.VecFrameStack(venv=monenv,n_stack=10)
|
||||
obs=monenv.reset()
|
||||
|
||||
|
||||
# define training functions
|
||||
def train(model,timesteps, econ_call,process_bar,name,db,index):
|
||||
db[index]=model.learn(total_timesteps=timesteps,progress_bar=process_bar,reset_num_timesteps=False,tb_log_name=name,callback=TensorboardCallback(econ_call))
|
||||
|
||||
|
||||
|
||||
runname="run_{}".format(int(np.random.rand()*100))
|
||||
# prepare training
|
||||
run_number=int(np.random.rand()*100)
|
||||
runname="run_{}".format(run_number)
|
||||
model_db=[] # object for storing model
|
||||
|
||||
|
||||
model = MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=monenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
|
||||
model_trade=MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=montraidingenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
|
||||
|
||||
n_agents=econ.n_agents
|
||||
total_required_for_episode=n_agents*env_config['episode_length']
|
||||
print("this is run {}".format(runname))
|
||||
@@ -249,11 +260,21 @@ eval_base_econ.run()
|
||||
eval_mobileRecieverEconWrapper=RecieverEconWrapper(eval_base_econ,"BasicMobileAgent")
|
||||
time.sleep(0.5)
|
||||
eval_sb3_converter=SB3EconConverter(eval_mobileRecieverEconWrapper,eval_econ,"BasicMobileAgent")
|
||||
|
||||
while True:
|
||||
# Create Eval ENV
|
||||
vec_mon_eval=VecMonitor(venv=eval_sb3_converter)
|
||||
#Train
|
||||
model=model.learn(total_timesteps=total_required_for_episode*10,progress_bar=True,reset_num_timesteps=False,tb_log_name=runname,callback=TensorboardCallback(econ=econ))
|
||||
runname="run_{}_{}".format(run_number,"basic")
|
||||
|
||||
thread_model=Thread(target=train,args=(model,total_required_for_episode*10,econ,True,runname,model_db,0))
|
||||
runname="run_{}_{}".format(run_number,"trader")
|
||||
thread_model_traid=Thread(target=train,args=(model_trade,total_required_for_episode*10,econ,False,runname,model_db,1))
|
||||
|
||||
thread_model.start()
|
||||
thread_model_traid.start()
|
||||
thread_model.join()
|
||||
thread_model_traid.join()
|
||||
#normenv.save("temp-normalizer.ai")
|
||||
|
||||
|
||||
|
||||
@@ -149,7 +149,8 @@ class BaseEconWrapper():
|
||||
self.action_edit_lock.acquire() # Start to submit action dict
|
||||
for k,v in actions.items():
|
||||
if k in self.actor_actions.keys():
|
||||
raise Exception("Actor action has already been submitted. {}".format(k))
|
||||
print("Actor action has already been submitted. {}".format(k))
|
||||
continue
|
||||
self.actor_actions[k]=v
|
||||
self.step_notifications[voter_id].clear()
|
||||
self.base_notification.set() #Alert base for action changes
|
||||
|
||||
Reference in New Issue
Block a user