import numpy as np from ai_economist import foundation from stable_baselines3.common.vec_env import vec_frame_stack from stable_baselines3.common.evaluation import evaluate_policy from sb3_contrib.ppo_mask import MaskablePPO import envs import wrapper import resources import pprint from agents import trading_agent from wrapper.base_econ_wrapper import BaseEconWrapper from wrapper.reciever_econ_wrapper import RecieverEconWrapper from wrapper.sb3_econ_converter import SB3EconConverter from tqdm import tqdm import components from stable_baselines3.common.env_checker import check_env from stable_baselines3 import PPO from stable_baselines3.common.vec_env.vec_monitor import VecMonitor from stable_baselines3.common.vec_env.vec_normalize import VecNormalize from sb3_contrib import RecurrentPPO from envs.econ_wrapper import EconVecEnv from stable_baselines3.common.callbacks import BaseCallback import yaml import time from threading import Thread env_config = { # ===== SCENARIO CLASS ===== # Which Scenario class to use: the class's name in the Scenario Registry (foundation.scenarios). # The environment object will be an instance of the Scenario class. 'scenario_name': 'econ', # ===== COMPONENTS ===== # Which components to use (specified as list of ("component_name", {component_kwargs}) tuples). # "component_name" refers to the Component class's name in the Component Registry (foundation.components) # {component_kwargs} is a dictionary of kwargs passed to the Component class # The order in which components reset, step, and generate obs follows their listed order below. 'components': [ # (1) Building houses ('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"],'max_skill_amount_benefit':1.5}), # (2) Trading collectible resources ('ContinuousDoubleAuction', {'max_num_orders': 10}), # (3) Movement and resource collection ('SimpleGather', {}), ('ExternalMarket',{'market_demand':{ 'Gem': 15 }}), ], # ===== SCENARIO CLASS ARGUMENTS ===== # (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment) 'starting_agent_coin': 10, 'fixed_four_skill_and_loc': True, # ===== STANDARD ARGUMENTS ====== # kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment) 'agent_composition': {"BasicMobileAgent": 20,"TradingAgent":5}, # Number of non-planner agents (must be > 1) 'world_size': [5, 5], # [Height, Width] of the env world 'episode_length': 256, # Number of timesteps per episode 'allow_observation_scaling': True, 'dense_log_frequency': 100, 'world_dense_log_frequency':1, 'energy_cost':0, 'energy_warmup_method': "auto", 'energy_warmup_constant': 4000, # In multi-action-mode, the policy selects an action for each action subspace (defined in component code). # Otherwise, the policy selects only 1 action. 'multi_action_mode_agents': False, 'multi_action_mode_planner': False, # When flattening observations, concatenate scalar & vector observations before output. # Otherwise, return observations with minimal processing. 'flatten_observations': False, # When Flattening masks, concatenate each action subspace mask into a single array. # Note: flatten_masks = True is required for masking action logits in the code below. 'flatten_masks': True, } eval_env_config = { # ===== SCENARIO CLASS ===== # Which Scenario class to use: the class's name in the Scenario Registry (foundation.scenarios). # The environment object will be an instance of the Scenario class. 'scenario_name': 'econ', # ===== COMPONENTS ===== # Which components to use (specified as list of ("component_name", {component_kwargs}) tuples). # "component_name" refers to the Component class's name in the Component Registry (foundation.components) # {component_kwargs} is a dictionary of kwargs passed to the Component class # The order in which components reset, step, and generate obs follows their listed order below. 'components': [ # (1) Building houses ('Craft', {'skill_dist': "pareto", 'commodities': ["Gem"],'max_skill_amount_benefit':1.5}), # (2) Trading collectible resources ('ContinuousDoubleAuction', {'max_num_orders': 10}), # (3) Movement and resource collection ('SimpleGather', {}), ('ExternalMarket',{'market_demand':{ 'Gem': 15 }}), ], # ===== SCENARIO CLASS ARGUMENTS ===== # (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment) 'starting_agent_coin': 10, 'fixed_four_skill_and_loc': True, # ===== STANDARD ARGUMENTS ====== # kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment) 'agent_composition': {"BasicMobileAgent": 20,"TradingAgent":5}, # Number of non-planner agents (must be > 1) 'world_size': [1, 1], # [Height, Width] of the env world 'episode_length': 256, # Number of timesteps per episode 'allow_observation_scaling': True, 'dense_log_frequency': 1, 'world_dense_log_frequency':1, 'energy_cost':0, 'energy_warmup_method': "auto", 'energy_warmup_constant': 4000, # In multi-action-mode, the policy selects an action for each action subspace (defined in component code). # Otherwise, the policy selects only 1 action. 'multi_action_mode_agents': False, 'multi_action_mode_planner': False, # When flattening observations, concatenate scalar & vector observations before output. # Otherwise, return observations with minimal processing. 'flatten_observations': False, # When Flattening masks, concatenate each action subspace mask into a single array. # Note: flatten_masks = True is required for masking action logits in the code below. 'flatten_masks': True, } num_frames=5 class TensorboardCallback(BaseCallback): """ Custom callback for plotting additional values in tensorboard. """ def __init__(self,econ, verbose=0): super().__init__(verbose) self.econ=econ self.metrics=econ.scenario_metrics() def _on_step(self) -> bool: # Log scalar value (here a random variable) if econ.world.timestep==0: prev_metrics=self.metrics if self.econ.previous_episode_metrics is None: self.metrics=self.econ.scenario_metrics() else: self.metrics=self.econ.previous_episode_metrics curr_prod=self.metrics["social/productivity"] trend_pord=curr_prod-prev_metrics["social/productivity"] self.logger.record("social/total_productivity", curr_prod) self.logger.record("social/delta_productivity", trend_pord) return True def printMarket(market): for i in range(len(market)): step=market[i] if len(step)>0: print("=== Step {} ===".format(i)) for transaction in step: t=transaction transstring = "({}) {} -> {} | [{}/{}] {} Coins\n".format(t["commodity"],t["seller"],t["buyer"],t["ask"],t["bid"],t["price"]) print(transstring) return "" def printBuilds(builds): for i in range(len(builds)): step=builds[i] if len(step)>0: for build in step: t=build transstring = "({}) Builder: {}, Skill: {}, Income {} ".format(i,t["builder"],t["build_skill"],t["income"]) print(transstring) return "" def printReplay(econ,agentid): worldmaps=["Stone","Wood"] log=econ.previous_episode_dense_log agent=econ.world.agents[agentid] agentid=str(agentid) maxsetp=len(log["states"])-1 for step in range(maxsetp): print() print("=== Step {} ===".format(step)) # state print("--- World ---") world=log['world'][step] for res in worldmaps: print("{}: {}".format(res,world[res][0][0])) print("--- State ---") state=log['states'][step][agentid] pprint.pprint(state) print("--- Action ---") action=log["actions"][step][agentid] if action=={}: print("Action: 0 -> NOOP") else: for k in action: formats="Action: {}({})".format(k,action[k]) print(formats) print("--- Reward ---") reward=log["rewards"][step][agentid] print("Reward: {}".format(reward)) #Setup Env Objects econ=foundation.make_env_instance(**env_config) market=econ.get_component("ContinuousDoubleAuction") action=market.get_n_actions("TradingAgent") baseEconWrapper=BaseEconWrapper(econ) baseEconWrapper.run() time.sleep(0.5) mobileRecieverEconWrapper=RecieverEconWrapper(base_econ=baseEconWrapper,agent_classname="BasicMobileAgent") tradeRecieverEconWrapper=RecieverEconWrapper(base_econ=baseEconWrapper,agent_classname="TradingAgent") sb3_traderConverter=SB3EconConverter(tradeRecieverEconWrapper,econ,"TradingAgent",True) sb3Converter=SB3EconConverter(mobileRecieverEconWrapper,econ,"BasicMobileAgent",True) # attach sb3 wrappers monenv=VecMonitor(venv=sb3Converter,info_keywords=["social/productivity","trend/productivity"]) montraidingenv=VecMonitor(venv=sb3_traderConverter) stackenv_basic=vec_frame_stack.VecFrameStack(venv=monenv,n_stack=num_frames) stackenv_traid=vec_frame_stack.VecFrameStack(venv=montraidingenv,n_stack=num_frames) # Model setup complete # Setup Eval Env econ_eval=foundation.make_env_instance(**eval_env_config) baseEconWrapper_eval=BaseEconWrapper(econ_eval) baseEconWrapper_eval.run() time.sleep(0.5) mobileRecieverEconWrapper_eval=RecieverEconWrapper(base_econ=baseEconWrapper_eval,agent_classname="BasicMobileAgent") tradeRecieverEconWrapper_eval=RecieverEconWrapper(base_econ=baseEconWrapper_eval,agent_classname="TradingAgent") sb3_traderConverter_eval=SB3EconConverter(tradeRecieverEconWrapper_eval,econ_eval,"TradingAgent",False) sb3Converter_eval=SB3EconConverter(mobileRecieverEconWrapper_eval,econ_eval,"BasicMobileAgent",False) # attach sb3 wrappers monenv_eval=VecMonitor(venv=sb3Converter_eval,info_keywords=["social/productivity","trend/productivity"]) montraidingenv_eval=VecMonitor(venv=sb3_traderConverter_eval) stackenv_basic_eval=vec_frame_stack.VecFrameStack(venv=monenv_eval,n_stack=num_frames) stackenv_traid_eval=vec_frame_stack.VecFrameStack(venv=montraidingenv_eval,n_stack=num_frames) obs=monenv.reset() # define training functions def train(model,timesteps, econ_call,process_bar,name,db,index): db[index]=model.learn(total_timesteps=timesteps,progress_bar=process_bar,reset_num_timesteps=False,tb_log_name=name,callback=TensorboardCallback(econ_call)) # prepare training run_number=int(np.random.rand()*100) runname="run_{}".format(run_number) model_db=[None,None] # object for storing model model = MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.5 ,gamma=0.99, learning_rate=1e-5,env=stackenv_basic, seed=300,verbose=1,device="cuda",tensorboard_log="./log") model_trade=MaskablePPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.5 ,gamma=0.99, learning_rate=1e-5,env=stackenv_traid, seed=300,verbose=1,device="cuda",tensorboard_log="./log") n_agents=econ.n_agents total_required_for_episode_basic=len(mobileRecieverEconWrapper.agnet_idx)*env_config['episode_length'] total_required_for_episode_traid=len(tradeRecieverEconWrapper.agnet_idx)*env_config['episode_length'] print("this is run {}".format(runname)) while True: #Train runname="run_{}_{}".format(run_number,"basic") thread_model=Thread(target=train,args=(model,total_required_for_episode_basic*50,econ,True,runname,model_db,0)) runname="run_{}_{}".format(run_number,"trader") thread_model_traid=Thread(target=train,args=(model_trade,total_required_for_episode_traid*50,econ,False,runname,model_db,1)) thread_model.start() thread_model_traid.start() thread_model.join() thread_model_traid.join() #normenv.save("temp-normalizer.ai") model=model_db[0] model_trade=model_db[1] model.save("basic.ai") model_trade.save("trade.ai") ## Run Eval print("### EVAL ###") obs_basic=stackenv_basic_eval.reset() obs_trade=stackenv_traid_eval.reset() done=False for i in tqdm(range(eval_env_config['episode_length'])): #create masks masks_basic=stackenv_basic_eval.action_masks() masks_trade=stackenv_traid_eval.action_masks() # get actions action_basic=model.predict(obs_basic,action_masks=masks_basic) action_trade=model_trade.predict(obs_trade,action_masks=masks_trade) #submit async directly for non blocking operation sb3Converter_eval.step_async(action_basic[0]) sb3_traderConverter_eval.step_async(action_trade[0]) # retieve full results obs_basic,rew_basic,done_e,info=stackenv_basic_eval.step(action_basic[0]) obs_trade,rew_trade,done_e,info=stackenv_traid_eval.step(action_trade[0]) done=done_e[0] market=econ_eval.get_component("ContinuousDoubleAuction") craft=econ_eval.get_component("Craft") # trades=market.get_dense_log() build=craft.get_dense_log() met=econ.previous_episode_metrics printReplay(econ_eval,0) # printMarket(trades) # printBuilds(builds=build) print("social/productivity: {}".format(met["social/productivity"])) print("labor/weighted_cost: {}".format(met["labor/weighted_cost"])) print("labor/warmup_integrator: {}".format(met["labor/warmup_integrator"])) time.sleep(1)