i think i did the converstion correct

This commit is contained in:
2023-01-12 18:33:06 +01:00
parent 03c6341b19
commit ee444cb56c
9 changed files with 131 additions and 60 deletions

17
main.py
View File

@@ -3,6 +3,10 @@ import numpy as np
from stable_baselines3.common.vec_env import vec_frame_stack
from stable_baselines3.common.evaluation import evaluate_policy
import envs
import wrapper
from wrapper.base_econ_wrapper import BaseEconWrapper
from wrapper.reciever_econ_wrapper import RecieverEconWrapper
from wrapper.sb3_econ_converter import SB3EconConverter
from tqdm import tqdm
import components
from stable_baselines3.common.env_checker import check_env
@@ -43,7 +47,7 @@ env_config = {
# ===== STANDARD ARGUMENTS ======
# kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment)
'n_agents': 20, # Number of non-planner agents (must be > 1)
'agent_composition': {"BasicMobileAgent": 20}, # Number of non-planner agents (must be > 1)
'world_size': [1, 1], # [Height, Width] of the env world
'episode_length': 256, # Number of timesteps per episode
'allow_observation_scaling': True,
@@ -95,7 +99,7 @@ eval_env_config = {
# ===== STANDARD ARGUMENTS ======
# kwargs that are used by every Scenario class (i.e. defined in BaseEnvironment)
'n_agents': 20, # Number of non-planner agents (must be > 1)
'agent_composition': {"BasicMobileAgent": 20}, # Number of non-planner agents (must be > 1)
'world_size': [1, 1], # [Height, Width] of the env world
'episode_length': 100, # Number of timesteps per episode
'allow_observation_scaling': True,
@@ -200,9 +204,14 @@ def printReplay(econ,agentid):
print("Reward: {}".format(reward))
#Setup Env Objects
econ=foundation.make_env_instance(**env_config)
baseEconWrapper=BaseEconWrapper(econ)
baseEconWrapper.run()
mobileRecieverEconWrapper=RecieverEconWrapper(base_econ=baseEconWrapper,agent_classname="BasicMobileAgent")
sb3Converter=SB3EconConverter(mobileRecieverEconWrapper,econ,"BasicMobileAgent")
obs=sb3Converter.reset()
vecenv=EconVecEnv(env_config=env_config)
econ=vecenv.env
monenv=VecMonitor(venv=vecenv,info_keywords=["social/productivity","trend/productivity"])
normenv=VecNormalize(monenv,norm_reward=False,clip_obs=1)
stackenv=vec_frame_stack.VecFrameStack(venv=monenv,n_stack=10)