it is working
This commit is contained in:
214
components/simple_gather.py
Normal file
214
components/simple_gather.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import rand
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
from ai_economist.foundation.entities import resource_registry, resources
|
||||
|
||||
@component_registry.add
|
||||
class SimpleGather(BaseComponent):
|
||||
"""
|
||||
Allows mobile agents to move around the world and collect resources and prevents
|
||||
agents from moving to invalid locations.
|
||||
Can be configured to include collection skill, where agents have heterogeneous
|
||||
probabilities of collecting bonus resources without additional labor cost.
|
||||
Args:
|
||||
move_labor (float): Labor cost associated with movement. Must be >= 0.
|
||||
Default is 1.0.
|
||||
collect_labor (float): Labor cost associated with collecting resources. This
|
||||
cost is added (in addition to any movement cost) when the agent lands on
|
||||
a tile that is populated with resources (triggering collection).
|
||||
Must be >= 0. Default is 1.0.
|
||||
skill_dist (str): Distribution type for sampling skills. Default ("none")
|
||||
gives all agents identical skill equal to a bonus prob of 0. "pareto" and
|
||||
"lognormal" sample skills from the associated distributions.
|
||||
"""
|
||||
|
||||
name = "SimpleGather"
|
||||
required_entities = ["Coin", "House", "Labor"]
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
|
||||
collect_labor=1.0,
|
||||
|
||||
skill_dist="none",
|
||||
**base_component_kwargs
|
||||
):
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
|
||||
|
||||
self.collect_labor = float(collect_labor)
|
||||
assert self.collect_labor >= 0
|
||||
|
||||
self.skill_dist = skill_dist.lower()
|
||||
assert self.skill_dist in ["none", "pareto", "lognormal"]
|
||||
|
||||
self.gathers = []
|
||||
self.commodities = [
|
||||
r for r in self.world.resources if resource_registry.get(r).collectible
|
||||
]
|
||||
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
Adds 1 action per commodity that can be picked up.
|
||||
"""
|
||||
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return len(self.commodities)
|
||||
return None
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
For mobile agents, add state field for collection skill.
|
||||
"""
|
||||
if agent_cls_name not in self.agent_subclasses:
|
||||
return {}
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return {"bonus_gather_prob": 0.0}
|
||||
raise NotImplementedError
|
||||
|
||||
def component_step(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
Pickup resources if available from env
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
gathers = []
|
||||
for agent in world.get_random_order_agents():
|
||||
|
||||
if self.name not in agent.action:
|
||||
continue
|
||||
resource_action = agent.get_component_action(
|
||||
self.name
|
||||
)
|
||||
|
||||
|
||||
if resource_action == 0: # NO-OP
|
||||
continue
|
||||
|
||||
resource_action -=1 # Starting at 1
|
||||
|
||||
r=self.commodities[resource_action]
|
||||
|
||||
if self.get_num_resources(r)>0:
|
||||
gather= self.pickup(r,agent)
|
||||
gathers.append(gather)
|
||||
|
||||
else:
|
||||
agent.bad_action=True
|
||||
continue
|
||||
|
||||
self.gathers.append(gathers)
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
Here, agents observe their collection skill. The planner does not observe
|
||||
anything from this component.
|
||||
"""
|
||||
num_agent=len(self.world.agents)
|
||||
obs_avai={}
|
||||
for r in self.commodities:
|
||||
key="pickup_perc_{}".format(r)
|
||||
pickProb=float(self.get_num_resources(r)/num_agent)
|
||||
if pickProb>1:
|
||||
pickProb=1
|
||||
obs_avai[key]=pickProb
|
||||
obs={}
|
||||
|
||||
for agent in self.world.agents:
|
||||
obs[agent.idx]={}
|
||||
obs[agent.idx]["bonus_gather_prob"]= agent.state["bonus_gather_prob"]
|
||||
obs[agent.idx].update(obs_avai)
|
||||
return obs
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
Prevent moving to adjacent tiles that are already occupied (or outside the
|
||||
boundaries of the world)
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
mask=[]
|
||||
for r in self.commodities:
|
||||
avail=0
|
||||
if self.get_num_resources(r)>0:
|
||||
avail=1
|
||||
mask.append(avail)
|
||||
|
||||
masks = {}
|
||||
|
||||
for agent in world.agents:
|
||||
masks[agent.idx]=mask
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
Re-sample agents' collection skills.
|
||||
"""
|
||||
for agent in self.world.agents:
|
||||
if self.skill_dist == "none":
|
||||
bonus_rate = 0.0
|
||||
elif self.skill_dist == "pareto":
|
||||
bonus_rate = np.minimum(2, np.random.pareto(3)) / 2
|
||||
elif self.skill_dist == "lognormal":
|
||||
bonus_rate = np.minimum(2, np.random.lognormal(-2.022, 0.938)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
agent.state["bonus_gather_prob"] = float(bonus_rate)
|
||||
|
||||
self.gathers = []
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Log resource collections.
|
||||
Returns:
|
||||
gathers (list): A list of gather events. Each entry corresponds to a single
|
||||
timestep and contains a description of any resource gathers that
|
||||
occurred on that timestep.
|
||||
"""
|
||||
return self.gathers
|
||||
|
||||
# For Components
|
||||
|
||||
def get_num_resources(self, res: resources.Resource):
|
||||
return self.world.maps.get_point(res,0,0)
|
||||
|
||||
def pickup(self, res: resources.Resource, agent ):
|
||||
n_gathered = 1 + (rand() < agent.state["bonus_gather_prob"])
|
||||
agent.state["inventory"][res] += n_gathered
|
||||
agent.state["endogenous"]["Labor"] += self.collect_labor
|
||||
self.world.consume_resource(res,0,0)
|
||||
# Log the gather
|
||||
return (
|
||||
dict(
|
||||
agent=agent.idx,
|
||||
resource=res,
|
||||
n=n_gathered,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user