Files
ai-econ/wrapper/utils.py

72 lines
2.0 KiB
Python

import numpy as np
# Convert econ to gym
def convert_econ_to_gym(econ):
gy=[]
gy=[v for k,v in econ.items()]
return gy
def convert_gym_to_econ(gy):
econ={}
for k in range(len(gy)):
econ[k]=gy[k]
return econ
def build_packager(sub_obs, put_in_both=None):
"""
Decides which keys-vals should be flattened or not.
put_in_both: include in both (e.g., 'time')
"""
if put_in_both is None:
put_in_both = []
keep_as_is = []
flatten = []
wrap_as_list = {}
for k, v in sub_obs.items():
if isinstance(v, np.ndarray):
multi_d_array = len(v.shape) > 1
else:
multi_d_array = False
if k == "action_mask" or multi_d_array:
keep_as_is.append(k)
else:
flatten.append(k)
if k in put_in_both:
keep_as_is.append(k)
wrap_as_list[k] = np.isscalar(v)
flatten = sorted(flatten)
return keep_as_is, flatten, wrap_as_list
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
"""Flattens observation with packagers."""
new_obs = {k: obs_dict[k] for k in keep_as_is}
if len(flatten) == 1:
k = flatten[0]
o = obs_dict[k]
if wrap_as_list[k]:
o = [o]
new_obs["flat"] = np.array(o, dtype=np.float32)
else:
to_flatten = [
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
]
try:
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
except ValueError:
for k, v in zip(flatten, to_flatten):
print(k, np.array(v).shape)
print(v)
print("")
raise
return new_obs