215 lines
6.7 KiB
Python
215 lines
6.7 KiB
Python
# Copyright (c) 2020, salesforce.com, inc.
|
|
# All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
# For full license text, see the LICENSE file in the repo root
|
|
# or https://opensource.org/licenses/BSD-3-Clause
|
|
|
|
import numpy as np
|
|
from numpy.random import rand
|
|
|
|
from ai_economist.foundation.base.base_component import (
|
|
BaseComponent,
|
|
component_registry,
|
|
)
|
|
from ai_economist.foundation.entities import resource_registry, resources
|
|
|
|
@component_registry.add
|
|
class SimpleGather(BaseComponent):
|
|
"""
|
|
Allows mobile agents to move around the world and collect resources and prevents
|
|
agents from moving to invalid locations.
|
|
Can be configured to include collection skill, where agents have heterogeneous
|
|
probabilities of collecting bonus resources without additional labor cost.
|
|
Args:
|
|
move_labor (float): Labor cost associated with movement. Must be >= 0.
|
|
Default is 1.0.
|
|
collect_labor (float): Labor cost associated with collecting resources. This
|
|
cost is added (in addition to any movement cost) when the agent lands on
|
|
a tile that is populated with resources (triggering collection).
|
|
Must be >= 0. Default is 1.0.
|
|
skill_dist (str): Distribution type for sampling skills. Default ("none")
|
|
gives all agents identical skill equal to a bonus prob of 0. "pareto" and
|
|
"lognormal" sample skills from the associated distributions.
|
|
"""
|
|
|
|
name = "SimpleGather"
|
|
required_entities = ["Coin", "House", "Labor"]
|
|
agent_subclasses = ["BasicMobileAgent"]
|
|
|
|
def __init__(
|
|
self,
|
|
*base_component_args,
|
|
|
|
collect_labor=1.0,
|
|
|
|
skill_dist="none",
|
|
**base_component_kwargs
|
|
):
|
|
super().__init__(*base_component_args, **base_component_kwargs)
|
|
|
|
|
|
|
|
self.collect_labor = float(collect_labor)
|
|
assert self.collect_labor >= 0
|
|
|
|
self.skill_dist = skill_dist.lower()
|
|
assert self.skill_dist in ["none", "pareto", "lognormal"]
|
|
|
|
self.gathers = []
|
|
self.commodities = [
|
|
r for r in self.world.resources if resource_registry.get(r).collectible
|
|
]
|
|
|
|
|
|
# Required methods for implementing components
|
|
# --------------------------------------------
|
|
|
|
def get_n_actions(self, agent_cls_name):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
Adds 1 action per commodity that can be picked up.
|
|
"""
|
|
|
|
if agent_cls_name == "BasicMobileAgent":
|
|
return len(self.commodities)
|
|
return None
|
|
|
|
def get_additional_state_fields(self, agent_cls_name):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
For mobile agents, add state field for collection skill.
|
|
"""
|
|
if agent_cls_name not in self.agent_subclasses:
|
|
return {}
|
|
if agent_cls_name == "BasicMobileAgent":
|
|
return {"bonus_gather_prob": 0.0}
|
|
raise NotImplementedError
|
|
|
|
def component_step(self):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
Pickup resources if available from env
|
|
"""
|
|
world = self.world
|
|
|
|
gathers = []
|
|
for agent in world.get_random_order_agents():
|
|
|
|
if self.name not in agent.action:
|
|
continue
|
|
resource_action = agent.get_component_action(
|
|
self.name
|
|
)
|
|
|
|
|
|
if resource_action == 0: # NO-OP
|
|
continue
|
|
|
|
resource_action -=1 # Starting at 1
|
|
|
|
r=self.commodities[resource_action]
|
|
|
|
if self.get_num_resources(r)>0:
|
|
gather= self.pickup(r,agent)
|
|
gathers.append(gather)
|
|
|
|
else:
|
|
agent.bad_action=True
|
|
continue
|
|
|
|
self.gathers.append(gathers)
|
|
|
|
def generate_observations(self):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
Here, agents observe their collection skill. The planner does not observe
|
|
anything from this component.
|
|
"""
|
|
num_agent=len(self.world.agents)
|
|
obs_avai={}
|
|
for r in self.commodities:
|
|
key="pickup_perc_{}".format(r)
|
|
pickProb=float(self.get_num_resources(r)/num_agent)
|
|
if pickProb>1:
|
|
pickProb=1
|
|
obs_avai[key]=pickProb
|
|
obs={}
|
|
|
|
for agent in self.world.agents:
|
|
obs[agent.idx]={}
|
|
obs[agent.idx]["bonus_gather_prob"]= agent.state["bonus_gather_prob"]
|
|
obs[agent.idx].update(obs_avai)
|
|
return obs
|
|
|
|
def generate_masks(self, completions=0):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
Prevent moving to adjacent tiles that are already occupied (or outside the
|
|
boundaries of the world)
|
|
"""
|
|
world = self.world
|
|
|
|
mask=[]
|
|
for r in self.commodities:
|
|
avail=0
|
|
if self.get_num_resources(r)>0:
|
|
avail=1
|
|
mask.append(avail)
|
|
|
|
masks = {}
|
|
|
|
for agent in world.agents:
|
|
masks[agent.idx]=mask
|
|
|
|
return masks
|
|
|
|
# For non-required customization
|
|
# ------------------------------
|
|
|
|
def additional_reset_steps(self):
|
|
"""
|
|
See base_component.py for detailed description.
|
|
Re-sample agents' collection skills.
|
|
"""
|
|
for agent in self.world.agents:
|
|
if self.skill_dist == "none":
|
|
bonus_rate = 0.0
|
|
elif self.skill_dist == "pareto":
|
|
bonus_rate = np.minimum(2, np.random.pareto(3)) / 2
|
|
elif self.skill_dist == "lognormal":
|
|
bonus_rate = np.minimum(2, np.random.lognormal(-2.022, 0.938)) / 2
|
|
else:
|
|
raise NotImplementedError
|
|
agent.state["bonus_gather_prob"] = float(bonus_rate)
|
|
|
|
self.gathers = []
|
|
|
|
def get_dense_log(self):
|
|
"""
|
|
Log resource collections.
|
|
Returns:
|
|
gathers (list): A list of gather events. Each entry corresponds to a single
|
|
timestep and contains a description of any resource gathers that
|
|
occurred on that timestep.
|
|
"""
|
|
return self.gathers
|
|
|
|
# For Components
|
|
|
|
def get_num_resources(self, res: resources.Resource):
|
|
return self.world.maps.get_point(res,0,0)
|
|
|
|
def pickup(self, res: resources.Resource, agent ):
|
|
n_gathered = 1 + (rand() < agent.state["bonus_gather_prob"])
|
|
agent.state["inventory"][res] += n_gathered
|
|
agent.state["endogenous"]["Labor"] += self.collect_labor
|
|
self.world.consume_resource(res,0,0)
|
|
# Log the gather
|
|
return (
|
|
dict(
|
|
agent=agent.idx,
|
|
resource=res,
|
|
n=n_gathered,
|
|
)
|
|
)
|