fixing a bit of training

This commit is contained in:
2023-01-15 00:27:37 +01:00
parent fcd25276bb
commit 3a8d10e0b4
2 changed files with 13 additions and 10 deletions

View File

@@ -159,7 +159,7 @@ class ContinuousDoubleAuction(BaseComponent):
"""If agent can submit an ask for resource."""
return (
self.n_orders[resource][agent.idx] < self.max_num_orders
and agent.state["inventory"][resource] > 0
and agent.state["inventory"][resource] >= 1
)
# Core components for this market

21
main.py
View File

@@ -38,7 +38,7 @@ env_config = {
# The order in which components reset, step, and generate obs follows their listed order below.
'components': [
# (1) Building houses
('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"]}),
('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"],'max_skill_amount_benefit':1.5}),
# (2) Trading collectible resources
('ContinuousDoubleAuction', {'max_num_orders': 10}),
# (3) Movement and resource collection
@@ -51,7 +51,7 @@ env_config = {
# ===== SCENARIO CLASS ARGUMENTS =====
# (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment)
'starting_agent_coin': 0,
'starting_agent_coin': 10,
'fixed_four_skill_and_loc': True,
# ===== STANDARD ARGUMENTS ======
@@ -93,7 +93,7 @@ eval_env_config = {
# The order in which components reset, step, and generate obs follows their listed order below.
'components': [
# (1) Building houses
('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"]}),
('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"],'max_skill_amount_benefit':1.5}),
# (2) Trading collectible resources
('ContinuousDoubleAuction', {'max_num_orders': 10}),
# (3) Movement and resource collection
@@ -106,7 +106,7 @@ eval_env_config = {
# ===== SCENARIO CLASS ARGUMENTS =====
# (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment)
'starting_agent_coin': 0,
'starting_agent_coin': 10,
'fixed_four_skill_and_loc': True,
# ===== STANDARD ARGUMENTS ======
@@ -248,11 +248,14 @@ runname="run_{}".format(run_number)
model_db=[] # object for storing model
model = MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=monenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
model_trade=MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=montraidingenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
model = MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.5 ,gamma=0.98, learning_rate=5e-3,env=monenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
model_trade=MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.5 ,gamma=0.98, learning_rate=5e-3,env=montraidingenv, seed=225,verbose=1,device="cuda",tensorboard_log="./log")
n_agents=econ.n_agents
total_required_for_episode=n_agents*env_config['episode_length']
total_required_for_episode_basic=len(mobileRecieverEconWrapper.agnet_idx)*env_config['episode_length']
total_required_for_episode_traid=len(tradeRecieverEconWrapper.agnet_idx)*env_config['episode_length']
print("this is run {}".format(runname))
eval_econ=foundation.make_env_instance(**eval_env_config)
eval_base_econ=BaseEconWrapper(eval_econ)
@@ -267,9 +270,9 @@ while True:
#Train
runname="run_{}_{}".format(run_number,"basic")
thread_model=Thread(target=train,args=(model,total_required_for_episode*10,econ,True,runname,model_db,0))
thread_model=Thread(target=train,args=(model,total_required_for_episode_basic*300,econ,True,runname,model_db,0))
runname="run_{}_{}".format(run_number,"trader")
thread_model_traid=Thread(target=train,args=(model_trade,total_required_for_episode*10,econ,False,runname,model_db,1))
thread_model_traid=Thread(target=train,args=(model_trade,total_required_for_episode_traid*300,econ,False,runname,model_db,1))
thread_model.start()
thread_model_traid.start()