284 lines
10 KiB
Plaintext
284 lines
10 KiB
Plaintext
from ai_economist import foundation
|
|
import numpy as np
|
|
from stable_baselines3.common.vec_env import vec_frame_stack
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
import envs
|
|
from tqdm import tqdm
|
|
import components
|
|
from stable_baselines3.common.env_checker import check_env
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
|
|
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
|
|
from sb3_contrib import RecurrentPPO
|
|
from envs.econ_wrapper import EconVecEnv
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
import yaml
|
|
import time
|
|
|
|
env_config = {
|
|
# ===== SCENARIO CLASS =====
|
|
# Which Scenario class to use: the class's name in the Scenario Registry (foundation.scenarios).
|
|
# The environment object will be an instance of the Scenario class.
|
|
'scenario_name': 'simple_market',
|
|
|
|
# ===== COMPONENTS =====
|
|
# Which components to use (specified as list of ("component_name", {component_kwargs}) tuples).
|
|
# "component_name" refers to the Component class's name in the Component Registry (foundation.components)
|
|
# {component_kwargs} is a dictionary of kwargs passed to the Component class
|
|
# The order in which components reset, step, and generate obs follows their listed order below.
|
|
'components': [
|
|
# (1) Building houses
|
|
('SimpleCraft', {'skill_dist': "none", 'payment_max_skill_multiplier': 3}),
|
|
# (2) Trading collectible resources
|
|
#('ContinuousDoubleAuction', {'max_num_orders': 10}),
|
|
# (3) Movement and resource collection
|
|
('SimpleGather', {}),
|
|
],
|
|
|
|
# ===== SCENARIO CLASS ARGUMENTS =====
|
|
# (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment)
|
|
|
|
'starting_agent_coin': 0,
|
|
'fixed_four_skill_and_loc': True,
|
|
|
|
# ===== STANDARD ARGUMENTS ======
|
|
# kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment)
|
|
'n_agents': 20, # Number of non-planner agents (must be > 1)
|
|
'world_size': [1, 1], # [Height, Width] of the env world
|
|
'episode_length': 256, # Number of timesteps per episode
|
|
'allow_observation_scaling': True,
|
|
'dense_log_frequency': 100,
|
|
'world_dense_log_frequency':1,
|
|
'energy_cost':0,
|
|
'energy_warmup_method': "auto",
|
|
'energy_warmup_constant': 0,
|
|
|
|
# In multi-action-mode, the policy selects an action for each action subspace (defined in component code).
|
|
# Otherwise, the policy selects only 1 action.
|
|
'multi_action_mode_agents': False,
|
|
'multi_action_mode_planner': False,
|
|
|
|
# When flattening observations, concatenate scalar & vector observations before output.
|
|
# Otherwise, return observations with minimal processing.
|
|
'flatten_observations': False,
|
|
# When Flattening masks, concatenate each action subspace mask into a single array.
|
|
# Note: flatten_masks = True is required for masking action logits in the code below.
|
|
'flatten_masks': False,
|
|
}
|
|
|
|
|
|
eval_env_config = {
|
|
# ===== SCENARIO CLASS =====
|
|
# Which Scenario class to use: the class's name in the Scenario Registry (foundation.scenarios).
|
|
# The environment object will be an instance of the Scenario class.
|
|
'scenario_name': 'simple_market',
|
|
|
|
# ===== COMPONENTS =====
|
|
# Which components to use (specified as list of ("component_name", {component_kwargs}) tuples).
|
|
# "component_name" refers to the Component class's name in the Component Registry (foundation.components)
|
|
# {component_kwargs} is a dictionary of kwargs passed to the Component class
|
|
# The order in which components reset, step, and generate obs follows their listed order below.
|
|
'components': [
|
|
# (1) Building houses
|
|
('SimpleCraft', {'skill_dist': "none", 'payment_max_skill_multiplier': 3}),
|
|
# (2) Trading collectible resources
|
|
#('ContinuousDoubleAuction', {'max_num_orders': 10}),
|
|
# (3) Movement and resource collection
|
|
('SimpleGather', {}),
|
|
],
|
|
|
|
# ===== SCENARIO CLASS ARGUMENTS =====
|
|
# (optional) kwargs that are added by the Scenario class (i.e. not defined in BaseEnvironment)
|
|
|
|
'starting_agent_coin': 0,
|
|
'fixed_four_skill_and_loc': True,
|
|
|
|
# ===== STANDARD ARGUMENTS ======
|
|
# kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment)
|
|
'n_agents': 20, # Number of non-planner agents (must be > 1)
|
|
'world_size': [1, 1], # [Height, Width] of the env world
|
|
'episode_length': 100, # Number of timesteps per episode
|
|
'allow_observation_scaling': True,
|
|
'dense_log_frequency': 10,
|
|
'world_dense_log_frequency':1,
|
|
'energy_cost':0,
|
|
'energy_warmup_method': "auto",
|
|
'energy_warmup_constant': 0,
|
|
|
|
# In multi-action-mode, the policy selects an action for each action subspace (defined in component code).
|
|
# Otherwise, the policy selects only 1 action.
|
|
'multi_action_mode_agents': False,
|
|
'multi_action_mode_planner': False,
|
|
|
|
# When flattening observations, concatenate scalar & vector observations before output.
|
|
# Otherwise, return observations with minimal processing.
|
|
'flatten_observations': False,
|
|
# When Flattening masks, concatenate each action subspace mask into a single array.
|
|
# Note: flatten_masks = True is required for masking action logits in the code below.
|
|
'flatten_masks': False,
|
|
}
|
|
|
|
num_frames=2
|
|
|
|
class TensorboardCallback(BaseCallback):
|
|
"""
|
|
Custom callback for plotting additional values in tensorboard.
|
|
"""
|
|
|
|
def __init__(self,econ, verbose=0):
|
|
super().__init__(verbose)
|
|
self.econ=econ
|
|
self.metrics=econ.scenario_metrics()
|
|
def _on_step(self) -> bool:
|
|
# Log scalar value (here a random variable)
|
|
prev_metrics=self.metrics
|
|
if self.econ.previous_episode_metrics is None:
|
|
self.metrics=self.econ.scenario_metrics()
|
|
else:
|
|
self.metrics=self.econ.previous_episode_metrics
|
|
curr_prod=self.metrics["social/productivity"]
|
|
trend_pord=curr_prod-prev_metrics["social/productivity"]
|
|
self.logger.record("social/total_productivity", curr_prod)
|
|
self.logger.record("social/delta_productivity", trend_pord)
|
|
|
|
return True
|
|
|
|
|
|
def sample_random_action(agent, mask):
|
|
"""Sample random UNMASKED action(s) for agent."""
|
|
# Return a list of actions: 1 for each action subspace
|
|
if agent.multi_action_mode:
|
|
split_masks = np.split(mask, agent.action_spaces.cumsum()[:-1])
|
|
return [np.random.choice(np.arange(len(m_)), p=m_/m_.sum()) for m_ in split_masks]
|
|
|
|
# Return a single action
|
|
else:
|
|
return np.random.choice(np.arange(agent.action_spaces), p=mask/mask.sum())
|
|
|
|
def sample_random_actions(env, obs):
|
|
"""Samples random UNMASKED actions for each agent in obs."""
|
|
|
|
actions = {
|
|
a_idx: 0
|
|
for a_idx in range( len(obs))
|
|
}
|
|
|
|
return actions
|
|
|
|
def printMarket(market):
|
|
for i in range(len(market)):
|
|
step=market[i]
|
|
if len(step)>0:
|
|
print("=== Step {} ===".format(i))
|
|
for transaction in step:
|
|
t=transaction
|
|
transstring = "({}) {} -> {} | [{}/{}] {} Coins\n".format(t["commodity"],t["seller"],t["buyer"],t["ask"],t["bid"],t["price"])
|
|
print(transstring)
|
|
return ""
|
|
|
|
def printBuilds(builds):
|
|
for i in range(len(builds)):
|
|
step=builds[i]
|
|
if len(step)>0:
|
|
for build in step:
|
|
t=build
|
|
transstring = "({}) Builder: {}, Skill: {}, Income {} ".format(i,t["builder"],t["build_skill"],t["income"])
|
|
print(transstring)
|
|
return ""
|
|
def printReplay(econ,agentid):
|
|
worldmaps=["Stone","Wood"]
|
|
|
|
log=econ.previous_episode_dense_log
|
|
agent=econ.world.agents[agentid]
|
|
|
|
agentid=str(agentid)
|
|
maxsetp=len(log["states"])-1
|
|
|
|
for step in range(maxsetp):
|
|
print()
|
|
print("=== Step {} ===".format(step))
|
|
# state
|
|
print("--- World ---")
|
|
world=log['world'][step]
|
|
for res in worldmaps:
|
|
print("{}: {}".format(res,world[res][0][0]))
|
|
print("--- State ---")
|
|
state=log['states'][step][agentid]
|
|
|
|
print(yaml.dump(state))
|
|
print("--- Action ---")
|
|
action=log["actions"][step][agentid]
|
|
|
|
|
|
if action=={}:
|
|
print("Action: 0 -> NOOP")
|
|
else:
|
|
for k in action:
|
|
formats="Action: {}({})".format(k,action[k])
|
|
print(formats)
|
|
print("--- Reward ---")
|
|
reward=log["rewards"][step][agentid]
|
|
print("Reward: {}".format(reward))
|
|
|
|
#Setup Env Objects
|
|
|
|
vecenv=EconVecEnv(env_config=env_config)
|
|
econ=vecenv.env
|
|
monenv=VecMonitor(venv=vecenv,info_keywords=["social/productivity","trend/productivity"])
|
|
normenv=VecNormalize(monenv,norm_reward=False,clip_obs=1)
|
|
stackenv=vec_frame_stack.VecFrameStack(venv=monenv,n_stack=10)
|
|
obs=stackenv.reset()
|
|
|
|
|
|
|
|
|
|
|
|
runname="run_{}".format(int(np.random.rand()*100))
|
|
|
|
model = PPO("MlpPolicy",n_steps=int(env_config['episode_length']*2),ent_coef=0.1, vf_coef=0.8 ,gamma=0.95, learning_rate=5e-3,env=monenv, verbose=1,device="cuda",tensorboard_log="./log")
|
|
|
|
total_required_for_episode=env_config['n_agents']*env_config['episode_length']
|
|
print("this is run {}".format(runname))
|
|
while True:
|
|
# Create Eval ENV
|
|
|
|
vec_env_eval=EconVecEnv(env_config=eval_env_config)
|
|
vec_mon_eval=VecMonitor(venv=vec_env_eval)
|
|
norm_env_eval=VecNormalize(vec_mon_eval,norm_reward=False,training=False)
|
|
eval_econ = vec_env_eval.env
|
|
|
|
#Train
|
|
model=model.learn(total_timesteps=total_required_for_episode*50,progress_bar=True,reset_num_timesteps=False,tb_log_name=runname,callback=TensorboardCallback(econ=econ))
|
|
normenv.save("temp-normalizer.ai")
|
|
|
|
|
|
|
|
## Run Eval
|
|
print("### EVAL ###")
|
|
norm_env_eval.load("temp-normalizer.ai",vec_mon_eval)
|
|
obs=vec_mon_eval.reset()
|
|
done=False
|
|
for i in tqdm(range(eval_env_config['episode_length'])):
|
|
action=model.predict(obs)
|
|
obs,rew,done_e,info=vec_mon_eval.step(action[0])
|
|
done=done_e[0]
|
|
|
|
|
|
|
|
#market=eval_econ.get_component("ContinuousDoubleAuction")
|
|
craft=eval_econ.get_component("SimpleCraft")
|
|
# trades=market.get_dense_log()
|
|
build=craft.get_dense_log()
|
|
met=econ.previous_episode_metrics
|
|
printReplay(eval_econ,0)
|
|
# printMarket(trades)
|
|
printBuilds(builds=build)
|
|
print("social/productivity: {}".format(met["social/productivity"]))
|
|
print("labor/weighted_cost: {}".format(met["labor/weighted_cost"]))
|
|
print("labor/warmup_integrator: {}".format(met["labor/warmup_integrator"]))
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|