72 lines
2.0 KiB
Python
72 lines
2.0 KiB
Python
import numpy as np
|
|
# Convert econ to gym
|
|
def convert_econ_to_gym(econ):
|
|
gy=[]
|
|
gy=[v for k,v in econ.items()]
|
|
return gy
|
|
|
|
def convert_gym_to_econ(gy):
|
|
econ={}
|
|
for k in range(len(gy)):
|
|
econ[k]=gy[k]
|
|
return econ
|
|
|
|
|
|
|
|
def build_packager(sub_obs, put_in_both=None):
|
|
"""
|
|
Decides which keys-vals should be flattened or not.
|
|
put_in_both: include in both (e.g., 'time')
|
|
"""
|
|
if put_in_both is None:
|
|
put_in_both = []
|
|
keep_as_is = []
|
|
flatten = []
|
|
wrap_as_list = {}
|
|
for k, v in sub_obs.items():
|
|
if isinstance(v, np.ndarray):
|
|
multi_d_array = len(v.shape) > 1
|
|
else:
|
|
multi_d_array = False
|
|
|
|
if k == "action_mask" or multi_d_array:
|
|
keep_as_is.append(k)
|
|
else:
|
|
flatten.append(k)
|
|
if k in put_in_both:
|
|
keep_as_is.append(k)
|
|
|
|
wrap_as_list[k] = np.isscalar(v)
|
|
|
|
flatten = sorted(flatten)
|
|
|
|
return keep_as_is, flatten, wrap_as_list
|
|
|
|
|
|
def package(obs_dict, keep_as_is, flatten, wrap_as_list):
|
|
"""Flattens observation with packagers."""
|
|
|
|
new_obs = {k: obs_dict[k] for k in keep_as_is}
|
|
if len(flatten) == 1:
|
|
k = flatten[0]
|
|
o = obs_dict[k]
|
|
if wrap_as_list[k]:
|
|
o = [o]
|
|
new_obs["flat"] = np.array(o, dtype=np.float32)
|
|
else:
|
|
to_flatten = [
|
|
[obs_dict[k]] if wrap_as_list[k] else obs_dict[k] for k in flatten
|
|
]
|
|
try:
|
|
new_obs["flat"] = np.concatenate(to_flatten).astype(np.float32)
|
|
except ValueError:
|
|
for k, v in zip(flatten, to_flatten):
|
|
print(k, np.array(v).shape)
|
|
print(v)
|
|
print("")
|
|
raise
|
|
|
|
return new_obs
|
|
|
|
|