adding ai_economist for modding
This commit is contained in:
5
ai_economist/foundation/base/__init__.py
Normal file
5
ai_economist/foundation/base/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
490
ai_economist/foundation/base/base_agent.py
Normal file
490
ai_economist/foundation/base/base_agent.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""Base class for Agent classes.
|
||||
|
||||
Instances of Agent classes are created for each agent in the environment. Agent
|
||||
instances are stateful, capturing location, inventory, endogenous variables,
|
||||
and any additional state fields created by environment components during
|
||||
construction (see BaseComponent.get_additional_state_fields in base_component.py).
|
||||
|
||||
They also provide a simple API for getting/setting actions for each of their
|
||||
registered action subspaces (which depend on the components used to build
|
||||
the environment).
|
||||
|
||||
Args:
|
||||
idx (int or str): Index that uniquely identifies the agent object amongst the
|
||||
other agent objects registered in its environment.
|
||||
multi_action_mode (bool): Whether to allow the agent to take one action for
|
||||
each of its registered action subspaces each timestep (if True),
|
||||
or to limit the agent to take only one action each timestep (if False).
|
||||
"""
|
||||
|
||||
name = ""
|
||||
|
||||
def __init__(self, idx=None, multi_action_mode=None):
|
||||
assert self.name
|
||||
|
||||
if idx is None:
|
||||
idx = 0
|
||||
|
||||
if multi_action_mode is None:
|
||||
multi_action_mode = False
|
||||
|
||||
if isinstance(idx, str):
|
||||
self._idx = idx
|
||||
else:
|
||||
self._idx = int(idx)
|
||||
|
||||
self.multi_action_mode = bool(multi_action_mode)
|
||||
self.single_action_map = (
|
||||
{}
|
||||
) # Used to convert single-action-mode actions to the general format
|
||||
|
||||
self.action = dict()
|
||||
self.action_dim = dict()
|
||||
self._action_names = []
|
||||
self._multi_action_dict = {}
|
||||
self._unique_actions = 0
|
||||
self._total_actions = 0
|
||||
|
||||
self.state = dict(loc=[0, 0], inventory={}, escrow={}, endogenous={})
|
||||
|
||||
self._registered_inventory = False
|
||||
self._registered_endogenous = False
|
||||
self._registered_components = False
|
||||
self._noop_action_dict = dict()
|
||||
|
||||
# Special flag to allow logic for multi-action-mode agents
|
||||
# that are not given any actions.
|
||||
self._passive_multi_action_agent = False
|
||||
|
||||
# If this gets set to true, we can make masks faster
|
||||
self._one_component_single_action = False
|
||||
self._premask = None
|
||||
|
||||
@property
|
||||
def idx(self):
|
||||
"""Index used to identify this agent. Must be unique within the environment."""
|
||||
return self._idx
|
||||
|
||||
def register_inventory(self, resources):
|
||||
"""Used during environment construction to populate inventory/escrow fields."""
|
||||
assert not self._registered_inventory
|
||||
for entity_name in resources:
|
||||
self.inventory[entity_name] = 0
|
||||
self.escrow[entity_name] = 0
|
||||
self._registered_inventory = True
|
||||
|
||||
def register_endogenous(self, endogenous):
|
||||
"""Used during environment construction to populate endogenous state fields."""
|
||||
assert not self._registered_endogenous
|
||||
for entity_name in endogenous:
|
||||
self.endogenous[entity_name] = 0
|
||||
self._registered_endogenous = True
|
||||
|
||||
def _incorporate_component(self, action_name, n):
|
||||
extra_n = (
|
||||
1 if self.multi_action_mode else 0
|
||||
) # Each sub-action has a NO-OP in multi action mode)
|
||||
self.action[action_name] = 0
|
||||
self.action_dim[action_name] = n + extra_n
|
||||
self._action_names.append(action_name)
|
||||
self._multi_action_dict[action_name] = False
|
||||
self._unique_actions += 1
|
||||
if self.multi_action_mode:
|
||||
self._total_actions += n + extra_n
|
||||
else:
|
||||
for action_n in range(1, n + 1):
|
||||
self._total_actions += 1
|
||||
self.single_action_map[int(self._total_actions)] = [
|
||||
action_name,
|
||||
action_n,
|
||||
]
|
||||
|
||||
def register_components(self, components):
|
||||
"""Used during environment construction to set up state/action spaces."""
|
||||
assert not self._registered_components
|
||||
for component in components:
|
||||
n = component.get_n_actions(self.name)
|
||||
if n is None:
|
||||
continue
|
||||
|
||||
# Most components will have a single action-per-agent, so n is an int
|
||||
if isinstance(n, int):
|
||||
if n == 0:
|
||||
continue
|
||||
self._incorporate_component(component.name, n)
|
||||
|
||||
# They can also internally handle multiple actions-per-agent,
|
||||
# so n is an tuple or list
|
||||
elif isinstance(n, (tuple, list)):
|
||||
for action_sub_name, n_ in n:
|
||||
if n_ == 0:
|
||||
continue
|
||||
if "." in action_sub_name:
|
||||
raise NameError(
|
||||
"Sub-action {} of component {} "
|
||||
"is illegally named.".format(
|
||||
action_sub_name, component.name
|
||||
)
|
||||
)
|
||||
self._incorporate_component(
|
||||
"{}.{}".format(component.name, action_sub_name), n_
|
||||
)
|
||||
|
||||
# If that's not what we got something is funky.
|
||||
else:
|
||||
raise TypeError(
|
||||
"Received unexpected type ({}) from {}.get_n_actions('{}')".format(
|
||||
type(n), component.name, self.name
|
||||
)
|
||||
)
|
||||
|
||||
for k, v in component.get_additional_state_fields(self.name).items():
|
||||
self.state[k] = v
|
||||
|
||||
# Currently no actions are available to this agent. Give it a placeholder.
|
||||
if len(self.action) == 0 and self.multi_action_mode:
|
||||
self._incorporate_component("PassiveAgentPlaceholder", 0)
|
||||
self._passive_multi_action_agent = True
|
||||
|
||||
elif len(self.action) == 1 and not self.multi_action_mode:
|
||||
self._one_component_single_action = True
|
||||
self._premask = np.ones(1 + self._total_actions, dtype=np.float32)
|
||||
|
||||
self._registered_components = True
|
||||
|
||||
self._noop_action_dict = {k: v * 0 for k, v in self.action.items()}
|
||||
|
||||
verbose = False
|
||||
if verbose:
|
||||
print(self.name, self.idx, "constructed action map:")
|
||||
for k, v in self.single_action_map.items():
|
||||
print("single action map:", k, v)
|
||||
for k, v in self.action.items():
|
||||
print("action:", k, v)
|
||||
for k, v in self.action_dim.items():
|
||||
print("action_dim:", k, v)
|
||||
|
||||
@property
|
||||
def action_spaces(self):
|
||||
"""
|
||||
if self.multi_action_mode == True:
|
||||
Returns an integer array with length equal to the number of action
|
||||
subspaces that the agent registered. The i'th element of the array
|
||||
indicates the number of actions associated with the i'th action subspace.
|
||||
In multi_action_mode, each subspace includes a NO-OP.
|
||||
Note: self._action_names describes which action subspace each element of
|
||||
the array refers to.
|
||||
|
||||
Example:
|
||||
>> self.multi_action_mode
|
||||
True
|
||||
>> self.action_spaces
|
||||
[2, 5]
|
||||
>> self._action_names
|
||||
["Build", "Gather"]
|
||||
# [1 Build action + Build NO-OP, 4 Gather actions + Gather NO-OP]
|
||||
|
||||
if self.multi_action_mode == False:
|
||||
Returns a single integer equal to the total number of actions that the
|
||||
agent can take.
|
||||
|
||||
Example:
|
||||
>> self.multi_action_mode
|
||||
False
|
||||
>> self.action_spaces
|
||||
6
|
||||
>> self._action_names
|
||||
["Build", "Gather"]
|
||||
# 1 NO-OP + 1 Build action + 4 Gather actions.
|
||||
"""
|
||||
if self.multi_action_mode:
|
||||
action_dims = []
|
||||
for m in self._action_names:
|
||||
action_dims.append(np.array(self.action_dim[m]).reshape(-1))
|
||||
return np.concatenate(action_dims).astype(np.int32)
|
||||
n_actions = 1 # (NO-OP)
|
||||
for m in self._action_names:
|
||||
n_actions += self.action_dim[m]
|
||||
return n_actions
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
"""2D list of [row, col] representing agent's location in the environment."""
|
||||
return self.state["loc"]
|
||||
|
||||
@property
|
||||
def endogenous(self):
|
||||
"""Dictionary representing endogenous quantities (i.e. "Labor").
|
||||
|
||||
Example:
|
||||
>> self.endogenous
|
||||
{"Labor": 30.25}
|
||||
"""
|
||||
return self.state["endogenous"]
|
||||
|
||||
@property
|
||||
def inventory(self):
|
||||
"""Dictionary representing quantities of resources in agent's inventory.
|
||||
|
||||
Example:
|
||||
>> self.inventory
|
||||
{"Wood": 3, "Stone": 20, "Coin": 1002.83}
|
||||
"""
|
||||
return self.state["inventory"]
|
||||
|
||||
@property
|
||||
def escrow(self):
|
||||
"""Dictionary representing quantities of resources in agent's escrow.
|
||||
|
||||
https://en.wikipedia.org/wiki/Escrow
|
||||
Escrow is used to manage any portion of the agent's inventory that is
|
||||
reserved for a particular purpose. Typically, something enters escrow as part
|
||||
of a contractual arrangement to disburse that something when another
|
||||
condition is met. An example is found in the ContinuousDoubleAuction
|
||||
Component class (see ../components/continuous_double_auction.py). When an
|
||||
agent creates an order to sell a unit of Wood, for example, the component
|
||||
moves one unit of Wood from the agent's inventory to its escrow. If another
|
||||
agent buys the Wood, it is moved from escrow to the other agent's inventory. By
|
||||
placing the Wood in escrow, it prevents the first agent from using it for
|
||||
something else (i.e. building a house).
|
||||
|
||||
Notes:
|
||||
The inventory and escrow share the same keys. An agent's endowment refers
|
||||
to the total quantity it has in its inventory and escrow.
|
||||
|
||||
Escrow is provided to simplify inventory management but its intended
|
||||
semantics are not enforced directly. It is up to Component classes to
|
||||
enforce these semantics.
|
||||
|
||||
Example:
|
||||
>> self.inventory
|
||||
{"Wood": 0, "Stone": 1, "Coin": 3}
|
||||
"""
|
||||
return self.state["escrow"]
|
||||
|
||||
def inventory_to_escrow(self, resource, amount):
|
||||
"""Move some amount of a resource from agent inventory to agent escrow.
|
||||
|
||||
Amount transferred is capped to the amount of resource in agent inventory.
|
||||
|
||||
Args:
|
||||
resource (str): The name of the resource to move (i.e. "Wood", "Coin").
|
||||
amount (float): The amount to be moved from inventory to escrow. Must be
|
||||
positive.
|
||||
|
||||
Returns:
|
||||
Amount of resource actually transferred. Will be less than amount argument
|
||||
if amount argument exceeded the amount of resource in the inventory.
|
||||
Calculated as:
|
||||
transferred = np.minimum(self.state["inventory"][resource], amount)
|
||||
"""
|
||||
assert amount >= 0
|
||||
transferred = float(np.minimum(self.state["inventory"][resource], amount))
|
||||
self.state["inventory"][resource] -= transferred
|
||||
self.state["escrow"][resource] += transferred
|
||||
return float(transferred)
|
||||
|
||||
def escrow_to_inventory(self, resource, amount):
|
||||
"""Move some amount of a resource from agent escrow to agent inventory.
|
||||
|
||||
Amount transferred is capped to the amount of resource in agent escrow.
|
||||
|
||||
Args:
|
||||
resource (str): The name of the resource to move (i.e. "Wood", "Coin").
|
||||
amount (float): The amount to be moved from escrow to inventory. Must be
|
||||
positive.
|
||||
|
||||
Returns:
|
||||
Amount of resource actually transferred. Will be less than amount argument
|
||||
if amount argument exceeded the amount of resource in escrow.
|
||||
Calculated as:
|
||||
transferred = np.minimum(self.state["escrow"][resource], amount)
|
||||
"""
|
||||
assert amount >= 0
|
||||
transferred = float(np.minimum(self.state["escrow"][resource], amount))
|
||||
self.state["escrow"][resource] -= transferred
|
||||
self.state["inventory"][resource] += transferred
|
||||
return float(transferred)
|
||||
|
||||
def total_endowment(self, resource):
|
||||
"""Get the combined inventory+escrow endowment of resource.
|
||||
|
||||
Args:
|
||||
resource (str): Name of the resource
|
||||
|
||||
Returns:
|
||||
The amount of resource in the agents inventory and escrow.
|
||||
|
||||
"""
|
||||
return self.inventory[resource] + self.escrow[resource]
|
||||
|
||||
def reset_actions(self, component=None):
|
||||
"""Reset all actions to the NO-OP action (the 0'th action index).
|
||||
|
||||
If component is specified, only reset action(s) for that component.
|
||||
"""
|
||||
if not component:
|
||||
self.action.update(self._noop_action_dict)
|
||||
else:
|
||||
for k, v in self.action.items():
|
||||
if "." in component:
|
||||
if k.lower() == component.lower():
|
||||
self.action[k] = v * 0
|
||||
else:
|
||||
base_component = k.split(".")[0]
|
||||
if base_component.lower() == component.lower():
|
||||
self.action[k] = v * 0
|
||||
|
||||
def has_component(self, component_name):
|
||||
"""Returns True if the agent has component_name as a registered subaction."""
|
||||
return bool(component_name in self.action)
|
||||
|
||||
def get_random_action(self):
|
||||
"""
|
||||
Select a component at random and randomly choose one of its actions (other
|
||||
than NO-OP).
|
||||
"""
|
||||
random_component = random.choice(self._action_names)
|
||||
component_action = random.choice(
|
||||
list(range(1, self.action_dim[random_component]))
|
||||
)
|
||||
return {random_component: component_action}
|
||||
|
||||
def get_component_action(self, component_name, sub_action_name=None):
|
||||
"""
|
||||
Return the action(s) taken for component_name component, or None if the
|
||||
agent does not use that component.
|
||||
"""
|
||||
if sub_action_name is not None:
|
||||
return self.action.get(component_name + "." + sub_action_name, None)
|
||||
matching_names = [
|
||||
m for m in self._action_names if m.split(".")[0] == component_name
|
||||
]
|
||||
if len(matching_names) == 0:
|
||||
return None
|
||||
if len(matching_names) == 1:
|
||||
return self.action.get(matching_names[0], None)
|
||||
return [self.action.get(m, None) for m in matching_names]
|
||||
|
||||
def set_component_action(self, component_name, action):
|
||||
"""Set the action(s) taken for component_name component."""
|
||||
if component_name not in self.action:
|
||||
raise KeyError(
|
||||
"Agent {} of type {} does not have {} registered as a subaction".format(
|
||||
self.idx, self.name, component_name
|
||||
)
|
||||
)
|
||||
if self._multi_action_dict[component_name]:
|
||||
self.action[component_name] = np.array(action, dtype=np.int32)
|
||||
else:
|
||||
self.action[component_name] = int(action)
|
||||
|
||||
def populate_random_actions(self):
|
||||
"""Fill the action buffer with random actions. This is for testing."""
|
||||
for component, d in self.action_dim.items():
|
||||
if isinstance(d, int):
|
||||
self.set_component_action(component, np.random.randint(0, d))
|
||||
else:
|
||||
d_array = np.array(d)
|
||||
self.set_component_action(
|
||||
component, np.floor(np.random.rand(*d_array.shape) * d_array)
|
||||
)
|
||||
|
||||
def parse_actions(self, actions):
|
||||
"""Parse the actions array to fill each component's action buffers."""
|
||||
if self.multi_action_mode:
|
||||
assert len(actions) == self._unique_actions
|
||||
if len(actions) == 1:
|
||||
self.set_component_action(self._action_names[0], actions[0])
|
||||
else:
|
||||
for action_name, action in zip(self._action_names, actions):
|
||||
self.set_component_action(action_name, int(action))
|
||||
|
||||
# Single action mode
|
||||
else:
|
||||
# Action was supplied as an index of a specific subaction.
|
||||
# No need to do any lookup.
|
||||
if isinstance(actions, dict):
|
||||
if len(actions) == 0:
|
||||
return
|
||||
assert len(actions) == 1
|
||||
action_name = list(actions.keys())[0]
|
||||
action = list(actions.values())[0]
|
||||
if action == 0:
|
||||
return
|
||||
self.set_component_action(action_name, action)
|
||||
|
||||
# Action was supplied as an index into the full set of combined actions
|
||||
else:
|
||||
action = int(actions)
|
||||
# Universal NO-OP
|
||||
if action == 0:
|
||||
return
|
||||
action_name, action = self.single_action_map.get(action)
|
||||
self.set_component_action(action_name, action)
|
||||
|
||||
def flatten_masks(self, mask_dict):
|
||||
"""Convert a dictionary of component action masks into a single mask vector."""
|
||||
if self._one_component_single_action:
|
||||
self._premask[1:] = mask_dict[self._action_names[0]]
|
||||
return self._premask
|
||||
|
||||
no_op_mask = [1]
|
||||
|
||||
if self._passive_multi_action_agent:
|
||||
return np.array(no_op_mask).astype(np.float32)
|
||||
|
||||
list_of_masks = []
|
||||
if not self.multi_action_mode:
|
||||
list_of_masks.append(no_op_mask)
|
||||
for m in self._action_names:
|
||||
if m not in mask_dict:
|
||||
raise KeyError("No mask provided for {} (agent {})".format(m, self.idx))
|
||||
if self.multi_action_mode:
|
||||
list_of_masks.append(no_op_mask)
|
||||
list_of_masks.append(mask_dict[m])
|
||||
return np.concatenate(list_of_masks).astype(np.float32)
|
||||
|
||||
|
||||
agent_registry = Registry(BaseAgent)
|
||||
"""The registry for Agent classes.
|
||||
|
||||
This creates a registry object for Agent classes. This registry requires that all
|
||||
added classes are subclasses of BaseAgent. To make an Agent class available through
|
||||
the registry, decorate the class definition with @agent_registry.add.
|
||||
|
||||
Example:
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
@agent_registry.add
|
||||
class ExampleAgent(BaseAgent):
|
||||
name = "Example"
|
||||
pass
|
||||
|
||||
assert agent_registry.has("Example")
|
||||
|
||||
AgentClass = agent_registry.get("Example")
|
||||
agent = AgentClass(...)
|
||||
assert isinstance(agent, ExampleAgent)
|
||||
|
||||
Notes:
|
||||
The foundation package exposes the agent registry as: foundation.agents
|
||||
|
||||
An Agent class that is defined and registered following the above example will
|
||||
only be visible in foundation.agents if defined/registered in a file that is
|
||||
imported in ../agents/__init__.py.
|
||||
"""
|
||||
406
ai_economist/foundation/base/base_component.py
Normal file
406
ai_economist/foundation/base/base_component.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.agents import agent_registry
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
from ai_economist.foundation.base.world import World
|
||||
|
||||
|
||||
class BaseComponent(ABC):
|
||||
"""
|
||||
Base Component class. Should be used as the parent class for Component classes.
|
||||
Component instances are used to add some particular dynamics to an environment.
|
||||
They also add action spaces through which agents can interact with the
|
||||
environment via the component instance.
|
||||
|
||||
Environments expand the agents' state/action spaces by querying:
|
||||
get_n_actions
|
||||
get_additional_state_fields
|
||||
|
||||
Environments expand their dynamics by querying:
|
||||
component_step
|
||||
generate_observations
|
||||
generate_masks
|
||||
|
||||
Environments expand logging behavior by querying:
|
||||
get_metrics
|
||||
get_dense_log
|
||||
|
||||
Because they are built as Python objects, component instances can also be
|
||||
stateful. Stateful attributes are reset via calls to:
|
||||
additional_reset_steps
|
||||
|
||||
The semantics of each method, and how they can be used to construct an instance
|
||||
of the Component class, are detailed below.
|
||||
|
||||
Refer to ../components/move.py for an example of a Component class that enables
|
||||
mobile agents to move and collect resources in the environment world.
|
||||
"""
|
||||
|
||||
# The name associated with this Component class (must be unique).
|
||||
# Note: This is what will identify the Component class in the component registry.
|
||||
name = ""
|
||||
|
||||
# An optional shorthand description of the what the component implements (i.e.
|
||||
# "Trading", "Building", etc.). See BaseEnvironment.get_component and
|
||||
# BaseEnvironment._finalize_logs to see where this may add convenience.
|
||||
# Does not need to be unique.
|
||||
component_type = None
|
||||
|
||||
# The (sub)classes of agents that this component applies to
|
||||
agent_subclasses = None # Replace with list or tuple (can be empty)
|
||||
|
||||
# The (non-agent) game entities that are expected to be in play
|
||||
required_entities = None # Replace with list or tuple (can be empty)
|
||||
|
||||
def __init__(self, world, episode_length, inventory_scale=1):
|
||||
assert self.name
|
||||
|
||||
assert isinstance(self.agent_subclasses, (tuple, list))
|
||||
assert len(self.agent_subclasses) > 0
|
||||
if len(self.agent_subclasses) > 1:
|
||||
for i in range(len(self.agent_subclasses)):
|
||||
for j in range(len(self.agent_subclasses)):
|
||||
if i == j:
|
||||
continue
|
||||
a_i = agent_registry.get(self.agent_subclasses[i])
|
||||
a_j = agent_registry.get(self.agent_subclasses[j])
|
||||
assert not issubclass(a_i, a_j)
|
||||
|
||||
assert isinstance(self.required_entities, (tuple, list))
|
||||
|
||||
self.check_world(world)
|
||||
self._world = world
|
||||
|
||||
assert isinstance(episode_length, int) and episode_length > 0
|
||||
self._episode_length = episode_length
|
||||
|
||||
self.n_agents = world.n_agents
|
||||
self.resources = world.resources
|
||||
self.landmarks = world.landmarks
|
||||
|
||||
self.timescale = 1
|
||||
assert self.timescale >= 1
|
||||
|
||||
self._inventory_scale = float(inventory_scale)
|
||||
|
||||
@property
|
||||
def world(self):
|
||||
"""The world object of the environment this component instance is part of.
|
||||
|
||||
The world object exposes the spatial/agent states through:
|
||||
world.maps # Reference to maps object representing spatial state
|
||||
world.agents # List of self.n_agents mobile agent objects
|
||||
world.planner # Reference to planner agent object
|
||||
|
||||
See world.py and base_agent.py for additional API details.
|
||||
"""
|
||||
return self._world
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
"""Episode length of the environment this component instance is a part of."""
|
||||
return int(self._episode_length)
|
||||
|
||||
@property
|
||||
def inv_scale(self):
|
||||
"""
|
||||
Value by which to scale quantities when generating observations.
|
||||
|
||||
Note: This property is set by the environment during construction and
|
||||
allows each component instance within the environment to refer to the same
|
||||
scaling value. How the value is actually used depends on the implementation
|
||||
of get_observations().
|
||||
"""
|
||||
return self._inventory_scale
|
||||
|
||||
@property
|
||||
def shorthand(self):
|
||||
"""The shorthand name, or name if no component_type is defined."""
|
||||
return self.name if self.component_type is None else self.component_type
|
||||
|
||||
@staticmethod
|
||||
def check_world(world):
|
||||
"""Validate the world object."""
|
||||
assert isinstance(world, World)
|
||||
|
||||
def reset(self):
|
||||
"""Reset any portion of the state managed by this component."""
|
||||
world = self.world
|
||||
all_agents = world.agents + [world.planner]
|
||||
for agent in all_agents:
|
||||
agent.state.update(self.get_additional_state_fields(agent.name))
|
||||
|
||||
# This method allows components to define additional reset steps
|
||||
self.additional_reset_steps()
|
||||
|
||||
def obs(self):
|
||||
"""
|
||||
Observation produced by this component, given current world/agents/component
|
||||
state.
|
||||
"""
|
||||
# This is mostly just to ensure formatting.
|
||||
obs = self.generate_observations()
|
||||
assert isinstance(obs, dict)
|
||||
obs = {str(k): v for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
Return the number of actions (not including NO-OPs) for agents of type
|
||||
agent_cls_name.
|
||||
|
||||
Args:
|
||||
agent_cls_name (str): name of the Agent class for which number of actions
|
||||
is being queried. For example, "BasicMobileAgent".
|
||||
|
||||
Returns:
|
||||
action_space (None, int, or list): If the component does not add any
|
||||
actions for agents of type agent_cls_name, return None. If it adds a
|
||||
single action space, return an integer specifying the number of
|
||||
actions in the action space. If it adds multiple action spaces,
|
||||
return a list of tuples ("action_set_name", num_actions_in_set).
|
||||
See below for further detail.
|
||||
|
||||
If agent_class_name type agents do not participate in the component, simply
|
||||
return None
|
||||
|
||||
In the next simplest case, the component adds one set of n different actions
|
||||
for agents of type agent_cls_name. In this case, return n (as an int). For
|
||||
example, if Component implements moving up, down, left, or right for
|
||||
"BasicMobileAgent" agents, then Component.get_n_actions('Mobile') should
|
||||
return 4.
|
||||
|
||||
If the component adds multiple sets of actions for a given agent type, this
|
||||
method should return a list of tuples:
|
||||
[("action_set_name_1", n_1), ..., ("action_set_name_M", n_M)],
|
||||
where M is the number of different sets of actions, and n_k is the number of
|
||||
actions in action set k.
|
||||
For example, if Component allows agent 'Planner' to set some tax for each of
|
||||
individual Mobile agents, and there are 3 such agents, then:
|
||||
Component.get_n_actions('Planner') should return, i.e.,
|
||||
[('Tax_0', 10), ('Tax_1', 10), ('Tax_2', 10)],
|
||||
where, in this example, the Planner agent can choose 10 different tax
|
||||
levels for each Mobile agent.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
Return a dictionary of {state_field: reset_val} managed by this Component
|
||||
class for agents of type agent_cls_name. This also partially controls reset
|
||||
behavior.
|
||||
|
||||
Args:
|
||||
agent_cls_name (str): name of the Agent class for which additional states
|
||||
are being queried. For example, "BasicMobileAgent".
|
||||
|
||||
Returns:
|
||||
extra_state_dict (dict): A dictionary of {"state_field": reset_val} for
|
||||
each extra state field that this component adds/manages to agents of
|
||||
type agent_cls_name. This extra_state_dict is incorporated into
|
||||
agent.state for each agent of this type. Note that the keyed fields
|
||||
will be reset to reset_val when the environment is reset.
|
||||
|
||||
If the component has its own internal state, the protocol for resetting that
|
||||
should be written into the custom method 'additional_reset_steps()' [see below].
|
||||
|
||||
States that are meant to be internal to the component do not need to be
|
||||
registered as agent state fields. Rather, adding to the agent state fields is
|
||||
most useful when two or more components refer to or affect the same state. In
|
||||
general, however, if the component expects a particular state field to exist,
|
||||
it should use return that field (and its reset value) here.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def component_step(self):
|
||||
"""
|
||||
For all relevant agents, execute the actions specific to this Component class.
|
||||
This is essentially where the component logic is implemented and what allows
|
||||
components to create environment dynamics.
|
||||
|
||||
If the component expects certain resources/landmarks/entities to be in play,
|
||||
it must declare them in 'required_entities' so that they can be registered as
|
||||
part of the world and, where appropriate, part of the agent inventory.
|
||||
|
||||
If the component expects non-standard fields to exist in agent.state for one
|
||||
or more agent types, that must be reflected in get_additional_state_fields().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_observations(self):
|
||||
"""
|
||||
Generate observations associated with this Component class.
|
||||
|
||||
A component does not need to produce observations and can provide observations
|
||||
for only some agent types; however, for a given environment, the structure of
|
||||
the observations returned by this component should be identical between
|
||||
subsequent calls to generate_observations. That is, the agents that receive
|
||||
observations should remain consistent as should the structure of their
|
||||
individual observations.
|
||||
|
||||
Returns:
|
||||
obs (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent (which can include
|
||||
the planner) for which this component provides an observation. For each
|
||||
entry, the key specifies the index of the agent and the value contains
|
||||
its associated observation dictionary.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
Create action masks to indicate which actions are and are not valid. Actions
|
||||
that are valid should be given a value of 1 and 0 otherwise. Do not generate
|
||||
a mask for the NO-OP action, which is always available.
|
||||
|
||||
Args:
|
||||
completions (int): The number of completed episodes. This is intended to
|
||||
be used in the case that actions may be masked or unmasked as part of a
|
||||
learning curriculum.
|
||||
|
||||
Returns:
|
||||
masks (dict): A dictionary of {agent.idx: mask} with an entry for each
|
||||
agent that can interact with this component. See below.
|
||||
|
||||
|
||||
The expected output parallels the action subspaces defined by get_n_actions():
|
||||
The output should be a dictionary of {agent.idx: mask} keyed for all agents
|
||||
that take actions via this component.
|
||||
|
||||
For example, say the component defines a set of 4 actions for agents of type
|
||||
"BasicMobileAgent" (self.get_n_actions("BasicMobileAgent) --> 4). Because all
|
||||
action spaces include a NO-OP action, there are 5 available actions,
|
||||
interpreted in this example as: NO-OP (index=0), moving up (index=1),
|
||||
down (index=2), left (index=3), or right (index=4). Say also that agent-0 (the
|
||||
agent with agent.idx=0) is prevented from moving left but can otherwise move.
|
||||
In this case, generate_masks(world)['0'] should point to a length-4 binary
|
||||
array, specifically [1, 1, 0, 1]. Note that the mask is length 4 while
|
||||
technically 5 actions are available. This is because NO-OP should be ignored
|
||||
when constructing masks.
|
||||
|
||||
In the more complex case where the component defines several action sets for
|
||||
an agent, say the planner agent (the agent with agent.idx='p'), then
|
||||
generate_masks(world)['p'] should point to a dictionary of
|
||||
{"action_set_name_m": mask_m} for each of the M action sets associated with
|
||||
agent p's type. Each such value, mask_m, should be a binary array whose
|
||||
length matches the number of actions in "action_set_name_m".
|
||||
|
||||
The default behavior (below) keeps all actions available. The code gives an
|
||||
example of expected formatting.
|
||||
"""
|
||||
world = self.world
|
||||
masks = {}
|
||||
# For all the agents in the environment
|
||||
for agent in world.agents + [world.planner]:
|
||||
# Get any action space(s) defined by this component for this agent
|
||||
n_actions = self.get_n_actions(agent.name)
|
||||
|
||||
# If no action spaces are defined, just move on.
|
||||
if n_actions is None:
|
||||
continue
|
||||
|
||||
# If a single action space is defined, n_actions corresponds to the
|
||||
# number of (non NO-OP) actions. Return an array of ones of that length,
|
||||
# enabling all actions.
|
||||
if isinstance(n_actions, (int, float)):
|
||||
masks[agent.idx] = np.ones(int(n_actions))
|
||||
|
||||
# If multiple action spaces are defined, n_actions corresponds to the
|
||||
# tuple or list giving ("name", N) for each action space, where "name"
|
||||
# is the unique name and N is the number of (non NO-OP) actions
|
||||
# associated with that action space.
|
||||
# Return a dictionary of {"name": length-N ones array}, enabling all
|
||||
# actions in all the action spaces.
|
||||
elif isinstance(n_actions, (tuple, list)):
|
||||
masks[agent.idx] = {
|
||||
sub_name: np.ones(int(sub_n)) for sub_name, sub_n in n_actions
|
||||
}
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
Use this method to implement additional steps that the component should
|
||||
perform at reset. Useful for resetting internal trackers.
|
||||
|
||||
This method should not return anything.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_metrics(self):
|
||||
"""
|
||||
Returns a dictionary of custom metrics describing the episode through the
|
||||
lens of the component.
|
||||
|
||||
For example, if Build is a subclass of BaseComponent that implements building,
|
||||
Build.get_metrics() might return a dictionary with terms relating to the
|
||||
number of things each agent built.
|
||||
|
||||
Returns:
|
||||
metrics (dict or None): A dictionary of {"metric_key": metric_value}
|
||||
entries describing the metrics that this component calculates. The
|
||||
environment combines scenario metrics with each of the metric
|
||||
dictionaries produced by its component instances. metric_value is
|
||||
expected to be a scalar.
|
||||
By returning None instead of a dictionary, the component is ignored
|
||||
by the environment when constructing the full metric report.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Return the dense log, either a tuple, list, or dict, of the episode through the
|
||||
lens of this component.
|
||||
|
||||
If this component does not yield a dense log, return None (default behavior).
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
component_registry = Registry(BaseComponent)
|
||||
"""The registry for Component classes.
|
||||
|
||||
This creates a registry object for Component classes. This registry requires that all
|
||||
added classes are subclasses of BaseComponent. To make a Component class available
|
||||
through the registry, decorate the class definition with @component_registry.add.
|
||||
|
||||
Example:
|
||||
from ai_economist.foundation.base.base_component
|
||||
import BaseComponent, component_registry
|
||||
|
||||
@component_registry.add
|
||||
class ExampleComponent(BaseComponent):
|
||||
name = "Example"
|
||||
pass
|
||||
|
||||
assert component_registry.has("Example")
|
||||
|
||||
ComponentClass = component_registry.get("Example")
|
||||
component = ComponentClass(...)
|
||||
assert isinstance(component, ExampleComponent)
|
||||
|
||||
Notes:
|
||||
The foundation package exposes the component registry as: foundation.components
|
||||
|
||||
A Component class that is defined and registered following the above example will
|
||||
only be visible in foundation.components if defined/registered in a file that is
|
||||
imported in ../components/__init__.py.
|
||||
"""
|
||||
1157
ai_economist/foundation/base/base_env.py
Normal file
1157
ai_economist/foundation/base/base_env.py
Normal file
File diff suppressed because it is too large
Load Diff
103
ai_economist/foundation/base/registrar.py
Normal file
103
ai_economist/foundation/base/registrar.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Utility for registering sets of similar classes and looking them up by name.
|
||||
|
||||
Registries provide a simple API for getting classes used to build environment
|
||||
instances. Their main purpose is to organize such "building block" classes (i.e.
|
||||
Components, Scenarios, Agents) for easy reference as well as to ensure that all
|
||||
classes within a particular registry inherit from the same Base Class.
|
||||
|
||||
Args:
|
||||
base_class (class): The class that all entries in the registry must be a
|
||||
subclass of.
|
||||
|
||||
Example:
|
||||
class BaseClass:
|
||||
pass
|
||||
|
||||
registry = Registry(BaseClass)
|
||||
|
||||
@registry.add
|
||||
class ExampleSubclassA(BaseClass):
|
||||
name = "ExampleA"
|
||||
pass
|
||||
|
||||
@registry.add
|
||||
class ExampleSubclassB(BaseClass):
|
||||
name = "ExampleB"
|
||||
pass
|
||||
|
||||
print(registry.entries)
|
||||
# ["ExampleA", "ExampleB"]
|
||||
|
||||
assert registry.has("ExampleA")
|
||||
assert registry.get("ExampleB") is ExampleSubclassB
|
||||
"""
|
||||
|
||||
def __init__(self, base_class=None):
|
||||
self.base_class = base_class
|
||||
self._entries = []
|
||||
self._lookup = dict()
|
||||
|
||||
def add(self, cls):
|
||||
"""Add cls to this registry.
|
||||
|
||||
Args:
|
||||
cls: The class to add to this registry. Must be a subclass of
|
||||
self.base_class.
|
||||
|
||||
Returns:
|
||||
cls (to allow decoration with @registry.add)
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
assert "." not in cls.name
|
||||
if self.base_class:
|
||||
assert issubclass(cls, self.base_class)
|
||||
self._lookup[cls.name.lower()] = cls
|
||||
if cls.name not in self._entries:
|
||||
self._entries.append(cls.name)
|
||||
return cls
|
||||
|
||||
def get(self, cls_name):
|
||||
"""Return registered class with name cls_name.
|
||||
|
||||
Args:
|
||||
cls_name (str): Name of the registered class to get.
|
||||
|
||||
Returns:
|
||||
Registered class cls, where cls.name matches cls_name (ignoring casing).
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
if cls_name.lower() not in self._lookup:
|
||||
raise KeyError('"{}" is not a name of a registered class'.format(cls_name))
|
||||
return self._lookup[cls_name.lower()]
|
||||
|
||||
def has(self, cls_name):
|
||||
"""Return True if a class with name cls_name is registered.
|
||||
|
||||
Args:
|
||||
cls_name (str): Name of class to check.
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
return cls_name.lower() in self._lookup
|
||||
|
||||
@property
|
||||
def entries(self):
|
||||
"""Names of classes in this registry.
|
||||
|
||||
Returns:
|
||||
A list of strings corresponding to the names of classes registered in
|
||||
this registry object.
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
return sorted(list(self._entries))
|
||||
495
ai_economist/foundation/base/world.py
Normal file
495
ai_economist/foundation/base/world.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.agents import agent_registry
|
||||
from ai_economist.foundation.entities import landmark_registry, resource_registry
|
||||
|
||||
|
||||
class Maps:
|
||||
"""Manages the spatial configuration of the world as a set of entity maps.
|
||||
|
||||
A maps object is built during world construction, which is a part of environment
|
||||
construction. The maps object is accessible through the world object. The maps
|
||||
object maintains a map state for each of the spatial entities that are involved
|
||||
in the constructed environment (which are determined by the "required_entities"
|
||||
attributes of the Scenario and Component classes used to build the environment).
|
||||
|
||||
The Maps class also implements some of the basic spatial logic of the game,
|
||||
such as which locations agents can occupy based on other agent locations and
|
||||
locations of various landmarks.
|
||||
|
||||
Args:
|
||||
size (list): A length-2 list specifying the dimensions of the 2D world.
|
||||
Interpreted as [height, width].
|
||||
n_agents (int): The number of mobile agents (does not include planner).
|
||||
world_resources (list): The resources registered during environment
|
||||
construction.
|
||||
world_landmarks (list): The landmarks registered during environment
|
||||
construction.
|
||||
"""
|
||||
|
||||
def __init__(self, size, n_agents, world_resources, world_landmarks):
|
||||
self.size = size
|
||||
self.sz_h, self.sz_w = size
|
||||
|
||||
self.n_agents = n_agents
|
||||
|
||||
self.resources = world_resources
|
||||
self.landmarks = world_landmarks
|
||||
self.entities = world_resources + world_landmarks
|
||||
|
||||
self._maps = {} # All maps
|
||||
self._blocked = [] # Solid objects that no agent can move through
|
||||
self._private = [] # Solid objects that only permit movement for parent agents
|
||||
self._public = [] # Non-solid objects that agents can move on top of
|
||||
self._resources = [] # Non-solid objects that can be collected
|
||||
|
||||
self._private_landmark_types = []
|
||||
self._resource_source_blocks = []
|
||||
|
||||
self._map_keys = []
|
||||
|
||||
self._accessibility_lookup = {}
|
||||
|
||||
for resource in self.resources:
|
||||
resource_cls = resource_registry.get(resource)
|
||||
if resource_cls.collectible:
|
||||
self._maps[resource] = np.zeros(shape=self.size)
|
||||
self._resources.append(resource)
|
||||
self._map_keys.append(resource)
|
||||
|
||||
self.landmarks.append("{}SourceBlock".format(resource))
|
||||
|
||||
for landmark in self.landmarks:
|
||||
dummy_landmark = landmark_registry.get(landmark)()
|
||||
|
||||
if dummy_landmark.public:
|
||||
self._maps[landmark] = np.zeros(shape=self.size)
|
||||
self._public.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
|
||||
elif dummy_landmark.blocking:
|
||||
self._maps[landmark] = np.zeros(shape=self.size)
|
||||
self._blocked.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
self._accessibility_lookup[landmark] = len(self._accessibility_lookup)
|
||||
|
||||
elif dummy_landmark.private:
|
||||
self._private_landmark_types.append(landmark)
|
||||
self._maps[landmark] = dict(
|
||||
owner=-np.ones(shape=self.size, dtype=np.int16),
|
||||
health=np.zeros(shape=self.size),
|
||||
)
|
||||
self._private.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
self._accessibility_lookup[landmark] = len(self._accessibility_lookup)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self._idx_map = np.stack(
|
||||
[i * np.ones(shape=self.size) for i in range(self.n_agents)]
|
||||
)
|
||||
self._idx_array = np.arange(self.n_agents)
|
||||
if self._accessibility_lookup:
|
||||
self._accessibility = np.ones(
|
||||
shape=[len(self._accessibility_lookup), self.n_agents] + self.size,
|
||||
dtype=bool,
|
||||
)
|
||||
self._net_accessibility = None
|
||||
else:
|
||||
self._accessibility = None
|
||||
self._net_accessibility = np.ones(
|
||||
shape=[self.n_agents] + self.size, dtype=bool
|
||||
)
|
||||
|
||||
self._agent_locs = [None for _ in range(self.n_agents)]
|
||||
self._unoccupied = np.ones(self.size, dtype=bool)
|
||||
|
||||
def clear(self, entity_name=None):
|
||||
"""Clear resource and landmark maps."""
|
||||
if entity_name is not None:
|
||||
assert entity_name in self._maps
|
||||
if entity_name in self._private_landmark_types:
|
||||
self._maps[entity_name] = dict(
|
||||
owner=-np.ones(shape=self.size, dtype=np.int16),
|
||||
health=np.zeros(shape=self.size),
|
||||
)
|
||||
else:
|
||||
self._maps[entity_name] *= 0
|
||||
|
||||
else:
|
||||
for name in self.keys():
|
||||
self.clear(entity_name=name)
|
||||
|
||||
if self._accessibility is not None:
|
||||
self._accessibility = np.ones_like(self._accessibility)
|
||||
self._net_accessibility = None
|
||||
|
||||
def clear_agent_loc(self, agent=None):
|
||||
"""Remove agents or agent from the world map."""
|
||||
# Clear all agent locations
|
||||
if agent is None:
|
||||
self._agent_locs = [None for _ in range(self.n_agents)]
|
||||
self._unoccupied[:, :] = 1
|
||||
|
||||
# Clear the location of the provided agent
|
||||
else:
|
||||
i = agent.idx
|
||||
if self._agent_locs[i] is None:
|
||||
return
|
||||
r, c = self._agent_locs[i]
|
||||
self._unoccupied[r, c] = 1
|
||||
self._agent_locs[i] = None
|
||||
|
||||
def set_agent_loc(self, agent, r, c):
|
||||
"""Set the location of agent to [r, c].
|
||||
|
||||
Note:
|
||||
Things might break if you set the agent's location to somewhere it
|
||||
cannot access. Don't do that.
|
||||
"""
|
||||
assert (0 <= r < self.size[0]) and (0 <= c < self.size[1])
|
||||
i = agent.idx
|
||||
# If the agent is currently on the board...
|
||||
if self._agent_locs[i] is not None:
|
||||
curr_r, curr_c = self._agent_locs[i]
|
||||
# If the agent isn't actually moving, just return
|
||||
if (curr_r, curr_c) == (r, c):
|
||||
return
|
||||
# Make the location the agent is currently at as unoccupied
|
||||
# (since the agent is going to move)
|
||||
self._unoccupied[curr_r, curr_c] = 1
|
||||
|
||||
# Set the agent location to the specified coordinates
|
||||
# and update the occupation map
|
||||
agent.state["loc"] = [r, c]
|
||||
self._agent_locs[i] = [r, c]
|
||||
self._unoccupied[r, c] = 0
|
||||
|
||||
def keys(self):
|
||||
"""Return an iterable over map keys."""
|
||||
return self._maps.keys()
|
||||
|
||||
def values(self):
|
||||
"""Return an iterable over map values."""
|
||||
return self._maps.values()
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable over map (key, value) pairs."""
|
||||
return self._maps.items()
|
||||
|
||||
def get(self, entity_name, owner=False):
|
||||
"""Return the map or ownership for entity_name."""
|
||||
assert entity_name in self._maps
|
||||
if entity_name in self._private_landmark_types:
|
||||
sub_key = "owner" if owner else "health"
|
||||
return self._maps[entity_name][sub_key]
|
||||
return self._maps[entity_name]
|
||||
|
||||
def set(self, entity_name, map_state):
|
||||
"""Set the map for entity_name."""
|
||||
if entity_name in self._private_landmark_types:
|
||||
assert "owner" in map_state
|
||||
assert self.get(entity_name, owner=True).shape == map_state["owner"].shape
|
||||
assert "health" in map_state
|
||||
assert self.get(entity_name, owner=False).shape == map_state["health"].shape
|
||||
|
||||
h = np.maximum(0.0, map_state["health"])
|
||||
o = map_state["owner"].astype(np.int16)
|
||||
|
||||
o[h <= 0] = -1
|
||||
tmp = o[h > 0]
|
||||
if len(tmp) > 0:
|
||||
assert np.min(tmp) >= 0
|
||||
|
||||
self._maps[entity_name] = dict(owner=o, health=h)
|
||||
|
||||
owned_by_agent = o[None] == self._idx_map
|
||||
owned_by_none = o[None] == -1
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.logical_or(owned_by_agent, owned_by_none)
|
||||
self._net_accessibility = None
|
||||
|
||||
else:
|
||||
assert self.get(entity_name).shape == map_state.shape
|
||||
self._maps[entity_name] = np.maximum(0, map_state)
|
||||
|
||||
if entity_name in self._blocked:
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.repeat(map_state[None] == 0, self.n_agents, axis=0)
|
||||
self._net_accessibility = None
|
||||
|
||||
def set_add(self, entity_name, map_state):
|
||||
"""Add map_state to the existing map for entity_name."""
|
||||
assert entity_name not in self._private_landmark_types
|
||||
self.set(entity_name, self.get(entity_name) + map_state)
|
||||
|
||||
def get_point(self, entity_name, r, c, **kwargs):
|
||||
"""Return the entity state at the specified coordinates."""
|
||||
point_map = self.get(entity_name, **kwargs)
|
||||
return point_map[r, c]
|
||||
|
||||
def set_point(self, entity_name, r, c, val, owner=None):
|
||||
"""Set the entity state at the specified coordinates."""
|
||||
if entity_name in self._private_landmark_types:
|
||||
assert owner is not None
|
||||
h = self._maps[entity_name]["health"]
|
||||
o = self._maps[entity_name]["owner"]
|
||||
assert o[r, c] == -1 or o[r, c] == int(owner)
|
||||
h[r, c] = np.maximum(0, val)
|
||||
if h[r, c] == 0:
|
||||
o[r, c] = -1
|
||||
else:
|
||||
o[r, c] = int(owner)
|
||||
|
||||
self._maps[entity_name]["owner"] = o
|
||||
self._maps[entity_name]["health"] = h
|
||||
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name], :, r, c
|
||||
] = np.logical_or(o[r, c] == self._idx_array, o[r, c] == -1).astype(bool)
|
||||
self._net_accessibility = None
|
||||
|
||||
else:
|
||||
self._maps[entity_name][r, c] = np.maximum(0, val)
|
||||
|
||||
if entity_name in self._blocked:
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.repeat(np.array([val]) == 0, self.n_agents, axis=0)
|
||||
self._net_accessibility = None
|
||||
|
||||
def set_point_add(self, entity_name, r, c, value, **kwargs):
|
||||
"""Add value to the existing entity state at the specified coordinates."""
|
||||
self.set_point(
|
||||
entity_name,
|
||||
r,
|
||||
c,
|
||||
value + self.get_point(entity_name, r, c, **kwargs),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def is_accessible(self, r, c, agent_id):
|
||||
"""Return True if agent with id agent_id can occupy the location [r, c]."""
|
||||
return bool(self.accessibility[agent_id, r, c])
|
||||
|
||||
def location_resources(self, r, c):
|
||||
"""Return {resource: health} dictionary for any resources at location [r, c]."""
|
||||
return {
|
||||
k: self._maps[k][r, c] for k in self._resources if self._maps[k][r, c] > 0
|
||||
}
|
||||
|
||||
def location_landmarks(self, r, c):
|
||||
"""Return {landmark: health} dictionary for any landmarks at location [r, c]."""
|
||||
tmp = {k: self.get_point(k, r, c) for k in self.keys()}
|
||||
return {k: v for k, v in tmp.items() if k not in self._resources and v > 0}
|
||||
|
||||
@property
|
||||
def unoccupied(self):
|
||||
"""Return a boolean map indicating which locations are unoccupied."""
|
||||
return self._unoccupied
|
||||
|
||||
@property
|
||||
def accessibility(self):
|
||||
"""Return a boolean map indicating which locations are accessible."""
|
||||
if self._net_accessibility is None:
|
||||
self._net_accessibility = self._accessibility.prod(axis=0).astype(bool)
|
||||
return self._net_accessibility
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
"""Return a boolean map indicating which locations are empty.
|
||||
|
||||
Empty locations have no landmarks or resources."""
|
||||
return self.state.sum(axis=0) == 0
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the concatenated maps of landmark and resources."""
|
||||
return np.stack([self.get(k) for k in self.keys()]).astype(np.float32)
|
||||
|
||||
@property
|
||||
def owner_state(self):
|
||||
"""Return the concatenated ownership maps of private landmarks."""
|
||||
return np.stack(
|
||||
[self.get(k, owner=True) for k in self._private_landmark_types]
|
||||
).astype(np.int16)
|
||||
|
||||
@property
|
||||
def state_dict(self):
|
||||
"""Return a dictionary of the map states."""
|
||||
return self._maps
|
||||
|
||||
|
||||
class World:
|
||||
"""Manages the environment's spatial- and agent-states.
|
||||
|
||||
The world object represents the state of the environment, minus whatever state
|
||||
information is implicitly maintained by separate components. The world object
|
||||
maintains the spatial state through an instance of the Maps class. Agent states
|
||||
are maintained through instances of Agent classes (subclasses of BaseAgent),
|
||||
with one such instance for each of the agents in the environment.
|
||||
|
||||
The world object is built during the environment construction, after the
|
||||
required entities have been registered. As part of the world object construction,
|
||||
it instantiates a map object and the agent objects.
|
||||
|
||||
The World class adds some functionality for interfacing with the spatial state
|
||||
(the maps object) and setting/resetting agent locations. But its function is
|
||||
mostly to wrap the stateful, non-component environment objects.
|
||||
|
||||
Args:
|
||||
world_size (list): A length-2 list specifying the dimensions of the 2D world.
|
||||
Interpreted as [height, width].
|
||||
n_agents (int): The number of total agents (does not include planner).
|
||||
agent_composition(dict): Dict of Agent Class names and amount
|
||||
world_resources (list): The resources registered during environment
|
||||
construction.
|
||||
world_landmarks (list): The landmarks registered during environment
|
||||
construction.
|
||||
multi_action_mode_agents (bool): Whether "mobile" agents use multi action mode
|
||||
(see BaseEnvironment in base_env.py).
|
||||
multi_action_mode_planner (bool): Whether the planner agent uses multi action
|
||||
mode (see BaseEnvironment in base_env.py).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size,
|
||||
agent_composition,
|
||||
world_resources,
|
||||
world_landmarks,
|
||||
multi_action_mode_agents,
|
||||
multi_action_mode_planner,
|
||||
):
|
||||
self.world_size = world_size
|
||||
|
||||
self.resources = world_resources
|
||||
self.landmarks = world_landmarks
|
||||
self.multi_action_mode_agents = bool(multi_action_mode_agents)
|
||||
self.multi_action_mode_planner = bool(multi_action_mode_planner)
|
||||
self._agent_class_idx_map={}
|
||||
#create agents
|
||||
self.agent_composition=agent_composition
|
||||
self.n_agents=0
|
||||
self._agents = []
|
||||
for k,v in agent_composition:
|
||||
self._agent_class_idx_map[k]=[]
|
||||
for offset in range(v):
|
||||
agent_class=agent_registry.get(k)
|
||||
self._agents.append(agent_class(self.n_agents,multi_action_mode_agents=self.multi_action_mode_agents))
|
||||
self._agent_class_idx_map[k].append(self.n_agents)
|
||||
self.n_agents+=1
|
||||
self.maps = Maps(world_size, self.n_agents, world_resources, world_landmarks)
|
||||
|
||||
planner_class = agent_registry.get("BasicPlanner")
|
||||
self._planner = planner_class(multi_action_mode=self.multi_action_mode_planner)
|
||||
|
||||
self.timestep = 0
|
||||
|
||||
# CUDA-related attributes (for GPU simulations).
|
||||
# These will be set via the env_wrapper, if required.
|
||||
self.use_cuda = False
|
||||
self.cuda_function_manager = None
|
||||
self.cuda_data_manager = None
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
"""Return a list of the agent objects in the world (sorted by index)."""
|
||||
return self._agents
|
||||
|
||||
@property
|
||||
def planner(self):
|
||||
"""Return the planner agent object."""
|
||||
return self._planner
|
||||
|
||||
@property
|
||||
def loc_map(self):
|
||||
"""Return a map indicating the agent index occupying each location.
|
||||
|
||||
Locations with a value of -1 are not occupied by an agent.
|
||||
"""
|
||||
idx_map = -np.ones(shape=self.world_size, dtype=np.int16)
|
||||
for agent in self.agents:
|
||||
r, c = agent.loc
|
||||
idx_map[r, c] = int(agent.idx)
|
||||
return idx_map
|
||||
|
||||
def get_random_order_agents(self):
|
||||
"""The agent list in a randomized order."""
|
||||
agent_order = np.random.permutation(self.n_agents)
|
||||
agents = self.agents
|
||||
return [agents[i] for i in agent_order]
|
||||
|
||||
def get_agent_class(self,idx):
|
||||
"""Return class name of agent"""
|
||||
return self.agents[idx].name
|
||||
|
||||
def is_valid(self, r, c):
|
||||
"""Return True if the coordinates [r, c] are within the game boundaries."""
|
||||
return (0 <= r < self.world_size[0]) and (0 <= c < self.world_size[1])
|
||||
|
||||
def is_location_accessible(self, r, c, agent):
|
||||
"""Return True if location [r, c] is accessible to agent."""
|
||||
if not self.is_valid(r, c):
|
||||
return False
|
||||
return self.maps.is_accessible(r, c, agent.idx)
|
||||
|
||||
def can_agent_occupy(self, r, c, agent):
|
||||
"""Return True if location [r, c] is accessible to agent and unoccupied."""
|
||||
if not self.is_location_accessible(r, c, agent):
|
||||
return False
|
||||
if self.maps.unoccupied[r, c]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_agent_locs(self):
|
||||
"""Take all agents off the board. Useful for resetting."""
|
||||
for agent in self.agents:
|
||||
agent.state["loc"] = [-1, -1]
|
||||
self.maps.clear_agent_loc()
|
||||
|
||||
def agent_locs_are_valid(self):
|
||||
"""Returns True if all agent locations comply with world semantics."""
|
||||
return all(
|
||||
self.is_location_accessible(*agent.loc, agent) for agent in self.agents
|
||||
)
|
||||
|
||||
def set_agent_loc(self, agent, r, c):
|
||||
"""Set the agent's location to coordinates [r, c] if possible.
|
||||
|
||||
If agent cannot occupy [r, c], do nothing."""
|
||||
if self.can_agent_occupy(r, c, agent):
|
||||
self.maps.set_agent_loc(agent, r, c)
|
||||
return [int(coord) for coord in agent.loc]
|
||||
|
||||
def location_resources(self, r, c):
|
||||
"""Return {resource: health} dictionary for any resources at location [r, c]."""
|
||||
if not self.is_valid(r, c):
|
||||
return {}
|
||||
return self.maps.location_resources(r, c)
|
||||
|
||||
def location_landmarks(self, r, c):
|
||||
"""Return {landmark: health} dictionary for any landmarks at location [r, c]."""
|
||||
if not self.is_valid(r, c):
|
||||
return {}
|
||||
return self.maps.location_landmarks(r, c)
|
||||
|
||||
def create_landmark(self, landmark_name, r, c, agent_idx=None):
|
||||
"""Place a landmark on the world map.
|
||||
|
||||
Place landmark of type landmark_name at the given coordinates, indicating
|
||||
agent ownership if applicable."""
|
||||
self.maps.set_point(landmark_name, r, c, 1, owner=agent_idx)
|
||||
|
||||
def consume_resource(self, resource_name, r, c):
|
||||
"""Consume a unit of resource_name from location [r, c]."""
|
||||
self.maps.set_point_add(resource_name, r, c, -1)
|
||||
Reference in New Issue
Block a user