from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence, Type, Union from ai_economist.foundation.base import base_env import gym import gym.spaces import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info from wrapper.base_econ_wrapper import BaseEconWrapper from ai_economist import foundation class RecieverEconWrapper(gym.Env): """Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types.""" def __init__(self, base_econ: BaseEconWrapper, agent_classname: str): self.base_econ=base_econ self.id=base_econ.register_vote() self.econ=base_econ.env self.agent_name=agent_classname self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname]) self.idx_to_index={} #create idx to index map for i in range(len(self.agnet_idx)): self.idx_to_index[self.agnet_idx[i]]=i first_idx=self.agnet_idx[0] def step_async(self, actions: dict) -> None: """Submittes actions to Env. actions is a dict with idx -> action pair""" data=self._dict_index_to_idx(actions) self.base_econ.reciever_request_step(data,self.id) def _dict_idx_to_index(self, data): data_out={} for k,v in data.items(): if k in self.idx_to_index: index=self.idx_to_index[k] data_out[index]=v return data_out def _dict_index_to_idx(self, data): data_out={} for k,v in data.items(): idx=self.agnet_idx[k] data_out[idx]=v return data_out def step_wait(self): #convert to econ actions obs,rew,done,info=self.base_econ.reciever_block_step(self.id) c_obs=self._dict_idx_to_index(obs) c_rew=self._dict_idx_to_index(rew) c_done=done c_info=self._dict_idx_to_index(info) return c_obs,c_rew,c_done,c_info def reset(self): # env=foundation.make_env_instance(**self.config) # self.env = env self.base_econ.reciever_request_reset() obs =self.base_econ.reciever_block_reset() c_obs=self._dict_idx_to_index(obs) return c_obs def step(self, action): self.step_async(action) return self.step_wait()