Files
ai-econ/components/simple_gather.py
2023-01-11 19:04:20 +01:00

215 lines
6.7 KiB
Python

# Copyright (c) 2020, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root
# or https://opensource.org/licenses/BSD-3-Clause
import numpy as np
from numpy.random import rand
from ai_economist.foundation.base.base_component import (
BaseComponent,
component_registry,
)
from ai_economist.foundation.entities import resource_registry, resources
@component_registry.add
class SimpleGather(BaseComponent):
"""
Allows mobile agents to move around the world and collect resources and prevents
agents from moving to invalid locations.
Can be configured to include collection skill, where agents have heterogeneous
probabilities of collecting bonus resources without additional labor cost.
Args:
move_labor (float): Labor cost associated with movement. Must be >= 0.
Default is 1.0.
collect_labor (float): Labor cost associated with collecting resources. This
cost is added (in addition to any movement cost) when the agent lands on
a tile that is populated with resources (triggering collection).
Must be >= 0. Default is 1.0.
skill_dist (str): Distribution type for sampling skills. Default ("none")
gives all agents identical skill equal to a bonus prob of 0. "pareto" and
"lognormal" sample skills from the associated distributions.
"""
name = "SimpleGather"
required_entities = ["Coin", "House", "Labor"]
agent_subclasses = ["BasicMobileAgent"]
def __init__(
self,
*base_component_args,
collect_labor=1.0,
skill_dist="none",
**base_component_kwargs
):
super().__init__(*base_component_args, **base_component_kwargs)
self.collect_labor = float(collect_labor)
assert self.collect_labor >= 0
self.skill_dist = skill_dist.lower()
assert self.skill_dist in ["none", "pareto", "lognormal"]
self.gathers = []
self.commodities = [
r for r in self.world.resources if resource_registry.get(r).collectible
]
# Required methods for implementing components
# --------------------------------------------
def get_n_actions(self, agent_cls_name):
"""
See base_component.py for detailed description.
Adds 1 action per commodity that can be picked up.
"""
if agent_cls_name == "BasicMobileAgent":
return len(self.commodities)
return None
def get_additional_state_fields(self, agent_cls_name):
"""
See base_component.py for detailed description.
For mobile agents, add state field for collection skill.
"""
if agent_cls_name not in self.agent_subclasses:
return {}
if agent_cls_name == "BasicMobileAgent":
return {"bonus_gather_prob": 0.0}
raise NotImplementedError
def component_step(self):
"""
See base_component.py for detailed description.
Pickup resources if available from env
"""
world = self.world
gathers = []
for agent in world.get_random_order_agents():
if self.name not in agent.action:
continue
resource_action = agent.get_component_action(
self.name
)
if resource_action == 0: # NO-OP
continue
resource_action -=1 # Starting at 1
r=self.commodities[resource_action]
if self.get_num_resources(r)>0:
gather= self.pickup(r,agent)
gathers.append(gather)
else:
agent.bad_action=True
continue
self.gathers.append(gathers)
def generate_observations(self):
"""
See base_component.py for detailed description.
Here, agents observe their collection skill. The planner does not observe
anything from this component.
"""
num_agent=len(self.world.agents)
obs_avai={}
for r in self.commodities:
key="pickup_perc_{}".format(r)
pickProb=float(self.get_num_resources(r)/num_agent)
if pickProb>1:
pickProb=1
obs_avai[key]=pickProb
obs={}
for agent in self.world.agents:
obs[agent.idx]={}
obs[agent.idx]["bonus_gather_prob"]= agent.state["bonus_gather_prob"]
obs[agent.idx].update(obs_avai)
return obs
def generate_masks(self, completions=0):
"""
See base_component.py for detailed description.
Prevent moving to adjacent tiles that are already occupied (or outside the
boundaries of the world)
"""
world = self.world
mask=[]
for r in self.commodities:
avail=0
if self.get_num_resources(r)>0:
avail=1
mask.append(avail)
masks = {}
for agent in world.agents:
masks[agent.idx]=mask
return masks
# For non-required customization
# ------------------------------
def additional_reset_steps(self):
"""
See base_component.py for detailed description.
Re-sample agents' collection skills.
"""
for agent in self.world.agents:
if self.skill_dist == "none":
bonus_rate = 0.0
elif self.skill_dist == "pareto":
bonus_rate = np.minimum(2, np.random.pareto(3)) / 2
elif self.skill_dist == "lognormal":
bonus_rate = np.minimum(2, np.random.lognormal(-2.022, 0.938)) / 2
else:
raise NotImplementedError
agent.state["bonus_gather_prob"] = float(bonus_rate)
self.gathers = []
def get_dense_log(self):
"""
Log resource collections.
Returns:
gathers (list): A list of gather events. Each entry corresponds to a single
timestep and contains a description of any resource gathers that
occurred on that timestep.
"""
return self.gathers
# For Components
def get_num_resources(self, res: resources.Resource):
return self.world.maps.get_point(res,0,0)
def pickup(self, res: resources.Resource, agent ):
n_gathered = 1 + (rand() < agent.state["bonus_gather_prob"])
agent.state["inventory"][res] += n_gathered
agent.state["endogenous"]["Labor"] += self.collect_labor
self.world.consume_resource(res,0,0)
# Log the gather
return (
dict(
agent=agent.idx,
resource=res,
n=n_gathered,
)
)