68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
|
from ai_economist.foundation.base import base_env
|
|
|
|
import gym
|
|
import gym.spaces
|
|
import numpy as np
|
|
from base_econ_wrapper import BaseEconVecEnv
|
|
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
|
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
|
|
|
from ai_economist import foundation
|
|
|
|
class RecieverEconVecEnv(gym.Env):
|
|
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
|
|
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
|
|
self.base_econ=base_econ
|
|
base_econ.register_vote()
|
|
self.econ=base_econ.env
|
|
self.agent_name=agent_classname
|
|
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
|
|
self.idx_to_index={}
|
|
#create idx to index map
|
|
for i in range(len(self.agnet_idx)):
|
|
self.idx_to_index[self.agnet_idx[i]]=i
|
|
first_idx=self.agnet_idx[0]
|
|
|
|
|
|
def step_async(self, actions: dict) -> None:
|
|
"""Submittes actions to Env. actions is a dict with idx -> action pair"""
|
|
data=self._dict_index_to_idx(actions)
|
|
self.base_econ.reciever_request_step(data)
|
|
|
|
def _dict_idx_to_index(self, data):
|
|
data_out={}
|
|
for k,v in data.items():
|
|
data_out[self.idx_to_index[k]]=v
|
|
return data_out
|
|
|
|
def _dict_index_to_idx(self, data):
|
|
data_out={}
|
|
for k,v in data.items():
|
|
data_out[self.agnet_idx[k]]=v
|
|
return data_out
|
|
|
|
def step_wait(self):
|
|
#convert to econ actions
|
|
obs,rew,done,info=self.base_econ.reciever_block_step()
|
|
c_obs=self._dict_idx_to_index(obs)
|
|
c_rew=self._dict_idx_to_index(rew)
|
|
c_done=self._dict_idx_to_index(done)
|
|
c_info=self._dict_idx_to_index(info)
|
|
return c_obs,c_rew,c_done,c_info
|
|
|
|
|
|
def reset(self):
|
|
# env=foundation.make_env_instance(**self.config)
|
|
# self.env = env
|
|
self.base_econ.reciever_request_reset()
|
|
obs =self.base_econ.reciever_block_reset()
|
|
c_obs=self._dict_idx_to_index(obs)
|
|
return c_obs
|
|
|
|
|
|
|