adding ai_economist for modding
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
logs/*
|
||||
_pycache_
|
||||
_pycache_
|
||||
*.pyc
|
||||
*tfevents*
|
||||
11
agents/consumer_agent.py
Normal file
11
agents/consumer_agent.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
|
||||
@agent_registry.add
|
||||
class ConsumerAgent(BaseAgent):
|
||||
"""
|
||||
A basic mobile agent represents an individual actor in the economic simulation.
|
||||
"Mobile" refers to agents of this type being able to move around in the 2D world.
|
||||
"""
|
||||
|
||||
name = "ConsumerAgent"
|
||||
11
agents/trading_agent.py
Normal file
11
agents/trading_agent.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
|
||||
@agent_registry.add
|
||||
class TradingAgent(BaseAgent):
|
||||
"""
|
||||
A basic mobile agent represents an individual actor in the economic simulation.
|
||||
"Mobile" refers to agents of this type being able to move around in the 2D world.
|
||||
"""
|
||||
|
||||
name = "TradingAgent"
|
||||
12
ai_economist/LICENSE.txt
Normal file
12
ai_economist/LICENSE.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
Copyright (c) 2020, Salesforce.com, Inc.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
7
ai_economist/__init__.py
Normal file
7
ai_economist/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist import foundation
|
||||
5
ai_economist/datasets/__init__.py
Normal file
5
ai_economist/datasets/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
27
ai_economist/datasets/covid19_datasets/README.md
Normal file
27
ai_economist/datasets/covid19_datasets/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
## List of COVID-19 datasources used
|
||||
|
||||
1. **US state government policies** (Oxford Covid-19 Government Response Tracker (OxCGRT))
|
||||
|
||||
https://github.com/OxCGRT/USA-covid-policy
|
||||
|
||||
|
||||
2. **US federal government direct payments** (Committee for a Responsible Federal Budget)
|
||||
|
||||
https://www.covidmoneytracker.org/
|
||||
|
||||
https://docs.google.com/spreadsheets/d/1Nr_J5wLfUT4IzqSXkYbdOXrRgEkBxhX0/edit#gid=682404301
|
||||
|
||||
|
||||
3. **US deaths data** (COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University)
|
||||
|
||||
https://github.com/CSSEGISandData/COVID-19
|
||||
|
||||
|
||||
4. **US unemployment** (Bureau of Labor and Statistics)
|
||||
|
||||
https://www.bls.gov/lau/
|
||||
|
||||
|
||||
5. **US vaccinations** (Our World in Data)
|
||||
|
||||
https://ourworldindata.org/covid-vaccinations
|
||||
5
ai_economist/datasets/covid19_datasets/__init__.py
Normal file
5
ai_economist/datasets/covid19_datasets/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1 @@
|
||||
{"DATE_FORMAT": "%Y-%m-%d", "STRINGENCY_POLICY_KEY": "StringencyIndex", "NUM_STRINGENCY_LEVELS": 10, "SIR_SMOOTHING_STD": 10, "SIR_MORTALITY": 0.02, "SIR_GAMMA": 0.07142857142857142, "US_STATE_IDX_TO_STATE_NAME": {"0": "Alabama", "1": "Alaska", "2": "Arizona", "3": "Arkansas", "4": "California", "5": "Colorado", "6": "Connecticut", "7": "Delaware", "8": "District of Columbia", "9": "Florida", "10": "Georgia", "11": "Hawaii", "12": "Idaho", "13": "Illinois", "14": "Indiana", "15": "Iowa", "16": "Kansas", "17": "Kentucky", "18": "Louisiana", "19": "Maine", "20": "Maryland", "21": "Massachusetts", "22": "Michigan", "23": "Minnesota", "24": "Mississippi", "25": "Missouri", "26": "Montana", "27": "Nebraska", "28": "Nevada", "29": "New Hampshire", "30": "New Jersey", "31": "New Mexico", "32": "New York", "33": "North Carolina", "34": "North Dakota", "35": "Ohio", "36": "Oklahoma", "37": "Oregon", "38": "Pennsylvania", "39": "Rhode Island", "40": "South Carolina", "41": "South Dakota", "42": "Tennessee", "43": "Texas", "44": "Utah", "45": "Vermont", "46": "Virginia", "47": "Washington", "48": "West Virginia", "49": "Wisconsin", "50": "Wyoming"}, "US_STATE_POPULATION": [4903185, 740995, 7278717, 3017804, 39512223, 5758736, 3565287, 973764, 705749, 21477737, 10617423, 1415872, 1787065, 12671821, 6732219, 3155070, 2913314, 4467673, 4648794, 1344212, 6045680, 6892503, 9986857, 5639632, 2976149, 6626371, 1068778, 1934408, 3080156, 1359711, 8882190, 2096829, 19453561, 10488084, 762062, 11689100, 3956971, 4217737, 12801989, 1059361, 5148714, 884659, 6829174, 28995881, 3205958, 623989, 8535519, 7614893, 1792147, 5822434, 578759], "US_POPULATION": 328737916, "GDP_PER_CAPITA": 65300}
|
||||
Binary file not shown.
1450
ai_economist/datasets/covid19_datasets/fit_model_parameters.ipynb
Normal file
1450
ai_economist/datasets/covid19_datasets/fit_model_parameters.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,846 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) 2021, salesforce.com, inc. \n",
|
||||
"All rights reserved. \n",
|
||||
"SPDX-License-Identifier: BSD-3-Clause \n",
|
||||
"For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# This notebook will be used to gather real-world data and perform data processing in order to use it in the covid-19 simulation.\n",
|
||||
"\n",
|
||||
"### All the downloaded data will be formatted into pandas dataframes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Below is the list of COVID-19 data sources used in this notebook\n",
|
||||
"\n",
|
||||
"1. **US state government policies** (Oxford Covid-19 Government Response Tracker (OxCGRT))\n",
|
||||
"\n",
|
||||
" https://github.com/OxCGRT/USA-covid-policy\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"2. **US federal government direct payments** (Committee for a Responsible Federal Budget)\n",
|
||||
"\n",
|
||||
" https://www.covidmoneytracker.org/\n",
|
||||
" \n",
|
||||
" https://docs.google.com/spreadsheets/d/1Nr_J5wLfUT4IzqSXkYbdOXrRgEkBxhX0/edit#gid=682404301\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"3. **US deaths data** (COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University)\n",
|
||||
"\n",
|
||||
" https://github.com/CSSEGISandData/COVID-19\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"4. **US vaccinations** (Our World in Data)\n",
|
||||
" \n",
|
||||
" https://ourworldindata.org/covid-vaccinations\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"5. **US unemployment** (Bureau of Labor and Statistics)\n",
|
||||
"\n",
|
||||
" https://www.bls.gov/lau/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datetime import datetime, timedelta\n",
|
||||
"import json\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"import pickle\n",
|
||||
"import scipy\n",
|
||||
"from scipy.signal import convolve"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Classes to fetch the real-world data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ai_economist.datasets.covid19_datasets.us_policies import DatasetCovidPoliciesUS\n",
|
||||
"from ai_economist.datasets.covid19_datasets.us_deaths import DatasetCovidDeathsUS\n",
|
||||
"from ai_economist.datasets.covid19_datasets.us_vaccinations import DatasetCovidVaccinationsUS\n",
|
||||
"from ai_economist.datasets.covid19_datasets.us_unemployment import DatasetCovidUnemploymentUS"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Set a base directory where you would like to download real world data. The latest data will be downloaded into a folder within the base directory, named using the current date"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_DATA_DIR_PATH = \"/tmp/covid19_data\" # SPECIFY A BASE DIRECTORY TO STORE ALL THE DOWNLOADED DATA"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DOWNLOAD_LATEST_DATA = True # Download the latest data or use whatever is saved earlier \n",
|
||||
"CURRENT_DATE = datetime.now()\n",
|
||||
"DATE_FORMAT = \"%Y-%m-%d\"\n",
|
||||
"date_string = CURRENT_DATE.strftime(DATE_FORMAT).replace('/','-')\n",
|
||||
"data_dir = os.path.join(BASE_DATA_DIR_PATH, date_string)\n",
|
||||
"\n",
|
||||
"print(\"All the data will be downloaded to the directory: '{}'.\".format(data_dir))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set up dictionary to write model constants\n",
|
||||
"model_constants = {}\n",
|
||||
"model_constants_filename = \"model_constants.json\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gather real-world data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 1. COVID-19 US State Government Policies\n",
|
||||
"### Source: Oxford Covid-19 Government Response Tracker (OxCGRT) \n",
|
||||
"(https://github.com/OxCGRT/USA-covid-policy)\n",
|
||||
"\n",
|
||||
"**NOTE:** All data will use the same format as **policy_df** (below) and use the same date index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"covid_policies_us = DatasetCovidPoliciesUS(\n",
|
||||
" data_dir=data_dir,\n",
|
||||
" download_latest_data=DOWNLOAD_LATEST_DATA\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Which of the policy indicators to treat as the open/close level\n",
|
||||
"STRINGENCY_POLICY_KEY = 'StringencyIndex'\n",
|
||||
"# Number of levels to discretize the stringency policy into. \n",
|
||||
"# In the context of reinforcement learning, this also determines the action space of the agents.\n",
|
||||
"NUM_STRINGENCY_LEVELS = 10\n",
|
||||
"\n",
|
||||
"policies_us_df = covid_policies_us.process_policy_data(\n",
|
||||
" stringency_policy_key=STRINGENCY_POLICY_KEY,\n",
|
||||
" num_stringency_levels=NUM_STRINGENCY_LEVELS\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Policy data are available between {} and {}\".format(policies_us_df[\"Date\"].min(), \n",
|
||||
" policies_us_df[\"Date\"].max()))\n",
|
||||
"\n",
|
||||
"policy_df = policies_us_df.pivot(\n",
|
||||
" index=\"Date\", columns=\"RegionName\", values=STRINGENCY_POLICY_KEY\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the common date index that all the dataframes will use\n",
|
||||
"COMMON_DATE_INDEX = policy_df.index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the list of states (in order) all the dataframes will use\n",
|
||||
"US_STATE_ORDER = policy_df.columns.values"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize the stringency level for a specified US state\n",
|
||||
"state = \"California\"\n",
|
||||
"policy_df[state].plot(figsize=(15,5), x='Date', title=\"Stringency Level for {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 2. COVID-19 Federal government subsidies (direct payments) to the states\n",
|
||||
"### Source: Committee For A Responsible Federal Budget\n",
|
||||
"https://www.covidmoneytracker.org/\n",
|
||||
"\n",
|
||||
"### Direct payments provided by the Federal Government so far are recorded in this google spreadsheet\n",
|
||||
"https://docs.google.com/spreadsheets/d/1Nr_J5wLfUT4IzqSXkYbdOXrRgEkBxhX0/edit#gid=682404301\n",
|
||||
"### Read as (date: direct payment amount)\n",
|
||||
"2020-04-15: 274B\n",
|
||||
"\n",
|
||||
"2020-12-27: 142B\n",
|
||||
"\n",
|
||||
"2021-03-11: 386B"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"subsidy_df = pd.DataFrame(policy_df.index).set_index(\"Date\")\n",
|
||||
"subsidy_df[\"USA\"] = 0.0\n",
|
||||
"\n",
|
||||
"subsidy_df.loc[\"2020-04-15\", \"USA\"] = 274e9\n",
|
||||
"subsidy_df.loc[\"2020-12-27\", \"USA\"] = 142e9\n",
|
||||
"subsidy_df.loc[\"2021-03-11\", \"USA\"] = 386e9"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 3. COVID-19 Deaths data\n",
|
||||
"### Source: COVID-19 Data Repository by the Center for Systems Science and Engineering (CSSE) at Johns Hopkins University \n",
|
||||
"(https://github.com/CSSEGISandData/COVID-19)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"deaths_us_df = DatasetCovidDeathsUS(\n",
|
||||
" data_dir=data_dir,\n",
|
||||
" download_latest_data=DOWNLOAD_LATEST_DATA\n",
|
||||
").df\n",
|
||||
"\n",
|
||||
"print(\"COVID-19 death data for the US is available between {} and {}\".format(\n",
|
||||
" deaths_us_df.columns[12], deaths_us_df.columns[-1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Retain just the states in US_STATE_ORDER\n",
|
||||
"deaths_us_df = deaths_us_df[deaths_us_df.Province_State.isin(US_STATE_ORDER)]\n",
|
||||
"\n",
|
||||
"# We will visualize this later in the notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 4. COVID-19 Vaccination Data\n",
|
||||
"### Source: Our World in Data\n",
|
||||
"(https://ourworldindata.org/covid-vaccinations)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vaccinations_us_df = DatasetCovidVaccinationsUS(\n",
|
||||
" data_dir=data_dir,\n",
|
||||
" download_latest_data=DOWNLOAD_LATEST_DATA\n",
|
||||
").df\n",
|
||||
"\n",
|
||||
"vaccination_dates = sorted(vaccinations_us_df.date.unique())\n",
|
||||
"print(\"Vaccination data is available between {} and {}\".format(min(vaccination_dates), max(vaccination_dates)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vaccinated_df = vaccinations_us_df.pivot(\n",
|
||||
" index=\"date\", columns=\"location\", values=\"people_fully_vaccinated\"\n",
|
||||
")[US_STATE_ORDER]\n",
|
||||
"\n",
|
||||
"vaccinated_df.index = pd.to_datetime(vaccinated_df.index)\n",
|
||||
"vaccinated_df = vaccinated_df.reindex(COMMON_DATE_INDEX).fillna(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize the vaccinations for a specified US state\n",
|
||||
"# Warning: the last value may not be updated (may show it to be 0)\n",
|
||||
"\n",
|
||||
"state = \"California\"\n",
|
||||
"vaccinated_df[state].plot(figsize=(15,5), x='Date', title=\"Vaccinations for {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using deaths and vaccinations to compute the susceptible-infected-recovered (SIR) numbers\n",
|
||||
"\n",
|
||||
"Our SIR data will only treat **deaths** as ground-truth.\n",
|
||||
"\n",
|
||||
"Given death data and some assumed constants about the _death rate_ and _recovery rate_ , we can apply some \"SIR algebra\" (i.e. solve for unknowns using the SIR equations) to _infer_ quantities like total \"recovered\", number of infected people, and ultimately **Beta**, which is the rate of transmission times the number of people an infected person comes into contact with."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# For data representation, we will want to build a dataframe for...\n",
|
||||
"# ... deaths...\n",
|
||||
"deaths_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"smoothed_deaths_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"# ... (inferred) SIR states...\n",
|
||||
"susceptible_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"infected_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"recovered_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"# ... and (inferred) Beta.\n",
|
||||
"beta_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# STD of the Gaussian smoothing window applied to the death data.\n",
|
||||
"SIR_SMOOTHING_STD = 10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fill the death dataframe from (smoothed) raw data\n",
|
||||
"\n",
|
||||
"def smooth(x, gauss_std=10):\n",
|
||||
" \"\"\"\n",
|
||||
" gauss_std: standard deviation of the Gaussian smoothing window applied to the death data.\n",
|
||||
" \"\"\"\n",
|
||||
" if gauss_std <= 0:\n",
|
||||
" return x\n",
|
||||
" # To invalidate the near-edge results, bookend the input x with nans\n",
|
||||
" x = np.concatenate([[np.nan], np.array(x), [np.nan]])\n",
|
||||
" \n",
|
||||
" kernel = scipy.stats.norm.pdf(\n",
|
||||
" np.linspace(-3*gauss_std, 3*gauss_std, 1+6*gauss_std),\n",
|
||||
" scale=gauss_std\n",
|
||||
" )\n",
|
||||
" normer = np.ones_like(x)\n",
|
||||
" smoothed_x = convolve(x, kernel, mode='same') / convolve(normer, kernel, mode='same')\n",
|
||||
" \n",
|
||||
" # Remove the indices added by the nan padding\n",
|
||||
" return smoothed_x[1:-1]\n",
|
||||
"\n",
|
||||
"for us_state_name in US_STATE_ORDER:\n",
|
||||
" state_deaths = deaths_us_df[deaths_us_df['Province_State']==us_state_name]\n",
|
||||
" cumulative_state_deaths = []\n",
|
||||
" for d in COMMON_DATE_INDEX:\n",
|
||||
" date_string = '{d.month}/{d.day}/{y}'.format(d=d, y=d.year % 2000)\n",
|
||||
" if date_string in state_deaths:\n",
|
||||
" cumulative_state_deaths.append(\n",
|
||||
" state_deaths[date_string].sum()\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" cumulative_state_deaths.append(\n",
|
||||
" np.nan\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Store raw numbers (for direct comparison)\n",
|
||||
" deaths_df[us_state_name] = cumulative_state_deaths\n",
|
||||
" \n",
|
||||
" # Store smoothed numbers (for beta analysis)\n",
|
||||
" smoothed_cumulative_state_deaths = smooth(cumulative_state_deaths, gauss_std=SIR_SMOOTHING_STD)\n",
|
||||
" smoothed_deaths_df[us_state_name] = smoothed_cumulative_state_deaths"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"state_deaths = deaths_us_df[deaths_us_df['Province_State']==\"California\"]\n",
|
||||
"cumulative_state_deaths = []\n",
|
||||
"for d in COMMON_DATE_INDEX:\n",
|
||||
" date_string = '{d.month}/{d.day}/{y}'.format(d=d, y=d.year % 2000)\n",
|
||||
" if date_string in state_deaths:\n",
|
||||
" cumulative_state_deaths.append(\n",
|
||||
" state_deaths[date_string].sum()\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" cumulative_state_deaths.append(\n",
|
||||
" np.nan\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize the deaths for a specified US state\n",
|
||||
"state = \"California\"\n",
|
||||
"\n",
|
||||
"# Some values near the ends may be \"missing\" because of smoothing\n",
|
||||
"deaths_df[state].plot(figsize=(15,5), x='Date', ylim=[0, 65000]);\n",
|
||||
"smoothed_deaths_df[state].plot(figsize=(15,5), x='Date', title=\"COVID deaths in {}\".format(state), ylim=[0, 65000], grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Death rate: fraction of infected persons who die\n",
|
||||
"SIR_MORTALITY = 0.02\n",
|
||||
"\n",
|
||||
"# Recovery rate: the inverse of expected time someone remains infected\n",
|
||||
"SIR_GAMMA = 1 / 14"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the core \"SIR algebra\" used to infer S, I, R, and Beta at each date.\n",
|
||||
"\n",
|
||||
"def infer_sir_and_beta(us_state_name):\n",
|
||||
" state_population = deaths_us_df[deaths_us_df['Province_State']==us_state_name]['Population'].sum()\n",
|
||||
" \n",
|
||||
" # Helpful to do this math in normalized numbers\n",
|
||||
" dead = np.array(smoothed_deaths_df[us_state_name]) / state_population\n",
|
||||
" vaccinated = np.array(vaccinated_df[us_state_name]) / state_population\n",
|
||||
" \n",
|
||||
" # Dead is the fraction of \"recovered\" that did not survive\n",
|
||||
" # Also, the vaccinated lot is part of the recovered\n",
|
||||
" recovered = dead / SIR_MORTALITY + vaccinated\n",
|
||||
" \n",
|
||||
" # The daily change in recovered (ignoring the vaccinated) is a fraction of the infected population on the previous day\n",
|
||||
" infected = np.nan * np.zeros_like(dead)\n",
|
||||
" infected[:-1] = (recovered[1:] - recovered[:-1] - (vaccinated[1:] - vaccinated[:-1])) / SIR_GAMMA\n",
|
||||
" \n",
|
||||
" # S+I+R must always = 1\n",
|
||||
" susceptible = 1 - infected - recovered\n",
|
||||
" \n",
|
||||
" # Here's where things get interesting. The change in infected is due to...\n",
|
||||
" change_in_i = infected[1:] - infected[:-1]\n",
|
||||
" # ... infected people that transition to the recovered state (decreases I)...\n",
|
||||
" expected_change_from_recovery = -infected[:-1] * SIR_GAMMA\n",
|
||||
" # ... and susceptible people that transition to the infected state (increases I).\n",
|
||||
" new_infections = change_in_i - expected_change_from_recovery\n",
|
||||
" \n",
|
||||
" # With these pieces, we can solve for Beta.\n",
|
||||
" beta_ = new_infections / (infected[:-1] * susceptible[:-1] + 1e-6)\n",
|
||||
" beta_ = np.clip(beta_, 0, 1)\n",
|
||||
" # Apply a threshold in terms of normalized daily deaths (if too low, beta estimates are bad)\n",
|
||||
" normalized_daily_deaths = dead[1:]-dead[:-1]\n",
|
||||
" ndd_lookback = np.zeros_like(new_infections)\n",
|
||||
" lookback_window = 3*SIR_SMOOTHING_STD\n",
|
||||
" ndd_cutoff = 1e-8\n",
|
||||
" ndd_lookback[lookback_window:] = normalized_daily_deaths[:-lookback_window]\n",
|
||||
" beta_[np.logical_not(ndd_lookback > 1e-8)] = np.nan\n",
|
||||
" \n",
|
||||
" beta = np.nan * np.zeros_like(dead)\n",
|
||||
" beta[:-1] = beta_\n",
|
||||
" \n",
|
||||
" # Undo normalization\n",
|
||||
" susceptible *= state_population\n",
|
||||
" infected *= state_population\n",
|
||||
" recovered *= state_population\n",
|
||||
" \n",
|
||||
" return susceptible, infected, recovered, beta"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fill the SIR and Beta dataframes with their inferred values\n",
|
||||
"for st in US_STATE_ORDER:\n",
|
||||
" susceptible_df[st], infected_df[st], recovered_df[st], beta_df[st] = infer_sir_and_beta(us_state_name=st)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Visualize the SIR and BETA for a specified US state\n",
|
||||
"# Warning: some values near the ends may be \"missing\" because of smoothing\n",
|
||||
"\n",
|
||||
"state = \"California\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"susceptible_df[state].plot(figsize=(15,3), x='Date', title=\"(Inferred) Susceptible Population in {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"infected_df[state].plot(figsize=(15,3), x='Date', title=\"(Inferred) Infected Population in {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"recovered_df[state].plot(figsize=(15,3), x='Date', title=\"(Inferred) Recovered Population in {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"beta_df[state].plot(figsize=(15,3), x='Date', title=\"(Inferred) SIR Beta in {}\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 5. COVID-19 Unemployment data\n",
|
||||
"### Source: Bureau of Labor and Statistics\n",
|
||||
"\n",
|
||||
"https://www.bls.gov/lau/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"monthly_unemployment_us = DatasetCovidUnemploymentUS(\n",
|
||||
" data_dir=data_dir,\n",
|
||||
" download_latest_data=DOWNLOAD_LATEST_DATA).data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sample_monthly_unemployment = monthly_unemployment_us['California']\n",
|
||||
"unemp_year_keys = sorted(sample_monthly_unemployment.keys())\n",
|
||||
"unemp_starting_month_key = sorted(sample_monthly_unemployment[unemp_year_keys[0]].keys())[0]\n",
|
||||
"unemp_ending_month_key = sorted(sample_monthly_unemployment[unemp_year_keys[-1]].keys())[-1]\n",
|
||||
"unemp_starting_date = datetime.strptime(\n",
|
||||
" str(unemp_year_keys[0]) + '-' + str(unemp_ending_month_key+1) + '-1', DATE_FORMAT)\n",
|
||||
"unemp_ending_date = datetime.strptime(\n",
|
||||
" str(unemp_year_keys[-1]) + '-' + str(unemp_ending_month_key+1) + '-1', DATE_FORMAT) - timedelta(1)\n",
|
||||
"\n",
|
||||
"print(\"Unemployment data is available between {} and {}\".format(datetime.strftime(unemp_starting_date, DATE_FORMAT),\n",
|
||||
" datetime.strftime(unemp_ending_date, DATE_FORMAT)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert this to a daily unemployment dataframe\n",
|
||||
"\n",
|
||||
"unemployment_df = pd.DataFrame(COMMON_DATE_INDEX, columns=['Date']).set_index('Date')\n",
|
||||
"\n",
|
||||
"for us_state_name in monthly_unemployment_us.keys():\n",
|
||||
" unemployment_df[us_state_name] = [\n",
|
||||
" monthly_unemployment_us[us_state_name][x.year].get(x.month, np.nan)\n",
|
||||
" for x in unemployment_df.index\n",
|
||||
" ]\n",
|
||||
"unemployment_df = unemployment_df[US_STATE_ORDER]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Visualize the unemployment rate for a specified US state\n",
|
||||
"# There is likely going to be some unemployment data missing at the tail end, \n",
|
||||
"# as the unemployment data isn't updated as frequently as the other data.\n",
|
||||
"\n",
|
||||
"state = \"California\"\n",
|
||||
"unemployment_df[state].plot(figsize=(15,5), x='Date', title=\"Unemployment for {} (%)\".format(state), grid=True);"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Unemployment rate -> unemployed (the number of unemployed people)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"us_state_to_pop_dict = {}\n",
|
||||
"for us_state in US_STATE_ORDER:\n",
|
||||
" us_state_to_pop_dict[us_state] = deaths_us_df[deaths_us_df.Province_State==us_state].Population.sum()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"unemployed_df = unemployment_df.multiply([us_state_to_pop_dict[col]/100.0 for col in unemployment_df.columns])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Save some of the data processing constants for use within the environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_constants_dict = {}\n",
|
||||
"\n",
|
||||
"model_constants_dict[\"DATE_FORMAT\"] = DATE_FORMAT\n",
|
||||
"model_constants_dict[\"STRINGENCY_POLICY_KEY\"] = STRINGENCY_POLICY_KEY\n",
|
||||
"model_constants_dict[\"NUM_STRINGENCY_LEVELS\"] = int(NUM_STRINGENCY_LEVELS)\n",
|
||||
"model_constants_dict[\"SIR_SMOOTHING_STD\"] = SIR_SMOOTHING_STD\n",
|
||||
"model_constants_dict[\"SIR_MORTALITY\"] = SIR_MORTALITY\n",
|
||||
"model_constants_dict[\"SIR_GAMMA\"] = SIR_GAMMA\n",
|
||||
"model_constants_dict[\"US_STATE_IDX_TO_STATE_NAME\"] = {\n",
|
||||
" us_state_idx: us_state for us_state_idx, us_state in enumerate(US_STATE_ORDER)\n",
|
||||
"}\n",
|
||||
"model_constants_dict[\"US_STATE_POPULATION\"] = [int(us_state_to_pop_dict[us_state]) for us_state in US_STATE_ORDER]\n",
|
||||
"model_constants_dict[\"US_POPULATION\"] = int(sum([us_state_to_pop_dict[us_state] for us_state in US_STATE_ORDER]))\n",
|
||||
"\n",
|
||||
"# 2019: https://data.worldbank.org/indicator/NY.GDP.PCAP.CD?locations=US&view=chart\n",
|
||||
"model_constants_dict[\"GDP_PER_CAPITA\"] = 65300 # TODO: Load this in from model_constants.json.\n",
|
||||
"\n",
|
||||
"model_constants_filename = \"model_constants.json\"\n",
|
||||
"with open(os.path.join(data_dir, model_constants_filename), \"w\") as fp: \n",
|
||||
" json.dump(model_constants_dict, fp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Save all the processed dataframes in order to use for model fitting notebook (fit_model_parameters.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataframes = {\n",
|
||||
" \"policy\": policy_df,\n",
|
||||
" \"subsidy\": subsidy_df,\n",
|
||||
" \"deaths\": deaths_df,\n",
|
||||
" \"vaccinated\": vaccinated_df,\n",
|
||||
" \"smoothed_deaths\": smoothed_deaths_df,\n",
|
||||
" \"susceptible\": susceptible_df,\n",
|
||||
" \"infected\": infected_df,\n",
|
||||
" \"recovered\": recovered_df,\n",
|
||||
" \"beta\": beta_df,\n",
|
||||
" \"unemployment\": unemployment_df,\n",
|
||||
" \"unemployed\": unemployed_df,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"for k, df in dataframes.items():\n",
|
||||
" if k == \"subsidy\": # This is at the USA level, not at the US states level\n",
|
||||
" continue\n",
|
||||
" assert (df.columns.to_list() == US_STATE_ORDER).all()\n",
|
||||
"\n",
|
||||
"with open(os.path.join(data_dir, 'dataframes.pkl'), 'wb') as F:\n",
|
||||
" pickle.dump(dataframes, F)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Also save all the data as numpy arrays for use within the covid19 simulation environment"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"real_world_data = {}\n",
|
||||
"for key in dataframes:\n",
|
||||
" real_world_data[key] = dataframes[key].values\n",
|
||||
" \n",
|
||||
"# Save the real-world data as a .npz for use within the environment\n",
|
||||
"np.savez(os.path.join(data_dir, \"real_world_data.npz\"), **real_world_data) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Finally, in order to use this gathered real-world data when you run the covid19 simulation, you will need to also"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 1. Run the \"fit_model_parameters.ipynb\" notebook with the base data directory specified below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"BASE_DATA_DIR_PATH = '{}'\".format(BASE_DATA_DIR_PATH))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2. Set \"path_to_data_and_fitted_params\" in the env config also to the data directory below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"path_to_data_and_fitted_params = '{}'\".format(data_dir))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
54
ai_economist/datasets/covid19_datasets/us_deaths.py
Normal file
54
ai_economist/datasets/covid19_datasets/us_deaths.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
|
||||
class DatasetCovidDeathsUS:
|
||||
"""
|
||||
Class to load COVID-19 deaths data for the US.
|
||||
Source: https://github.com/CSSEGISandData/COVID-19
|
||||
Note: in this dataset, reporting deaths only started on the 22nd of January 2020,
|
||||
|
||||
Attributes:
|
||||
df: Timeseries dataframe of confirmed COVID deaths for all the US states
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir="", download_latest_data=True):
|
||||
if not os.path.exists(data_dir):
|
||||
print(
|
||||
"Creating a dynamic data directory to store "
|
||||
"COVID-19 deaths data: {}".format(data_dir)
|
||||
)
|
||||
os.makedirs(data_dir)
|
||||
|
||||
filename = "daily_us_deaths.csv"
|
||||
if download_latest_data or filename not in os.listdir(data_dir):
|
||||
print(
|
||||
"Fetching latest U.S. COVID-19 deaths data from John Hopkins, "
|
||||
"and saving it in {}".format(data_dir)
|
||||
)
|
||||
|
||||
req = requests.get(
|
||||
"https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/"
|
||||
"csse_covid_19_data/csse_covid_19_time_series/"
|
||||
"time_series_covid19_deaths_US.csv"
|
||||
)
|
||||
self.df = pd.read_csv(BytesIO(req.content))
|
||||
self.df.to_csv(
|
||||
os.path.join(data_dir, filename)
|
||||
) # Note: performs an overwrite
|
||||
else:
|
||||
print(
|
||||
"Not fetching the latest U.S. COVID-19 deaths data from John Hopkins."
|
||||
" Using whatever was saved earlier in {}!!".format(data_dir)
|
||||
)
|
||||
assert filename in os.listdir(data_dir)
|
||||
self.df = pd.read_csv(os.path.join(data_dir, filename), low_memory=False)
|
||||
122
ai_economist/datasets/covid19_datasets/us_policies.py
Normal file
122
ai_economist/datasets/covid19_datasets/us_policies.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
|
||||
class DatasetCovidPoliciesUS:
|
||||
"""
|
||||
Class to load COVID-19 government policies for the US states.
|
||||
Source: https://github.com/OxCGRT/USA-covid-policy
|
||||
|
||||
Other references:
|
||||
- Codebook: https://github.com/OxCGRT/covid-policy-tracker/blob/master/
|
||||
documentation/codebook.md
|
||||
- Index computation methodology: https://github.com/OxCGRT/covid-policy-tracker/
|
||||
blob/master/documentation/index_methodology.md
|
||||
|
||||
Attributes:
|
||||
df: Timeseries dataframe of state-wide policies
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir="", download_latest_data=True):
|
||||
if not os.path.exists(data_dir):
|
||||
print(
|
||||
"Creating a dynamic data directory to store COVID-19 "
|
||||
"policy tracking data: {}".format(data_dir)
|
||||
)
|
||||
os.makedirs(data_dir)
|
||||
|
||||
filename = "daily_us_policies.csv"
|
||||
if download_latest_data or filename not in os.listdir(data_dir):
|
||||
print(
|
||||
"Fetching latest U.S. COVID-19 policies data from OxCGRT, "
|
||||
"and saving it in {}".format(data_dir)
|
||||
)
|
||||
req = requests.get(
|
||||
"https://raw.githubusercontent.com/OxCGRT/USA-covid-policy/master/"
|
||||
"data/OxCGRT_US_latest.csv"
|
||||
)
|
||||
self.df = pd.read_csv(BytesIO(req.content), low_memory=False)
|
||||
self.df["Date"] = self.df["Date"].apply(
|
||||
lambda x: datetime.strptime(str(x), "%Y%m%d")
|
||||
)
|
||||
|
||||
# Fetch only the state-wide policies
|
||||
self.df = self.df.loc[self.df["Jurisdiction"] != "NAT_GOV"]
|
||||
|
||||
self.df.to_csv(
|
||||
os.path.join(data_dir, filename)
|
||||
) # Note: performs an overwrite
|
||||
else:
|
||||
print(
|
||||
"Not fetching the latest U.S. COVID-19 policies data from OxCGRT. "
|
||||
"Using whatever was saved earlier in {}!!".format(data_dir)
|
||||
)
|
||||
assert filename in os.listdir(data_dir)
|
||||
self.df = pd.read_csv(os.path.join(data_dir, filename), low_memory=False)
|
||||
|
||||
def process_policy_data(
|
||||
self, stringency_policy_key="StringencyIndex", num_stringency_levels=10
|
||||
):
|
||||
"""
|
||||
Gather the relevant policy indicator frm the dataframe,
|
||||
fill in the null values (if any),
|
||||
and discretize/quantize the policy into num_stringency_levels.
|
||||
Note: Possible values for stringency_policy_key are
|
||||
["StringencyIndex", "Government response index",
|
||||
"Containment and health index", "Economic Support index".]
|
||||
Reference: https://github.com/OxCGRT/covid-policy-tracker/blob/master/
|
||||
documentation/index_methodology.md
|
||||
"""
|
||||
|
||||
def discretize(policies, num_indicator_levels=10):
|
||||
"""
|
||||
Discretize the policies (a Pandas series) into num_indicator_levels
|
||||
"""
|
||||
# Indices are normalized to be in [0, 100]
|
||||
bins = np.linspace(0, 100, num_indicator_levels)
|
||||
# Find left and right values of bin and find the nearer edge
|
||||
bin_index = np.digitize(policies, bins, right=True)
|
||||
bin_left_edges = bins[bin_index - 1]
|
||||
bin_right_edges = bins[bin_index]
|
||||
discretized_policies = bin_index + np.argmin(
|
||||
np.stack(
|
||||
(
|
||||
np.abs(policies.values - bin_left_edges),
|
||||
np.abs(policies.values - bin_right_edges),
|
||||
)
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
return discretized_policies
|
||||
|
||||
# Gather just the relevant columns
|
||||
policy_df = self.df[["RegionName", "Date", stringency_policy_key]].copy()
|
||||
|
||||
# Fill in null values via a "forward fill"
|
||||
policy_df[stringency_policy_key].fillna(method="ffill", inplace=True)
|
||||
|
||||
# Discretize the stringency indices
|
||||
discretized_stringency_policies = discretize(
|
||||
policy_df[stringency_policy_key], num_indicator_levels=num_stringency_levels
|
||||
)
|
||||
policy_df.loc[:, stringency_policy_key] = discretized_stringency_policies
|
||||
|
||||
# Replace Washington DC by District of Columbia to keep consistent
|
||||
# (with the other data sources)
|
||||
policy_df = policy_df.replace("Washington DC", "District of Columbia")
|
||||
|
||||
policy_df = policy_df.sort_values(by=["RegionName", "Date"])
|
||||
|
||||
return policy_df
|
||||
128
ai_economist/datasets/covid19_datasets/us_unemployment.py
Normal file
128
ai_economist/datasets/covid19_datasets/us_unemployment.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import bz2
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import threading
|
||||
import urllib.request as urllib2
|
||||
|
||||
import pandas as pd
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
class DatasetCovidUnemploymentUS:
|
||||
"""
|
||||
Class to load COVID-19 unemployment data for the US states.
|
||||
Source: https://www.bls.gov/lau/
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir="", download_latest_data=True):
|
||||
if not os.path.exists(data_dir):
|
||||
print(
|
||||
"Creating a dynamic data directory to store COVID-19 "
|
||||
"unemployment data: {}".format(data_dir)
|
||||
)
|
||||
os.makedirs(data_dir)
|
||||
|
||||
filename = "monthly_us_unemployment.bz2"
|
||||
if download_latest_data or filename not in os.listdir(data_dir):
|
||||
# Construct the U.S. state to FIPS code mapping
|
||||
state_fips_df = pd.read_excel(
|
||||
"https://www2.census.gov/programs-surveys/popest/geographies/2017/"
|
||||
"state-geocodes-v2017.xlsx",
|
||||
header=5,
|
||||
)
|
||||
# remove all statistical areas and cities
|
||||
state_fips_df = state_fips_df.loc[state_fips_df["State (FIPS)"] != 0]
|
||||
self.us_state_to_fips_dict = pd.Series(
|
||||
state_fips_df["State (FIPS)"].values, index=state_fips_df.Name
|
||||
).to_dict()
|
||||
|
||||
print(
|
||||
"Fetching the U.S. unemployment data from "
|
||||
"Bureau of Labor and Statistics, and saving it in {}".format(data_dir)
|
||||
)
|
||||
self.data = self.scrape_bls_data()
|
||||
fp = bz2.BZ2File(os.path.join(data_dir, filename), "wb")
|
||||
pickle.dump(self.data, fp)
|
||||
fp.close()
|
||||
|
||||
else:
|
||||
print(
|
||||
"Not fetching the U.S. unemployment data from Bureau of Labor and"
|
||||
" Statistics. Using whatever was saved earlier in {}!!".format(data_dir)
|
||||
)
|
||||
assert filename in os.listdir(data_dir)
|
||||
with bz2.BZ2File(os.path.join(data_dir, filename), "rb") as fp:
|
||||
self.data = pickle.load(fp)
|
||||
fp.close()
|
||||
|
||||
# Scrape monthly unemployment from the Bureau of Labor Statistics website
|
||||
def get_monthly_bls_unemployment_rates(self, state_fips):
|
||||
with urllib2.urlopen(
|
||||
"https://data.bls.gov/timeseries/LASST{:02d}0000000000003".format(
|
||||
state_fips
|
||||
)
|
||||
) as response:
|
||||
html_doc = response.read()
|
||||
|
||||
soup = BeautifulSoup(html_doc, "html.parser")
|
||||
table = soup.find_all("table")[1]
|
||||
table_rows = table.find_all("tr")
|
||||
|
||||
unemployment_dict = {}
|
||||
|
||||
mth2idx = {
|
||||
"Jan": 1,
|
||||
"Feb": 2,
|
||||
"Mar": 3,
|
||||
"Apr": 4,
|
||||
"May": 5,
|
||||
"Jun": 6,
|
||||
"Jul": 7,
|
||||
"Aug": 8,
|
||||
"Sep": 9,
|
||||
"Oct": 10,
|
||||
"Nov": 11,
|
||||
"Dec": 12,
|
||||
}
|
||||
|
||||
for tr in table_rows[1:-1]:
|
||||
td = tr.find_all("td")[-1]
|
||||
unemp = float("".join([c for c in td.text if c.isdigit() or c == "."]))
|
||||
th = tr.find_all("th")
|
||||
year = int(th[0].text)
|
||||
month = mth2idx[th[1].text]
|
||||
if year not in unemployment_dict:
|
||||
unemployment_dict[year] = {}
|
||||
unemployment_dict[year][month] = unemp
|
||||
|
||||
return unemployment_dict
|
||||
|
||||
def scrape_bls_data(self):
|
||||
def do_scrape(us_state, fips, queue_obj):
|
||||
out = self.get_monthly_bls_unemployment_rates(fips)
|
||||
queue_obj.put([us_state, out])
|
||||
|
||||
print("Getting BLS Data. This might take a minute...")
|
||||
result = queue.Queue()
|
||||
threads = [
|
||||
threading.Thread(target=do_scrape, args=(us_state, fips, result))
|
||||
for us_state, fips in self.us_state_to_fips_dict.items()
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
monthly_unemployment = {}
|
||||
while not result.empty():
|
||||
us_state, data = result.get()
|
||||
monthly_unemployment[us_state] = data
|
||||
|
||||
return monthly_unemployment
|
||||
61
ai_economist/datasets/covid19_datasets/us_vaccinations.py
Normal file
61
ai_economist/datasets/covid19_datasets/us_vaccinations.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
|
||||
class DatasetCovidVaccinationsUS:
|
||||
"""
|
||||
Class to load COVID-19 vaccination data for the US.
|
||||
Source: https://ourworldindata.org/covid-vaccinations
|
||||
|
||||
Attributes:
|
||||
df: Timeseries dataframe of COVID vaccinations for all the US states
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir="", download_latest_data=True):
|
||||
if not os.path.exists(data_dir):
|
||||
print(
|
||||
"Creating a dynamic data directory to store COVID-19 "
|
||||
"vaccination data: {}".format(data_dir)
|
||||
)
|
||||
os.makedirs(data_dir)
|
||||
|
||||
filename = "daily_us_vaccinations.csv"
|
||||
if download_latest_data or filename not in os.listdir(data_dir):
|
||||
print(
|
||||
"Fetching latest U.S. COVID-19 vaccination data from "
|
||||
"Our World in Data, and saving it in {}".format(data_dir)
|
||||
)
|
||||
|
||||
req = requests.get(
|
||||
"https://raw.githubusercontent.com/owid/covid-19-data/master/"
|
||||
"public/data/vaccinations/us_state_vaccinations.csv"
|
||||
)
|
||||
self.df = pd.read_csv(BytesIO(req.content))
|
||||
|
||||
# Rename New York State to New York for consistency with other datasets
|
||||
self.df = self.df.replace("New York State", "New York")
|
||||
|
||||
# Interpolate missing values
|
||||
self.df = self.df.interpolate(method="linear")
|
||||
|
||||
self.df.to_csv(
|
||||
os.path.join(data_dir, filename)
|
||||
) # Note: performs an overwrite
|
||||
else:
|
||||
print(
|
||||
"Not fetching the latest U.S. COVID-19 deaths data from "
|
||||
"Our World in Data. Using whatever was saved earlier in {}!!".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
assert filename in os.listdir(data_dir)
|
||||
self.df = pd.read_csv(os.path.join(data_dir, filename), low_memory=False)
|
||||
18
ai_economist/foundation/__init__.py
Normal file
18
ai_economist/foundation/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation import utils
|
||||
from ai_economist.foundation.agents import agent_registry as agents
|
||||
from ai_economist.foundation.components import component_registry as components
|
||||
from ai_economist.foundation.entities import endogenous_registry as endogenous
|
||||
from ai_economist.foundation.entities import landmark_registry as landmarks
|
||||
from ai_economist.foundation.entities import resource_registry as resources
|
||||
from ai_economist.foundation.scenarios import scenario_registry as scenarios
|
||||
|
||||
|
||||
def make_env_instance(scenario_name, **kwargs):
|
||||
scenario_class = scenarios.get(scenario_name)
|
||||
return scenario_class(**kwargs)
|
||||
12
ai_economist/foundation/agents/__init__.py
Normal file
12
ai_economist/foundation/agents/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.base_agent import agent_registry
|
||||
|
||||
from . import mobiles, planners
|
||||
|
||||
# Import files that add Agent class(es) to agent_registry
|
||||
# -------------------------------------------------------
|
||||
18
ai_economist/foundation/agents/mobiles.py
Normal file
18
ai_economist/foundation/agents/mobiles.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
|
||||
@agent_registry.add
|
||||
class BasicMobileAgent(BaseAgent):
|
||||
"""
|
||||
A basic mobile agent represents an individual actor in the economic simulation.
|
||||
|
||||
"Mobile" refers to agents of this type being able to move around in the 2D world.
|
||||
"""
|
||||
|
||||
name = "BasicMobileAgent"
|
||||
40
ai_economist/foundation/agents/planners.py
Normal file
40
ai_economist/foundation/agents/planners.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
|
||||
@agent_registry.add
|
||||
class BasicPlanner(BaseAgent):
|
||||
"""
|
||||
A basic planner agent represents a social planner that sets macroeconomic policy.
|
||||
|
||||
Unlike the "mobile" agent, the planner does not represent an embodied agent in
|
||||
the world environment. BasicPlanner modifies the BaseAgent class to remove
|
||||
location as part of the agent state.
|
||||
|
||||
Also unlike the "mobile" agent, the planner agent is expected to be unique --
|
||||
that is, there should only be 1 planner. For this reason, BasicPlanner ignores
|
||||
the idx argument during construction and always sets its agent index as "p".
|
||||
"""
|
||||
|
||||
name = "BasicPlanner"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
del self.state["loc"]
|
||||
|
||||
# Overwrite any specified index so that this one is always indexed as 'p'
|
||||
# (make a separate class of planner if you want there to be multiple planners
|
||||
# in a game)
|
||||
self._idx = "p"
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
"""
|
||||
Planner agents do not occupy any location.
|
||||
"""
|
||||
raise AttributeError("BasicPlanner agents do not occupy a location.")
|
||||
5
ai_economist/foundation/base/__init__.py
Normal file
5
ai_economist/foundation/base/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
490
ai_economist/foundation/base/base_agent.py
Normal file
490
ai_economist/foundation/base/base_agent.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""Base class for Agent classes.
|
||||
|
||||
Instances of Agent classes are created for each agent in the environment. Agent
|
||||
instances are stateful, capturing location, inventory, endogenous variables,
|
||||
and any additional state fields created by environment components during
|
||||
construction (see BaseComponent.get_additional_state_fields in base_component.py).
|
||||
|
||||
They also provide a simple API for getting/setting actions for each of their
|
||||
registered action subspaces (which depend on the components used to build
|
||||
the environment).
|
||||
|
||||
Args:
|
||||
idx (int or str): Index that uniquely identifies the agent object amongst the
|
||||
other agent objects registered in its environment.
|
||||
multi_action_mode (bool): Whether to allow the agent to take one action for
|
||||
each of its registered action subspaces each timestep (if True),
|
||||
or to limit the agent to take only one action each timestep (if False).
|
||||
"""
|
||||
|
||||
name = ""
|
||||
|
||||
def __init__(self, idx=None, multi_action_mode=None):
|
||||
assert self.name
|
||||
|
||||
if idx is None:
|
||||
idx = 0
|
||||
|
||||
if multi_action_mode is None:
|
||||
multi_action_mode = False
|
||||
|
||||
if isinstance(idx, str):
|
||||
self._idx = idx
|
||||
else:
|
||||
self._idx = int(idx)
|
||||
|
||||
self.multi_action_mode = bool(multi_action_mode)
|
||||
self.single_action_map = (
|
||||
{}
|
||||
) # Used to convert single-action-mode actions to the general format
|
||||
|
||||
self.action = dict()
|
||||
self.action_dim = dict()
|
||||
self._action_names = []
|
||||
self._multi_action_dict = {}
|
||||
self._unique_actions = 0
|
||||
self._total_actions = 0
|
||||
|
||||
self.state = dict(loc=[0, 0], inventory={}, escrow={}, endogenous={})
|
||||
|
||||
self._registered_inventory = False
|
||||
self._registered_endogenous = False
|
||||
self._registered_components = False
|
||||
self._noop_action_dict = dict()
|
||||
|
||||
# Special flag to allow logic for multi-action-mode agents
|
||||
# that are not given any actions.
|
||||
self._passive_multi_action_agent = False
|
||||
|
||||
# If this gets set to true, we can make masks faster
|
||||
self._one_component_single_action = False
|
||||
self._premask = None
|
||||
|
||||
@property
|
||||
def idx(self):
|
||||
"""Index used to identify this agent. Must be unique within the environment."""
|
||||
return self._idx
|
||||
|
||||
def register_inventory(self, resources):
|
||||
"""Used during environment construction to populate inventory/escrow fields."""
|
||||
assert not self._registered_inventory
|
||||
for entity_name in resources:
|
||||
self.inventory[entity_name] = 0
|
||||
self.escrow[entity_name] = 0
|
||||
self._registered_inventory = True
|
||||
|
||||
def register_endogenous(self, endogenous):
|
||||
"""Used during environment construction to populate endogenous state fields."""
|
||||
assert not self._registered_endogenous
|
||||
for entity_name in endogenous:
|
||||
self.endogenous[entity_name] = 0
|
||||
self._registered_endogenous = True
|
||||
|
||||
def _incorporate_component(self, action_name, n):
|
||||
extra_n = (
|
||||
1 if self.multi_action_mode else 0
|
||||
) # Each sub-action has a NO-OP in multi action mode)
|
||||
self.action[action_name] = 0
|
||||
self.action_dim[action_name] = n + extra_n
|
||||
self._action_names.append(action_name)
|
||||
self._multi_action_dict[action_name] = False
|
||||
self._unique_actions += 1
|
||||
if self.multi_action_mode:
|
||||
self._total_actions += n + extra_n
|
||||
else:
|
||||
for action_n in range(1, n + 1):
|
||||
self._total_actions += 1
|
||||
self.single_action_map[int(self._total_actions)] = [
|
||||
action_name,
|
||||
action_n,
|
||||
]
|
||||
|
||||
def register_components(self, components):
|
||||
"""Used during environment construction to set up state/action spaces."""
|
||||
assert not self._registered_components
|
||||
for component in components:
|
||||
n = component.get_n_actions(self.name)
|
||||
if n is None:
|
||||
continue
|
||||
|
||||
# Most components will have a single action-per-agent, so n is an int
|
||||
if isinstance(n, int):
|
||||
if n == 0:
|
||||
continue
|
||||
self._incorporate_component(component.name, n)
|
||||
|
||||
# They can also internally handle multiple actions-per-agent,
|
||||
# so n is an tuple or list
|
||||
elif isinstance(n, (tuple, list)):
|
||||
for action_sub_name, n_ in n:
|
||||
if n_ == 0:
|
||||
continue
|
||||
if "." in action_sub_name:
|
||||
raise NameError(
|
||||
"Sub-action {} of component {} "
|
||||
"is illegally named.".format(
|
||||
action_sub_name, component.name
|
||||
)
|
||||
)
|
||||
self._incorporate_component(
|
||||
"{}.{}".format(component.name, action_sub_name), n_
|
||||
)
|
||||
|
||||
# If that's not what we got something is funky.
|
||||
else:
|
||||
raise TypeError(
|
||||
"Received unexpected type ({}) from {}.get_n_actions('{}')".format(
|
||||
type(n), component.name, self.name
|
||||
)
|
||||
)
|
||||
|
||||
for k, v in component.get_additional_state_fields(self.name).items():
|
||||
self.state[k] = v
|
||||
|
||||
# Currently no actions are available to this agent. Give it a placeholder.
|
||||
if len(self.action) == 0 and self.multi_action_mode:
|
||||
self._incorporate_component("PassiveAgentPlaceholder", 0)
|
||||
self._passive_multi_action_agent = True
|
||||
|
||||
elif len(self.action) == 1 and not self.multi_action_mode:
|
||||
self._one_component_single_action = True
|
||||
self._premask = np.ones(1 + self._total_actions, dtype=np.float32)
|
||||
|
||||
self._registered_components = True
|
||||
|
||||
self._noop_action_dict = {k: v * 0 for k, v in self.action.items()}
|
||||
|
||||
verbose = False
|
||||
if verbose:
|
||||
print(self.name, self.idx, "constructed action map:")
|
||||
for k, v in self.single_action_map.items():
|
||||
print("single action map:", k, v)
|
||||
for k, v in self.action.items():
|
||||
print("action:", k, v)
|
||||
for k, v in self.action_dim.items():
|
||||
print("action_dim:", k, v)
|
||||
|
||||
@property
|
||||
def action_spaces(self):
|
||||
"""
|
||||
if self.multi_action_mode == True:
|
||||
Returns an integer array with length equal to the number of action
|
||||
subspaces that the agent registered. The i'th element of the array
|
||||
indicates the number of actions associated with the i'th action subspace.
|
||||
In multi_action_mode, each subspace includes a NO-OP.
|
||||
Note: self._action_names describes which action subspace each element of
|
||||
the array refers to.
|
||||
|
||||
Example:
|
||||
>> self.multi_action_mode
|
||||
True
|
||||
>> self.action_spaces
|
||||
[2, 5]
|
||||
>> self._action_names
|
||||
["Build", "Gather"]
|
||||
# [1 Build action + Build NO-OP, 4 Gather actions + Gather NO-OP]
|
||||
|
||||
if self.multi_action_mode == False:
|
||||
Returns a single integer equal to the total number of actions that the
|
||||
agent can take.
|
||||
|
||||
Example:
|
||||
>> self.multi_action_mode
|
||||
False
|
||||
>> self.action_spaces
|
||||
6
|
||||
>> self._action_names
|
||||
["Build", "Gather"]
|
||||
# 1 NO-OP + 1 Build action + 4 Gather actions.
|
||||
"""
|
||||
if self.multi_action_mode:
|
||||
action_dims = []
|
||||
for m in self._action_names:
|
||||
action_dims.append(np.array(self.action_dim[m]).reshape(-1))
|
||||
return np.concatenate(action_dims).astype(np.int32)
|
||||
n_actions = 1 # (NO-OP)
|
||||
for m in self._action_names:
|
||||
n_actions += self.action_dim[m]
|
||||
return n_actions
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
"""2D list of [row, col] representing agent's location in the environment."""
|
||||
return self.state["loc"]
|
||||
|
||||
@property
|
||||
def endogenous(self):
|
||||
"""Dictionary representing endogenous quantities (i.e. "Labor").
|
||||
|
||||
Example:
|
||||
>> self.endogenous
|
||||
{"Labor": 30.25}
|
||||
"""
|
||||
return self.state["endogenous"]
|
||||
|
||||
@property
|
||||
def inventory(self):
|
||||
"""Dictionary representing quantities of resources in agent's inventory.
|
||||
|
||||
Example:
|
||||
>> self.inventory
|
||||
{"Wood": 3, "Stone": 20, "Coin": 1002.83}
|
||||
"""
|
||||
return self.state["inventory"]
|
||||
|
||||
@property
|
||||
def escrow(self):
|
||||
"""Dictionary representing quantities of resources in agent's escrow.
|
||||
|
||||
https://en.wikipedia.org/wiki/Escrow
|
||||
Escrow is used to manage any portion of the agent's inventory that is
|
||||
reserved for a particular purpose. Typically, something enters escrow as part
|
||||
of a contractual arrangement to disburse that something when another
|
||||
condition is met. An example is found in the ContinuousDoubleAuction
|
||||
Component class (see ../components/continuous_double_auction.py). When an
|
||||
agent creates an order to sell a unit of Wood, for example, the component
|
||||
moves one unit of Wood from the agent's inventory to its escrow. If another
|
||||
agent buys the Wood, it is moved from escrow to the other agent's inventory. By
|
||||
placing the Wood in escrow, it prevents the first agent from using it for
|
||||
something else (i.e. building a house).
|
||||
|
||||
Notes:
|
||||
The inventory and escrow share the same keys. An agent's endowment refers
|
||||
to the total quantity it has in its inventory and escrow.
|
||||
|
||||
Escrow is provided to simplify inventory management but its intended
|
||||
semantics are not enforced directly. It is up to Component classes to
|
||||
enforce these semantics.
|
||||
|
||||
Example:
|
||||
>> self.inventory
|
||||
{"Wood": 0, "Stone": 1, "Coin": 3}
|
||||
"""
|
||||
return self.state["escrow"]
|
||||
|
||||
def inventory_to_escrow(self, resource, amount):
|
||||
"""Move some amount of a resource from agent inventory to agent escrow.
|
||||
|
||||
Amount transferred is capped to the amount of resource in agent inventory.
|
||||
|
||||
Args:
|
||||
resource (str): The name of the resource to move (i.e. "Wood", "Coin").
|
||||
amount (float): The amount to be moved from inventory to escrow. Must be
|
||||
positive.
|
||||
|
||||
Returns:
|
||||
Amount of resource actually transferred. Will be less than amount argument
|
||||
if amount argument exceeded the amount of resource in the inventory.
|
||||
Calculated as:
|
||||
transferred = np.minimum(self.state["inventory"][resource], amount)
|
||||
"""
|
||||
assert amount >= 0
|
||||
transferred = float(np.minimum(self.state["inventory"][resource], amount))
|
||||
self.state["inventory"][resource] -= transferred
|
||||
self.state["escrow"][resource] += transferred
|
||||
return float(transferred)
|
||||
|
||||
def escrow_to_inventory(self, resource, amount):
|
||||
"""Move some amount of a resource from agent escrow to agent inventory.
|
||||
|
||||
Amount transferred is capped to the amount of resource in agent escrow.
|
||||
|
||||
Args:
|
||||
resource (str): The name of the resource to move (i.e. "Wood", "Coin").
|
||||
amount (float): The amount to be moved from escrow to inventory. Must be
|
||||
positive.
|
||||
|
||||
Returns:
|
||||
Amount of resource actually transferred. Will be less than amount argument
|
||||
if amount argument exceeded the amount of resource in escrow.
|
||||
Calculated as:
|
||||
transferred = np.minimum(self.state["escrow"][resource], amount)
|
||||
"""
|
||||
assert amount >= 0
|
||||
transferred = float(np.minimum(self.state["escrow"][resource], amount))
|
||||
self.state["escrow"][resource] -= transferred
|
||||
self.state["inventory"][resource] += transferred
|
||||
return float(transferred)
|
||||
|
||||
def total_endowment(self, resource):
|
||||
"""Get the combined inventory+escrow endowment of resource.
|
||||
|
||||
Args:
|
||||
resource (str): Name of the resource
|
||||
|
||||
Returns:
|
||||
The amount of resource in the agents inventory and escrow.
|
||||
|
||||
"""
|
||||
return self.inventory[resource] + self.escrow[resource]
|
||||
|
||||
def reset_actions(self, component=None):
|
||||
"""Reset all actions to the NO-OP action (the 0'th action index).
|
||||
|
||||
If component is specified, only reset action(s) for that component.
|
||||
"""
|
||||
if not component:
|
||||
self.action.update(self._noop_action_dict)
|
||||
else:
|
||||
for k, v in self.action.items():
|
||||
if "." in component:
|
||||
if k.lower() == component.lower():
|
||||
self.action[k] = v * 0
|
||||
else:
|
||||
base_component = k.split(".")[0]
|
||||
if base_component.lower() == component.lower():
|
||||
self.action[k] = v * 0
|
||||
|
||||
def has_component(self, component_name):
|
||||
"""Returns True if the agent has component_name as a registered subaction."""
|
||||
return bool(component_name in self.action)
|
||||
|
||||
def get_random_action(self):
|
||||
"""
|
||||
Select a component at random and randomly choose one of its actions (other
|
||||
than NO-OP).
|
||||
"""
|
||||
random_component = random.choice(self._action_names)
|
||||
component_action = random.choice(
|
||||
list(range(1, self.action_dim[random_component]))
|
||||
)
|
||||
return {random_component: component_action}
|
||||
|
||||
def get_component_action(self, component_name, sub_action_name=None):
|
||||
"""
|
||||
Return the action(s) taken for component_name component, or None if the
|
||||
agent does not use that component.
|
||||
"""
|
||||
if sub_action_name is not None:
|
||||
return self.action.get(component_name + "." + sub_action_name, None)
|
||||
matching_names = [
|
||||
m for m in self._action_names if m.split(".")[0] == component_name
|
||||
]
|
||||
if len(matching_names) == 0:
|
||||
return None
|
||||
if len(matching_names) == 1:
|
||||
return self.action.get(matching_names[0], None)
|
||||
return [self.action.get(m, None) for m in matching_names]
|
||||
|
||||
def set_component_action(self, component_name, action):
|
||||
"""Set the action(s) taken for component_name component."""
|
||||
if component_name not in self.action:
|
||||
raise KeyError(
|
||||
"Agent {} of type {} does not have {} registered as a subaction".format(
|
||||
self.idx, self.name, component_name
|
||||
)
|
||||
)
|
||||
if self._multi_action_dict[component_name]:
|
||||
self.action[component_name] = np.array(action, dtype=np.int32)
|
||||
else:
|
||||
self.action[component_name] = int(action)
|
||||
|
||||
def populate_random_actions(self):
|
||||
"""Fill the action buffer with random actions. This is for testing."""
|
||||
for component, d in self.action_dim.items():
|
||||
if isinstance(d, int):
|
||||
self.set_component_action(component, np.random.randint(0, d))
|
||||
else:
|
||||
d_array = np.array(d)
|
||||
self.set_component_action(
|
||||
component, np.floor(np.random.rand(*d_array.shape) * d_array)
|
||||
)
|
||||
|
||||
def parse_actions(self, actions):
|
||||
"""Parse the actions array to fill each component's action buffers."""
|
||||
if self.multi_action_mode:
|
||||
assert len(actions) == self._unique_actions
|
||||
if len(actions) == 1:
|
||||
self.set_component_action(self._action_names[0], actions[0])
|
||||
else:
|
||||
for action_name, action in zip(self._action_names, actions):
|
||||
self.set_component_action(action_name, int(action))
|
||||
|
||||
# Single action mode
|
||||
else:
|
||||
# Action was supplied as an index of a specific subaction.
|
||||
# No need to do any lookup.
|
||||
if isinstance(actions, dict):
|
||||
if len(actions) == 0:
|
||||
return
|
||||
assert len(actions) == 1
|
||||
action_name = list(actions.keys())[0]
|
||||
action = list(actions.values())[0]
|
||||
if action == 0:
|
||||
return
|
||||
self.set_component_action(action_name, action)
|
||||
|
||||
# Action was supplied as an index into the full set of combined actions
|
||||
else:
|
||||
action = int(actions)
|
||||
# Universal NO-OP
|
||||
if action == 0:
|
||||
return
|
||||
action_name, action = self.single_action_map.get(action)
|
||||
self.set_component_action(action_name, action)
|
||||
|
||||
def flatten_masks(self, mask_dict):
|
||||
"""Convert a dictionary of component action masks into a single mask vector."""
|
||||
if self._one_component_single_action:
|
||||
self._premask[1:] = mask_dict[self._action_names[0]]
|
||||
return self._premask
|
||||
|
||||
no_op_mask = [1]
|
||||
|
||||
if self._passive_multi_action_agent:
|
||||
return np.array(no_op_mask).astype(np.float32)
|
||||
|
||||
list_of_masks = []
|
||||
if not self.multi_action_mode:
|
||||
list_of_masks.append(no_op_mask)
|
||||
for m in self._action_names:
|
||||
if m not in mask_dict:
|
||||
raise KeyError("No mask provided for {} (agent {})".format(m, self.idx))
|
||||
if self.multi_action_mode:
|
||||
list_of_masks.append(no_op_mask)
|
||||
list_of_masks.append(mask_dict[m])
|
||||
return np.concatenate(list_of_masks).astype(np.float32)
|
||||
|
||||
|
||||
agent_registry = Registry(BaseAgent)
|
||||
"""The registry for Agent classes.
|
||||
|
||||
This creates a registry object for Agent classes. This registry requires that all
|
||||
added classes are subclasses of BaseAgent. To make an Agent class available through
|
||||
the registry, decorate the class definition with @agent_registry.add.
|
||||
|
||||
Example:
|
||||
from ai_economist.foundation.base.base_agent import BaseAgent, agent_registry
|
||||
|
||||
@agent_registry.add
|
||||
class ExampleAgent(BaseAgent):
|
||||
name = "Example"
|
||||
pass
|
||||
|
||||
assert agent_registry.has("Example")
|
||||
|
||||
AgentClass = agent_registry.get("Example")
|
||||
agent = AgentClass(...)
|
||||
assert isinstance(agent, ExampleAgent)
|
||||
|
||||
Notes:
|
||||
The foundation package exposes the agent registry as: foundation.agents
|
||||
|
||||
An Agent class that is defined and registered following the above example will
|
||||
only be visible in foundation.agents if defined/registered in a file that is
|
||||
imported in ../agents/__init__.py.
|
||||
"""
|
||||
406
ai_economist/foundation/base/base_component.py
Normal file
406
ai_economist/foundation/base/base_component.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.agents import agent_registry
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
from ai_economist.foundation.base.world import World
|
||||
|
||||
|
||||
class BaseComponent(ABC):
|
||||
"""
|
||||
Base Component class. Should be used as the parent class for Component classes.
|
||||
Component instances are used to add some particular dynamics to an environment.
|
||||
They also add action spaces through which agents can interact with the
|
||||
environment via the component instance.
|
||||
|
||||
Environments expand the agents' state/action spaces by querying:
|
||||
get_n_actions
|
||||
get_additional_state_fields
|
||||
|
||||
Environments expand their dynamics by querying:
|
||||
component_step
|
||||
generate_observations
|
||||
generate_masks
|
||||
|
||||
Environments expand logging behavior by querying:
|
||||
get_metrics
|
||||
get_dense_log
|
||||
|
||||
Because they are built as Python objects, component instances can also be
|
||||
stateful. Stateful attributes are reset via calls to:
|
||||
additional_reset_steps
|
||||
|
||||
The semantics of each method, and how they can be used to construct an instance
|
||||
of the Component class, are detailed below.
|
||||
|
||||
Refer to ../components/move.py for an example of a Component class that enables
|
||||
mobile agents to move and collect resources in the environment world.
|
||||
"""
|
||||
|
||||
# The name associated with this Component class (must be unique).
|
||||
# Note: This is what will identify the Component class in the component registry.
|
||||
name = ""
|
||||
|
||||
# An optional shorthand description of the what the component implements (i.e.
|
||||
# "Trading", "Building", etc.). See BaseEnvironment.get_component and
|
||||
# BaseEnvironment._finalize_logs to see where this may add convenience.
|
||||
# Does not need to be unique.
|
||||
component_type = None
|
||||
|
||||
# The (sub)classes of agents that this component applies to
|
||||
agent_subclasses = None # Replace with list or tuple (can be empty)
|
||||
|
||||
# The (non-agent) game entities that are expected to be in play
|
||||
required_entities = None # Replace with list or tuple (can be empty)
|
||||
|
||||
def __init__(self, world, episode_length, inventory_scale=1):
|
||||
assert self.name
|
||||
|
||||
assert isinstance(self.agent_subclasses, (tuple, list))
|
||||
assert len(self.agent_subclasses) > 0
|
||||
if len(self.agent_subclasses) > 1:
|
||||
for i in range(len(self.agent_subclasses)):
|
||||
for j in range(len(self.agent_subclasses)):
|
||||
if i == j:
|
||||
continue
|
||||
a_i = agent_registry.get(self.agent_subclasses[i])
|
||||
a_j = agent_registry.get(self.agent_subclasses[j])
|
||||
assert not issubclass(a_i, a_j)
|
||||
|
||||
assert isinstance(self.required_entities, (tuple, list))
|
||||
|
||||
self.check_world(world)
|
||||
self._world = world
|
||||
|
||||
assert isinstance(episode_length, int) and episode_length > 0
|
||||
self._episode_length = episode_length
|
||||
|
||||
self.n_agents = world.n_agents
|
||||
self.resources = world.resources
|
||||
self.landmarks = world.landmarks
|
||||
|
||||
self.timescale = 1
|
||||
assert self.timescale >= 1
|
||||
|
||||
self._inventory_scale = float(inventory_scale)
|
||||
|
||||
@property
|
||||
def world(self):
|
||||
"""The world object of the environment this component instance is part of.
|
||||
|
||||
The world object exposes the spatial/agent states through:
|
||||
world.maps # Reference to maps object representing spatial state
|
||||
world.agents # List of self.n_agents mobile agent objects
|
||||
world.planner # Reference to planner agent object
|
||||
|
||||
See world.py and base_agent.py for additional API details.
|
||||
"""
|
||||
return self._world
|
||||
|
||||
@property
|
||||
def episode_length(self):
|
||||
"""Episode length of the environment this component instance is a part of."""
|
||||
return int(self._episode_length)
|
||||
|
||||
@property
|
||||
def inv_scale(self):
|
||||
"""
|
||||
Value by which to scale quantities when generating observations.
|
||||
|
||||
Note: This property is set by the environment during construction and
|
||||
allows each component instance within the environment to refer to the same
|
||||
scaling value. How the value is actually used depends on the implementation
|
||||
of get_observations().
|
||||
"""
|
||||
return self._inventory_scale
|
||||
|
||||
@property
|
||||
def shorthand(self):
|
||||
"""The shorthand name, or name if no component_type is defined."""
|
||||
return self.name if self.component_type is None else self.component_type
|
||||
|
||||
@staticmethod
|
||||
def check_world(world):
|
||||
"""Validate the world object."""
|
||||
assert isinstance(world, World)
|
||||
|
||||
def reset(self):
|
||||
"""Reset any portion of the state managed by this component."""
|
||||
world = self.world
|
||||
all_agents = world.agents + [world.planner]
|
||||
for agent in all_agents:
|
||||
agent.state.update(self.get_additional_state_fields(agent.name))
|
||||
|
||||
# This method allows components to define additional reset steps
|
||||
self.additional_reset_steps()
|
||||
|
||||
def obs(self):
|
||||
"""
|
||||
Observation produced by this component, given current world/agents/component
|
||||
state.
|
||||
"""
|
||||
# This is mostly just to ensure formatting.
|
||||
obs = self.generate_observations()
|
||||
assert isinstance(obs, dict)
|
||||
obs = {str(k): v for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
Return the number of actions (not including NO-OPs) for agents of type
|
||||
agent_cls_name.
|
||||
|
||||
Args:
|
||||
agent_cls_name (str): name of the Agent class for which number of actions
|
||||
is being queried. For example, "BasicMobileAgent".
|
||||
|
||||
Returns:
|
||||
action_space (None, int, or list): If the component does not add any
|
||||
actions for agents of type agent_cls_name, return None. If it adds a
|
||||
single action space, return an integer specifying the number of
|
||||
actions in the action space. If it adds multiple action spaces,
|
||||
return a list of tuples ("action_set_name", num_actions_in_set).
|
||||
See below for further detail.
|
||||
|
||||
If agent_class_name type agents do not participate in the component, simply
|
||||
return None
|
||||
|
||||
In the next simplest case, the component adds one set of n different actions
|
||||
for agents of type agent_cls_name. In this case, return n (as an int). For
|
||||
example, if Component implements moving up, down, left, or right for
|
||||
"BasicMobileAgent" agents, then Component.get_n_actions('Mobile') should
|
||||
return 4.
|
||||
|
||||
If the component adds multiple sets of actions for a given agent type, this
|
||||
method should return a list of tuples:
|
||||
[("action_set_name_1", n_1), ..., ("action_set_name_M", n_M)],
|
||||
where M is the number of different sets of actions, and n_k is the number of
|
||||
actions in action set k.
|
||||
For example, if Component allows agent 'Planner' to set some tax for each of
|
||||
individual Mobile agents, and there are 3 such agents, then:
|
||||
Component.get_n_actions('Planner') should return, i.e.,
|
||||
[('Tax_0', 10), ('Tax_1', 10), ('Tax_2', 10)],
|
||||
where, in this example, the Planner agent can choose 10 different tax
|
||||
levels for each Mobile agent.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
Return a dictionary of {state_field: reset_val} managed by this Component
|
||||
class for agents of type agent_cls_name. This also partially controls reset
|
||||
behavior.
|
||||
|
||||
Args:
|
||||
agent_cls_name (str): name of the Agent class for which additional states
|
||||
are being queried. For example, "BasicMobileAgent".
|
||||
|
||||
Returns:
|
||||
extra_state_dict (dict): A dictionary of {"state_field": reset_val} for
|
||||
each extra state field that this component adds/manages to agents of
|
||||
type agent_cls_name. This extra_state_dict is incorporated into
|
||||
agent.state for each agent of this type. Note that the keyed fields
|
||||
will be reset to reset_val when the environment is reset.
|
||||
|
||||
If the component has its own internal state, the protocol for resetting that
|
||||
should be written into the custom method 'additional_reset_steps()' [see below].
|
||||
|
||||
States that are meant to be internal to the component do not need to be
|
||||
registered as agent state fields. Rather, adding to the agent state fields is
|
||||
most useful when two or more components refer to or affect the same state. In
|
||||
general, however, if the component expects a particular state field to exist,
|
||||
it should use return that field (and its reset value) here.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def component_step(self):
|
||||
"""
|
||||
For all relevant agents, execute the actions specific to this Component class.
|
||||
This is essentially where the component logic is implemented and what allows
|
||||
components to create environment dynamics.
|
||||
|
||||
If the component expects certain resources/landmarks/entities to be in play,
|
||||
it must declare them in 'required_entities' so that they can be registered as
|
||||
part of the world and, where appropriate, part of the agent inventory.
|
||||
|
||||
If the component expects non-standard fields to exist in agent.state for one
|
||||
or more agent types, that must be reflected in get_additional_state_fields().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_observations(self):
|
||||
"""
|
||||
Generate observations associated with this Component class.
|
||||
|
||||
A component does not need to produce observations and can provide observations
|
||||
for only some agent types; however, for a given environment, the structure of
|
||||
the observations returned by this component should be identical between
|
||||
subsequent calls to generate_observations. That is, the agents that receive
|
||||
observations should remain consistent as should the structure of their
|
||||
individual observations.
|
||||
|
||||
Returns:
|
||||
obs (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent (which can include
|
||||
the planner) for which this component provides an observation. For each
|
||||
entry, the key specifies the index of the agent and the value contains
|
||||
its associated observation dictionary.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
Create action masks to indicate which actions are and are not valid. Actions
|
||||
that are valid should be given a value of 1 and 0 otherwise. Do not generate
|
||||
a mask for the NO-OP action, which is always available.
|
||||
|
||||
Args:
|
||||
completions (int): The number of completed episodes. This is intended to
|
||||
be used in the case that actions may be masked or unmasked as part of a
|
||||
learning curriculum.
|
||||
|
||||
Returns:
|
||||
masks (dict): A dictionary of {agent.idx: mask} with an entry for each
|
||||
agent that can interact with this component. See below.
|
||||
|
||||
|
||||
The expected output parallels the action subspaces defined by get_n_actions():
|
||||
The output should be a dictionary of {agent.idx: mask} keyed for all agents
|
||||
that take actions via this component.
|
||||
|
||||
For example, say the component defines a set of 4 actions for agents of type
|
||||
"BasicMobileAgent" (self.get_n_actions("BasicMobileAgent) --> 4). Because all
|
||||
action spaces include a NO-OP action, there are 5 available actions,
|
||||
interpreted in this example as: NO-OP (index=0), moving up (index=1),
|
||||
down (index=2), left (index=3), or right (index=4). Say also that agent-0 (the
|
||||
agent with agent.idx=0) is prevented from moving left but can otherwise move.
|
||||
In this case, generate_masks(world)['0'] should point to a length-4 binary
|
||||
array, specifically [1, 1, 0, 1]. Note that the mask is length 4 while
|
||||
technically 5 actions are available. This is because NO-OP should be ignored
|
||||
when constructing masks.
|
||||
|
||||
In the more complex case where the component defines several action sets for
|
||||
an agent, say the planner agent (the agent with agent.idx='p'), then
|
||||
generate_masks(world)['p'] should point to a dictionary of
|
||||
{"action_set_name_m": mask_m} for each of the M action sets associated with
|
||||
agent p's type. Each such value, mask_m, should be a binary array whose
|
||||
length matches the number of actions in "action_set_name_m".
|
||||
|
||||
The default behavior (below) keeps all actions available. The code gives an
|
||||
example of expected formatting.
|
||||
"""
|
||||
world = self.world
|
||||
masks = {}
|
||||
# For all the agents in the environment
|
||||
for agent in world.agents + [world.planner]:
|
||||
# Get any action space(s) defined by this component for this agent
|
||||
n_actions = self.get_n_actions(agent.name)
|
||||
|
||||
# If no action spaces are defined, just move on.
|
||||
if n_actions is None:
|
||||
continue
|
||||
|
||||
# If a single action space is defined, n_actions corresponds to the
|
||||
# number of (non NO-OP) actions. Return an array of ones of that length,
|
||||
# enabling all actions.
|
||||
if isinstance(n_actions, (int, float)):
|
||||
masks[agent.idx] = np.ones(int(n_actions))
|
||||
|
||||
# If multiple action spaces are defined, n_actions corresponds to the
|
||||
# tuple or list giving ("name", N) for each action space, where "name"
|
||||
# is the unique name and N is the number of (non NO-OP) actions
|
||||
# associated with that action space.
|
||||
# Return a dictionary of {"name": length-N ones array}, enabling all
|
||||
# actions in all the action spaces.
|
||||
elif isinstance(n_actions, (tuple, list)):
|
||||
masks[agent.idx] = {
|
||||
sub_name: np.ones(int(sub_n)) for sub_name, sub_n in n_actions
|
||||
}
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
Use this method to implement additional steps that the component should
|
||||
perform at reset. Useful for resetting internal trackers.
|
||||
|
||||
This method should not return anything.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_metrics(self):
|
||||
"""
|
||||
Returns a dictionary of custom metrics describing the episode through the
|
||||
lens of the component.
|
||||
|
||||
For example, if Build is a subclass of BaseComponent that implements building,
|
||||
Build.get_metrics() might return a dictionary with terms relating to the
|
||||
number of things each agent built.
|
||||
|
||||
Returns:
|
||||
metrics (dict or None): A dictionary of {"metric_key": metric_value}
|
||||
entries describing the metrics that this component calculates. The
|
||||
environment combines scenario metrics with each of the metric
|
||||
dictionaries produced by its component instances. metric_value is
|
||||
expected to be a scalar.
|
||||
By returning None instead of a dictionary, the component is ignored
|
||||
by the environment when constructing the full metric report.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Return the dense log, either a tuple, list, or dict, of the episode through the
|
||||
lens of this component.
|
||||
|
||||
If this component does not yield a dense log, return None (default behavior).
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
component_registry = Registry(BaseComponent)
|
||||
"""The registry for Component classes.
|
||||
|
||||
This creates a registry object for Component classes. This registry requires that all
|
||||
added classes are subclasses of BaseComponent. To make a Component class available
|
||||
through the registry, decorate the class definition with @component_registry.add.
|
||||
|
||||
Example:
|
||||
from ai_economist.foundation.base.base_component
|
||||
import BaseComponent, component_registry
|
||||
|
||||
@component_registry.add
|
||||
class ExampleComponent(BaseComponent):
|
||||
name = "Example"
|
||||
pass
|
||||
|
||||
assert component_registry.has("Example")
|
||||
|
||||
ComponentClass = component_registry.get("Example")
|
||||
component = ComponentClass(...)
|
||||
assert isinstance(component, ExampleComponent)
|
||||
|
||||
Notes:
|
||||
The foundation package exposes the component registry as: foundation.components
|
||||
|
||||
A Component class that is defined and registered following the above example will
|
||||
only be visible in foundation.components if defined/registered in a file that is
|
||||
imported in ../components/__init__.py.
|
||||
"""
|
||||
1157
ai_economist/foundation/base/base_env.py
Normal file
1157
ai_economist/foundation/base/base_env.py
Normal file
File diff suppressed because it is too large
Load Diff
103
ai_economist/foundation/base/registrar.py
Normal file
103
ai_economist/foundation/base/registrar.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Utility for registering sets of similar classes and looking them up by name.
|
||||
|
||||
Registries provide a simple API for getting classes used to build environment
|
||||
instances. Their main purpose is to organize such "building block" classes (i.e.
|
||||
Components, Scenarios, Agents) for easy reference as well as to ensure that all
|
||||
classes within a particular registry inherit from the same Base Class.
|
||||
|
||||
Args:
|
||||
base_class (class): The class that all entries in the registry must be a
|
||||
subclass of.
|
||||
|
||||
Example:
|
||||
class BaseClass:
|
||||
pass
|
||||
|
||||
registry = Registry(BaseClass)
|
||||
|
||||
@registry.add
|
||||
class ExampleSubclassA(BaseClass):
|
||||
name = "ExampleA"
|
||||
pass
|
||||
|
||||
@registry.add
|
||||
class ExampleSubclassB(BaseClass):
|
||||
name = "ExampleB"
|
||||
pass
|
||||
|
||||
print(registry.entries)
|
||||
# ["ExampleA", "ExampleB"]
|
||||
|
||||
assert registry.has("ExampleA")
|
||||
assert registry.get("ExampleB") is ExampleSubclassB
|
||||
"""
|
||||
|
||||
def __init__(self, base_class=None):
|
||||
self.base_class = base_class
|
||||
self._entries = []
|
||||
self._lookup = dict()
|
||||
|
||||
def add(self, cls):
|
||||
"""Add cls to this registry.
|
||||
|
||||
Args:
|
||||
cls: The class to add to this registry. Must be a subclass of
|
||||
self.base_class.
|
||||
|
||||
Returns:
|
||||
cls (to allow decoration with @registry.add)
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
assert "." not in cls.name
|
||||
if self.base_class:
|
||||
assert issubclass(cls, self.base_class)
|
||||
self._lookup[cls.name.lower()] = cls
|
||||
if cls.name not in self._entries:
|
||||
self._entries.append(cls.name)
|
||||
return cls
|
||||
|
||||
def get(self, cls_name):
|
||||
"""Return registered class with name cls_name.
|
||||
|
||||
Args:
|
||||
cls_name (str): Name of the registered class to get.
|
||||
|
||||
Returns:
|
||||
Registered class cls, where cls.name matches cls_name (ignoring casing).
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
if cls_name.lower() not in self._lookup:
|
||||
raise KeyError('"{}" is not a name of a registered class'.format(cls_name))
|
||||
return self._lookup[cls_name.lower()]
|
||||
|
||||
def has(self, cls_name):
|
||||
"""Return True if a class with name cls_name is registered.
|
||||
|
||||
Args:
|
||||
cls_name (str): Name of class to check.
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
return cls_name.lower() in self._lookup
|
||||
|
||||
@property
|
||||
def entries(self):
|
||||
"""Names of classes in this registry.
|
||||
|
||||
Returns:
|
||||
A list of strings corresponding to the names of classes registered in
|
||||
this registry object.
|
||||
|
||||
See Registry class docstring for example.
|
||||
"""
|
||||
return sorted(list(self._entries))
|
||||
495
ai_economist/foundation/base/world.py
Normal file
495
ai_economist/foundation/base/world.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.agents import agent_registry
|
||||
from ai_economist.foundation.entities import landmark_registry, resource_registry
|
||||
|
||||
|
||||
class Maps:
|
||||
"""Manages the spatial configuration of the world as a set of entity maps.
|
||||
|
||||
A maps object is built during world construction, which is a part of environment
|
||||
construction. The maps object is accessible through the world object. The maps
|
||||
object maintains a map state for each of the spatial entities that are involved
|
||||
in the constructed environment (which are determined by the "required_entities"
|
||||
attributes of the Scenario and Component classes used to build the environment).
|
||||
|
||||
The Maps class also implements some of the basic spatial logic of the game,
|
||||
such as which locations agents can occupy based on other agent locations and
|
||||
locations of various landmarks.
|
||||
|
||||
Args:
|
||||
size (list): A length-2 list specifying the dimensions of the 2D world.
|
||||
Interpreted as [height, width].
|
||||
n_agents (int): The number of mobile agents (does not include planner).
|
||||
world_resources (list): The resources registered during environment
|
||||
construction.
|
||||
world_landmarks (list): The landmarks registered during environment
|
||||
construction.
|
||||
"""
|
||||
|
||||
def __init__(self, size, n_agents, world_resources, world_landmarks):
|
||||
self.size = size
|
||||
self.sz_h, self.sz_w = size
|
||||
|
||||
self.n_agents = n_agents
|
||||
|
||||
self.resources = world_resources
|
||||
self.landmarks = world_landmarks
|
||||
self.entities = world_resources + world_landmarks
|
||||
|
||||
self._maps = {} # All maps
|
||||
self._blocked = [] # Solid objects that no agent can move through
|
||||
self._private = [] # Solid objects that only permit movement for parent agents
|
||||
self._public = [] # Non-solid objects that agents can move on top of
|
||||
self._resources = [] # Non-solid objects that can be collected
|
||||
|
||||
self._private_landmark_types = []
|
||||
self._resource_source_blocks = []
|
||||
|
||||
self._map_keys = []
|
||||
|
||||
self._accessibility_lookup = {}
|
||||
|
||||
for resource in self.resources:
|
||||
resource_cls = resource_registry.get(resource)
|
||||
if resource_cls.collectible:
|
||||
self._maps[resource] = np.zeros(shape=self.size)
|
||||
self._resources.append(resource)
|
||||
self._map_keys.append(resource)
|
||||
|
||||
self.landmarks.append("{}SourceBlock".format(resource))
|
||||
|
||||
for landmark in self.landmarks:
|
||||
dummy_landmark = landmark_registry.get(landmark)()
|
||||
|
||||
if dummy_landmark.public:
|
||||
self._maps[landmark] = np.zeros(shape=self.size)
|
||||
self._public.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
|
||||
elif dummy_landmark.blocking:
|
||||
self._maps[landmark] = np.zeros(shape=self.size)
|
||||
self._blocked.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
self._accessibility_lookup[landmark] = len(self._accessibility_lookup)
|
||||
|
||||
elif dummy_landmark.private:
|
||||
self._private_landmark_types.append(landmark)
|
||||
self._maps[landmark] = dict(
|
||||
owner=-np.ones(shape=self.size, dtype=np.int16),
|
||||
health=np.zeros(shape=self.size),
|
||||
)
|
||||
self._private.append(landmark)
|
||||
self._map_keys.append(landmark)
|
||||
self._accessibility_lookup[landmark] = len(self._accessibility_lookup)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self._idx_map = np.stack(
|
||||
[i * np.ones(shape=self.size) for i in range(self.n_agents)]
|
||||
)
|
||||
self._idx_array = np.arange(self.n_agents)
|
||||
if self._accessibility_lookup:
|
||||
self._accessibility = np.ones(
|
||||
shape=[len(self._accessibility_lookup), self.n_agents] + self.size,
|
||||
dtype=bool,
|
||||
)
|
||||
self._net_accessibility = None
|
||||
else:
|
||||
self._accessibility = None
|
||||
self._net_accessibility = np.ones(
|
||||
shape=[self.n_agents] + self.size, dtype=bool
|
||||
)
|
||||
|
||||
self._agent_locs = [None for _ in range(self.n_agents)]
|
||||
self._unoccupied = np.ones(self.size, dtype=bool)
|
||||
|
||||
def clear(self, entity_name=None):
|
||||
"""Clear resource and landmark maps."""
|
||||
if entity_name is not None:
|
||||
assert entity_name in self._maps
|
||||
if entity_name in self._private_landmark_types:
|
||||
self._maps[entity_name] = dict(
|
||||
owner=-np.ones(shape=self.size, dtype=np.int16),
|
||||
health=np.zeros(shape=self.size),
|
||||
)
|
||||
else:
|
||||
self._maps[entity_name] *= 0
|
||||
|
||||
else:
|
||||
for name in self.keys():
|
||||
self.clear(entity_name=name)
|
||||
|
||||
if self._accessibility is not None:
|
||||
self._accessibility = np.ones_like(self._accessibility)
|
||||
self._net_accessibility = None
|
||||
|
||||
def clear_agent_loc(self, agent=None):
|
||||
"""Remove agents or agent from the world map."""
|
||||
# Clear all agent locations
|
||||
if agent is None:
|
||||
self._agent_locs = [None for _ in range(self.n_agents)]
|
||||
self._unoccupied[:, :] = 1
|
||||
|
||||
# Clear the location of the provided agent
|
||||
else:
|
||||
i = agent.idx
|
||||
if self._agent_locs[i] is None:
|
||||
return
|
||||
r, c = self._agent_locs[i]
|
||||
self._unoccupied[r, c] = 1
|
||||
self._agent_locs[i] = None
|
||||
|
||||
def set_agent_loc(self, agent, r, c):
|
||||
"""Set the location of agent to [r, c].
|
||||
|
||||
Note:
|
||||
Things might break if you set the agent's location to somewhere it
|
||||
cannot access. Don't do that.
|
||||
"""
|
||||
assert (0 <= r < self.size[0]) and (0 <= c < self.size[1])
|
||||
i = agent.idx
|
||||
# If the agent is currently on the board...
|
||||
if self._agent_locs[i] is not None:
|
||||
curr_r, curr_c = self._agent_locs[i]
|
||||
# If the agent isn't actually moving, just return
|
||||
if (curr_r, curr_c) == (r, c):
|
||||
return
|
||||
# Make the location the agent is currently at as unoccupied
|
||||
# (since the agent is going to move)
|
||||
self._unoccupied[curr_r, curr_c] = 1
|
||||
|
||||
# Set the agent location to the specified coordinates
|
||||
# and update the occupation map
|
||||
agent.state["loc"] = [r, c]
|
||||
self._agent_locs[i] = [r, c]
|
||||
self._unoccupied[r, c] = 0
|
||||
|
||||
def keys(self):
|
||||
"""Return an iterable over map keys."""
|
||||
return self._maps.keys()
|
||||
|
||||
def values(self):
|
||||
"""Return an iterable over map values."""
|
||||
return self._maps.values()
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable over map (key, value) pairs."""
|
||||
return self._maps.items()
|
||||
|
||||
def get(self, entity_name, owner=False):
|
||||
"""Return the map or ownership for entity_name."""
|
||||
assert entity_name in self._maps
|
||||
if entity_name in self._private_landmark_types:
|
||||
sub_key = "owner" if owner else "health"
|
||||
return self._maps[entity_name][sub_key]
|
||||
return self._maps[entity_name]
|
||||
|
||||
def set(self, entity_name, map_state):
|
||||
"""Set the map for entity_name."""
|
||||
if entity_name in self._private_landmark_types:
|
||||
assert "owner" in map_state
|
||||
assert self.get(entity_name, owner=True).shape == map_state["owner"].shape
|
||||
assert "health" in map_state
|
||||
assert self.get(entity_name, owner=False).shape == map_state["health"].shape
|
||||
|
||||
h = np.maximum(0.0, map_state["health"])
|
||||
o = map_state["owner"].astype(np.int16)
|
||||
|
||||
o[h <= 0] = -1
|
||||
tmp = o[h > 0]
|
||||
if len(tmp) > 0:
|
||||
assert np.min(tmp) >= 0
|
||||
|
||||
self._maps[entity_name] = dict(owner=o, health=h)
|
||||
|
||||
owned_by_agent = o[None] == self._idx_map
|
||||
owned_by_none = o[None] == -1
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.logical_or(owned_by_agent, owned_by_none)
|
||||
self._net_accessibility = None
|
||||
|
||||
else:
|
||||
assert self.get(entity_name).shape == map_state.shape
|
||||
self._maps[entity_name] = np.maximum(0, map_state)
|
||||
|
||||
if entity_name in self._blocked:
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.repeat(map_state[None] == 0, self.n_agents, axis=0)
|
||||
self._net_accessibility = None
|
||||
|
||||
def set_add(self, entity_name, map_state):
|
||||
"""Add map_state to the existing map for entity_name."""
|
||||
assert entity_name not in self._private_landmark_types
|
||||
self.set(entity_name, self.get(entity_name) + map_state)
|
||||
|
||||
def get_point(self, entity_name, r, c, **kwargs):
|
||||
"""Return the entity state at the specified coordinates."""
|
||||
point_map = self.get(entity_name, **kwargs)
|
||||
return point_map[r, c]
|
||||
|
||||
def set_point(self, entity_name, r, c, val, owner=None):
|
||||
"""Set the entity state at the specified coordinates."""
|
||||
if entity_name in self._private_landmark_types:
|
||||
assert owner is not None
|
||||
h = self._maps[entity_name]["health"]
|
||||
o = self._maps[entity_name]["owner"]
|
||||
assert o[r, c] == -1 or o[r, c] == int(owner)
|
||||
h[r, c] = np.maximum(0, val)
|
||||
if h[r, c] == 0:
|
||||
o[r, c] = -1
|
||||
else:
|
||||
o[r, c] = int(owner)
|
||||
|
||||
self._maps[entity_name]["owner"] = o
|
||||
self._maps[entity_name]["health"] = h
|
||||
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name], :, r, c
|
||||
] = np.logical_or(o[r, c] == self._idx_array, o[r, c] == -1).astype(bool)
|
||||
self._net_accessibility = None
|
||||
|
||||
else:
|
||||
self._maps[entity_name][r, c] = np.maximum(0, val)
|
||||
|
||||
if entity_name in self._blocked:
|
||||
self._accessibility[
|
||||
self._accessibility_lookup[entity_name]
|
||||
] = np.repeat(np.array([val]) == 0, self.n_agents, axis=0)
|
||||
self._net_accessibility = None
|
||||
|
||||
def set_point_add(self, entity_name, r, c, value, **kwargs):
|
||||
"""Add value to the existing entity state at the specified coordinates."""
|
||||
self.set_point(
|
||||
entity_name,
|
||||
r,
|
||||
c,
|
||||
value + self.get_point(entity_name, r, c, **kwargs),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def is_accessible(self, r, c, agent_id):
|
||||
"""Return True if agent with id agent_id can occupy the location [r, c]."""
|
||||
return bool(self.accessibility[agent_id, r, c])
|
||||
|
||||
def location_resources(self, r, c):
|
||||
"""Return {resource: health} dictionary for any resources at location [r, c]."""
|
||||
return {
|
||||
k: self._maps[k][r, c] for k in self._resources if self._maps[k][r, c] > 0
|
||||
}
|
||||
|
||||
def location_landmarks(self, r, c):
|
||||
"""Return {landmark: health} dictionary for any landmarks at location [r, c]."""
|
||||
tmp = {k: self.get_point(k, r, c) for k in self.keys()}
|
||||
return {k: v for k, v in tmp.items() if k not in self._resources and v > 0}
|
||||
|
||||
@property
|
||||
def unoccupied(self):
|
||||
"""Return a boolean map indicating which locations are unoccupied."""
|
||||
return self._unoccupied
|
||||
|
||||
@property
|
||||
def accessibility(self):
|
||||
"""Return a boolean map indicating which locations are accessible."""
|
||||
if self._net_accessibility is None:
|
||||
self._net_accessibility = self._accessibility.prod(axis=0).astype(bool)
|
||||
return self._net_accessibility
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
"""Return a boolean map indicating which locations are empty.
|
||||
|
||||
Empty locations have no landmarks or resources."""
|
||||
return self.state.sum(axis=0) == 0
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the concatenated maps of landmark and resources."""
|
||||
return np.stack([self.get(k) for k in self.keys()]).astype(np.float32)
|
||||
|
||||
@property
|
||||
def owner_state(self):
|
||||
"""Return the concatenated ownership maps of private landmarks."""
|
||||
return np.stack(
|
||||
[self.get(k, owner=True) for k in self._private_landmark_types]
|
||||
).astype(np.int16)
|
||||
|
||||
@property
|
||||
def state_dict(self):
|
||||
"""Return a dictionary of the map states."""
|
||||
return self._maps
|
||||
|
||||
|
||||
class World:
|
||||
"""Manages the environment's spatial- and agent-states.
|
||||
|
||||
The world object represents the state of the environment, minus whatever state
|
||||
information is implicitly maintained by separate components. The world object
|
||||
maintains the spatial state through an instance of the Maps class. Agent states
|
||||
are maintained through instances of Agent classes (subclasses of BaseAgent),
|
||||
with one such instance for each of the agents in the environment.
|
||||
|
||||
The world object is built during the environment construction, after the
|
||||
required entities have been registered. As part of the world object construction,
|
||||
it instantiates a map object and the agent objects.
|
||||
|
||||
The World class adds some functionality for interfacing with the spatial state
|
||||
(the maps object) and setting/resetting agent locations. But its function is
|
||||
mostly to wrap the stateful, non-component environment objects.
|
||||
|
||||
Args:
|
||||
world_size (list): A length-2 list specifying the dimensions of the 2D world.
|
||||
Interpreted as [height, width].
|
||||
n_agents (int): The number of total agents (does not include planner).
|
||||
agent_composition(dict): Dict of Agent Class names and amount
|
||||
world_resources (list): The resources registered during environment
|
||||
construction.
|
||||
world_landmarks (list): The landmarks registered during environment
|
||||
construction.
|
||||
multi_action_mode_agents (bool): Whether "mobile" agents use multi action mode
|
||||
(see BaseEnvironment in base_env.py).
|
||||
multi_action_mode_planner (bool): Whether the planner agent uses multi action
|
||||
mode (see BaseEnvironment in base_env.py).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size,
|
||||
agent_composition,
|
||||
world_resources,
|
||||
world_landmarks,
|
||||
multi_action_mode_agents,
|
||||
multi_action_mode_planner,
|
||||
):
|
||||
self.world_size = world_size
|
||||
|
||||
self.resources = world_resources
|
||||
self.landmarks = world_landmarks
|
||||
self.multi_action_mode_agents = bool(multi_action_mode_agents)
|
||||
self.multi_action_mode_planner = bool(multi_action_mode_planner)
|
||||
self._agent_class_idx_map={}
|
||||
#create agents
|
||||
self.agent_composition=agent_composition
|
||||
self.n_agents=0
|
||||
self._agents = []
|
||||
for k,v in agent_composition:
|
||||
self._agent_class_idx_map[k]=[]
|
||||
for offset in range(v):
|
||||
agent_class=agent_registry.get(k)
|
||||
self._agents.append(agent_class(self.n_agents,multi_action_mode_agents=self.multi_action_mode_agents))
|
||||
self._agent_class_idx_map[k].append(self.n_agents)
|
||||
self.n_agents+=1
|
||||
self.maps = Maps(world_size, self.n_agents, world_resources, world_landmarks)
|
||||
|
||||
planner_class = agent_registry.get("BasicPlanner")
|
||||
self._planner = planner_class(multi_action_mode=self.multi_action_mode_planner)
|
||||
|
||||
self.timestep = 0
|
||||
|
||||
# CUDA-related attributes (for GPU simulations).
|
||||
# These will be set via the env_wrapper, if required.
|
||||
self.use_cuda = False
|
||||
self.cuda_function_manager = None
|
||||
self.cuda_data_manager = None
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
"""Return a list of the agent objects in the world (sorted by index)."""
|
||||
return self._agents
|
||||
|
||||
@property
|
||||
def planner(self):
|
||||
"""Return the planner agent object."""
|
||||
return self._planner
|
||||
|
||||
@property
|
||||
def loc_map(self):
|
||||
"""Return a map indicating the agent index occupying each location.
|
||||
|
||||
Locations with a value of -1 are not occupied by an agent.
|
||||
"""
|
||||
idx_map = -np.ones(shape=self.world_size, dtype=np.int16)
|
||||
for agent in self.agents:
|
||||
r, c = agent.loc
|
||||
idx_map[r, c] = int(agent.idx)
|
||||
return idx_map
|
||||
|
||||
def get_random_order_agents(self):
|
||||
"""The agent list in a randomized order."""
|
||||
agent_order = np.random.permutation(self.n_agents)
|
||||
agents = self.agents
|
||||
return [agents[i] for i in agent_order]
|
||||
|
||||
def get_agent_class(self,idx):
|
||||
"""Return class name of agent"""
|
||||
return self.agents[idx].name
|
||||
|
||||
def is_valid(self, r, c):
|
||||
"""Return True if the coordinates [r, c] are within the game boundaries."""
|
||||
return (0 <= r < self.world_size[0]) and (0 <= c < self.world_size[1])
|
||||
|
||||
def is_location_accessible(self, r, c, agent):
|
||||
"""Return True if location [r, c] is accessible to agent."""
|
||||
if not self.is_valid(r, c):
|
||||
return False
|
||||
return self.maps.is_accessible(r, c, agent.idx)
|
||||
|
||||
def can_agent_occupy(self, r, c, agent):
|
||||
"""Return True if location [r, c] is accessible to agent and unoccupied."""
|
||||
if not self.is_location_accessible(r, c, agent):
|
||||
return False
|
||||
if self.maps.unoccupied[r, c]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_agent_locs(self):
|
||||
"""Take all agents off the board. Useful for resetting."""
|
||||
for agent in self.agents:
|
||||
agent.state["loc"] = [-1, -1]
|
||||
self.maps.clear_agent_loc()
|
||||
|
||||
def agent_locs_are_valid(self):
|
||||
"""Returns True if all agent locations comply with world semantics."""
|
||||
return all(
|
||||
self.is_location_accessible(*agent.loc, agent) for agent in self.agents
|
||||
)
|
||||
|
||||
def set_agent_loc(self, agent, r, c):
|
||||
"""Set the agent's location to coordinates [r, c] if possible.
|
||||
|
||||
If agent cannot occupy [r, c], do nothing."""
|
||||
if self.can_agent_occupy(r, c, agent):
|
||||
self.maps.set_agent_loc(agent, r, c)
|
||||
return [int(coord) for coord in agent.loc]
|
||||
|
||||
def location_resources(self, r, c):
|
||||
"""Return {resource: health} dictionary for any resources at location [r, c]."""
|
||||
if not self.is_valid(r, c):
|
||||
return {}
|
||||
return self.maps.location_resources(r, c)
|
||||
|
||||
def location_landmarks(self, r, c):
|
||||
"""Return {landmark: health} dictionary for any landmarks at location [r, c]."""
|
||||
if not self.is_valid(r, c):
|
||||
return {}
|
||||
return self.maps.location_landmarks(r, c)
|
||||
|
||||
def create_landmark(self, landmark_name, r, c, agent_idx=None):
|
||||
"""Place a landmark on the world map.
|
||||
|
||||
Place landmark of type landmark_name at the given coordinates, indicating
|
||||
agent ownership if applicable."""
|
||||
self.maps.set_point(landmark_name, r, c, 1, owner=agent_idx)
|
||||
|
||||
def consume_resource(self, resource_name, r, c):
|
||||
"""Consume a unit of resource_name from location [r, c]."""
|
||||
self.maps.set_point_add(resource_name, r, c, -1)
|
||||
19
ai_economist/foundation/components/__init__.py
Normal file
19
ai_economist/foundation/components/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.base_component import component_registry
|
||||
|
||||
from . import (
|
||||
build,
|
||||
continuous_double_auction,
|
||||
covid19_components,
|
||||
move,
|
||||
redistribution,
|
||||
simple_labor,
|
||||
)
|
||||
|
||||
# Import files that add Component class(es) to component_registry
|
||||
# ---------------------------------------------------------------
|
||||
266
ai_economist/foundation/components/build.py
Normal file
266
ai_economist/foundation/components/build.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class Build(BaseComponent):
|
||||
"""
|
||||
Allows mobile agents to build house landmarks in the world using stone and wood,
|
||||
earning income.
|
||||
|
||||
Can be configured to include heterogeneous building skill where agents earn
|
||||
different levels of income when building.
|
||||
|
||||
Args:
|
||||
payment (int): Default amount of coin agents earn from building.
|
||||
Must be >= 0. Default is 10.
|
||||
payment_max_skill_multiplier (int): Maximum skill multiplier that an agent
|
||||
can sample. Must be >= 1. Default is 1.
|
||||
skill_dist (str): Distribution type for sampling skills. Default ("none")
|
||||
gives all agents identical skill equal to a multiplier of 1. "pareto" and
|
||||
"lognormal" sample skills from the associated distributions.
|
||||
build_labor (float): Labor cost associated with building a house.
|
||||
Must be >= 0. Default is 10.
|
||||
"""
|
||||
|
||||
name = "Build"
|
||||
component_type = "Build"
|
||||
required_entities = ["Wood", "Stone", "Coin", "House", "Labor"]
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
payment=10,
|
||||
payment_max_skill_multiplier=1,
|
||||
skill_dist="none",
|
||||
build_labor=10.0,
|
||||
**base_component_kwargs
|
||||
):
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
self.payment = int(payment)
|
||||
assert self.payment >= 0
|
||||
|
||||
self.payment_max_skill_multiplier = int(payment_max_skill_multiplier)
|
||||
assert self.payment_max_skill_multiplier >= 1
|
||||
|
||||
self.resource_cost = {"Wood": 1, "Stone": 1}
|
||||
|
||||
self.build_labor = float(build_labor)
|
||||
assert self.build_labor >= 0
|
||||
|
||||
self.skill_dist = skill_dist.lower()
|
||||
assert self.skill_dist in ["none", "pareto", "lognormal"]
|
||||
|
||||
self.sampled_skills = {}
|
||||
|
||||
self.builds = []
|
||||
|
||||
def agent_can_build(self, agent):
|
||||
"""Return True if agent can actually build in its current location."""
|
||||
# See if the agent has the resources necessary to complete the action
|
||||
for resource, cost in self.resource_cost.items():
|
||||
if agent.state["inventory"][resource] < cost:
|
||||
return False
|
||||
|
||||
# Do nothing if this spot is already occupied by a landmark or resource
|
||||
if self.world.location_resources(*agent.loc):
|
||||
return False
|
||||
if self.world.location_landmarks(*agent.loc):
|
||||
return False
|
||||
# If we made it here, the agent can build.
|
||||
return True
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Add a single action (build) for mobile agents.
|
||||
"""
|
||||
# This component adds 1 action that mobile agents can take: build a house
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return 1
|
||||
|
||||
return None
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
For mobile agents, add state fields for building skill.
|
||||
"""
|
||||
if agent_cls_name not in self.agent_subclasses:
|
||||
return {}
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return {"build_payment": float(self.payment), "build_skill": 1}
|
||||
raise NotImplementedError
|
||||
|
||||
def component_step(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Convert stone+wood to house+coin for agents that choose to build and can.
|
||||
"""
|
||||
world = self.world
|
||||
build = []
|
||||
# Apply any building actions taken by the mobile agents
|
||||
for agent in world.get_random_order_agents():
|
||||
|
||||
action = agent.get_component_action(self.name)
|
||||
|
||||
# This component doesn't apply to this agent!
|
||||
if action is None:
|
||||
continue
|
||||
|
||||
# NO-OP!
|
||||
if action == 0:
|
||||
pass
|
||||
|
||||
# Build! (If you can.)
|
||||
elif action == 1:
|
||||
if self.agent_can_build(agent):
|
||||
# Remove the resources
|
||||
for resource, cost in self.resource_cost.items():
|
||||
agent.state["inventory"][resource] -= cost
|
||||
|
||||
# Place a house where the agent is standing
|
||||
loc_r, loc_c = agent.loc
|
||||
world.create_landmark("House", loc_r, loc_c, agent.idx)
|
||||
|
||||
# Receive payment for the house
|
||||
agent.state["inventory"]["Coin"] += agent.state["build_payment"]
|
||||
|
||||
# Incur the labor cost for building
|
||||
agent.state["endogenous"]["Labor"] += self.build_labor
|
||||
|
||||
build.append(
|
||||
{
|
||||
"builder": agent.idx,
|
||||
"loc": np.array(agent.loc),
|
||||
"income": float(agent.state["build_payment"]),
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.builds.append(build)
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Here, agents observe their build skill. The planner does not observe anything
|
||||
from this component.
|
||||
"""
|
||||
|
||||
obs_dict = dict()
|
||||
for agent in self.world.agents:
|
||||
obs_dict[agent.idx] = {
|
||||
"build_payment": agent.state["build_payment"] / self.payment,
|
||||
"build_skill": self.sampled_skills[agent.idx],
|
||||
}
|
||||
|
||||
return obs_dict
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Prevent building only if a landmark already occupies the agent's location.
|
||||
"""
|
||||
|
||||
masks = {}
|
||||
# Mobile agents' build action is masked if they cannot build with their
|
||||
# current location and/or endowment
|
||||
for agent in self.world.agents:
|
||||
masks[agent.idx] = np.array([self.agent_can_build(agent)])
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def get_metrics(self):
|
||||
"""
|
||||
Metrics that capture what happened through this component.
|
||||
|
||||
Returns:
|
||||
metrics (dict): A dictionary of {"metric_name": metric_value},
|
||||
where metric_value is a scalar.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
build_stats = {a.idx: {"n_builds": 0} for a in world.agents}
|
||||
for builds in self.builds:
|
||||
for build in builds:
|
||||
idx = build["builder"]
|
||||
build_stats[idx]["n_builds"] += 1
|
||||
|
||||
out_dict = {}
|
||||
for a in world.agents:
|
||||
for k, v in build_stats[a.idx].items():
|
||||
out_dict["{}/{}".format(a.idx, k)] = v
|
||||
|
||||
num_houses = np.sum(world.maps.get("House") > 0)
|
||||
out_dict["total_builds"] = num_houses
|
||||
|
||||
return out_dict
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Re-sample agents' building skills.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
self.sampled_skills = {agent.idx: 1 for agent in world.agents}
|
||||
|
||||
PMSM = self.payment_max_skill_multiplier
|
||||
|
||||
for agent in world.agents:
|
||||
if self.skill_dist == "none":
|
||||
sampled_skill = 1
|
||||
pay_rate = 1
|
||||
elif self.skill_dist == "pareto":
|
||||
sampled_skill = np.random.pareto(4)
|
||||
pay_rate = np.minimum(PMSM, (PMSM - 1) * sampled_skill + 1)
|
||||
elif self.skill_dist == "lognormal":
|
||||
sampled_skill = np.random.lognormal(-1, 0.5)
|
||||
pay_rate = np.minimum(PMSM, (PMSM - 1) * sampled_skill + 1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
agent.state["build_payment"] = float(pay_rate * self.payment)
|
||||
agent.state["build_skill"] = float(sampled_skill)
|
||||
|
||||
self.sampled_skills[agent.idx] = sampled_skill
|
||||
|
||||
self.builds = []
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Log builds.
|
||||
|
||||
Returns:
|
||||
builds (list): A list of build events. Each entry corresponds to a single
|
||||
timestep and contains a description of any builds that occurred on
|
||||
that timestep.
|
||||
|
||||
"""
|
||||
return self.builds
|
||||
679
ai_economist/foundation/components/continuous_double_auction.py
Normal file
679
ai_economist/foundation/components/continuous_double_auction.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
from ai_economist.foundation.entities import resource_registry
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class ContinuousDoubleAuction(BaseComponent):
|
||||
"""Allows mobile agents to buy/sell collectible resources with one another.
|
||||
|
||||
Implements a commodity-exchange-style market where agents may sell a unit of
|
||||
resource by submitting an ask (saying the minimum it will accept in payment)
|
||||
or may buy a resource by submitting a bid (saying the maximum it will pay in
|
||||
exchange for a unit of a given resource).
|
||||
|
||||
Args:
|
||||
max_bid_ask (int): Maximum amount of coin that an agent can bid or ask for.
|
||||
Must be >= 1. Default is 10 coin.
|
||||
order_labor (float): Amount of labor incurred when an agent creates an order.
|
||||
Must be >= 0. Default is 0.25.
|
||||
order_duration (int): Number of environment timesteps before an unfilled
|
||||
bid/ask expires. Must be >= 1. Default is 50 timesteps.
|
||||
max_num_orders (int, optional): Maximum number of bids + asks that an agent can
|
||||
have open for a given resource. Must be >= 1. Default is no limit to
|
||||
number of orders.
|
||||
"""
|
||||
|
||||
name = "ContinuousDoubleAuction"
|
||||
component_type = "Trade"
|
||||
required_entities = ["Coin", "Labor"]
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
max_bid_ask=10,
|
||||
order_labor=0.25,
|
||||
order_duration=50,
|
||||
max_num_orders=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# The max amount (in coin) that an agent can bid/ask for 1 unit of a commodity
|
||||
self.max_bid_ask = int(max_bid_ask)
|
||||
assert self.max_bid_ask >= 1
|
||||
self.price_floor = 0
|
||||
self.price_ceiling = int(max_bid_ask)
|
||||
|
||||
# The amount of time (in timesteps) that an order stays in the books
|
||||
# before it expires
|
||||
self.order_duration = int(order_duration)
|
||||
assert self.order_duration >= 1
|
||||
|
||||
# The maximum number of bid+ask orders an agent can have open
|
||||
# for each type of commodity
|
||||
self.max_num_orders = int(max_num_orders or self.order_duration)
|
||||
assert self.max_num_orders >= 1
|
||||
|
||||
# The labor cost associated with creating a bid or ask order
|
||||
|
||||
self.order_labor = float(order_labor)
|
||||
self.order_labor = max(self.order_labor, 0.0)
|
||||
|
||||
# Each collectible resource in the world can be traded via this component
|
||||
self.commodities = [
|
||||
r for r in self.world.resources if resource_registry.get(r).collectible
|
||||
]
|
||||
|
||||
# These get reset at the start of an episode:
|
||||
self.asks = {c: [] for c in self.commodities}
|
||||
self.bids = {c: [] for c in self.commodities}
|
||||
self.n_orders = {
|
||||
c: {i: 0 for i in range(self.n_agents)} for c in self.commodities
|
||||
}
|
||||
self.executed_trades = []
|
||||
self.price_history = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
self.bid_hists = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
self.ask_hists = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
|
||||
# Convenience methods
|
||||
# -------------------
|
||||
|
||||
def _price_zeros(self):
|
||||
if 1 + self.price_ceiling - self.price_floor <= 0:
|
||||
print("ERROR!", self.price_ceiling, self.price_floor)
|
||||
|
||||
return np.zeros(1 + self.price_ceiling - self.price_floor)
|
||||
|
||||
def available_asks(self, resource, agent):
|
||||
"""
|
||||
Get a histogram of asks for resource to which agent could bid against.
|
||||
|
||||
Args:
|
||||
resource (str): Name of the resource
|
||||
agent (BasicMobileAgent or None): Object of agent for which available
|
||||
asks are being queried. If None, all asks are considered available.
|
||||
|
||||
Returns:
|
||||
ask_hist (ndarray): For each possible price level, the number of
|
||||
available asks.
|
||||
"""
|
||||
if agent is None:
|
||||
a_idx = -1
|
||||
else:
|
||||
a_idx = agent.idx
|
||||
ask_hist = self._price_zeros()
|
||||
for i, h in self.ask_hists[resource].items():
|
||||
if a_idx != i:
|
||||
ask_hist += h
|
||||
return ask_hist
|
||||
|
||||
def available_bids(self, resource, agent):
|
||||
"""
|
||||
Get a histogram of bids for resource to which agent could ask against.
|
||||
|
||||
Args:
|
||||
resource (str): Name of the resource
|
||||
agent (BasicMobileAgent or None): Object of agent for which available
|
||||
bids are being queried. If None, all bids are considered available.
|
||||
|
||||
Returns:
|
||||
bid_hist (ndarray): For each possible price level, the number of
|
||||
available bids.
|
||||
"""
|
||||
if agent is None:
|
||||
a_idx = -1
|
||||
else:
|
||||
a_idx = agent.idx
|
||||
bid_hist = self._price_zeros()
|
||||
for i, h in self.bid_hists[resource].items():
|
||||
if a_idx != i:
|
||||
bid_hist += h
|
||||
return bid_hist
|
||||
|
||||
def can_bid(self, resource, agent):
|
||||
"""If agent can submit a bid for resource."""
|
||||
return self.n_orders[resource][agent.idx] < self.max_num_orders
|
||||
|
||||
def can_ask(self, resource, agent):
|
||||
"""If agent can submit an ask for resource."""
|
||||
return (
|
||||
self.n_orders[resource][agent.idx] < self.max_num_orders
|
||||
and agent.state["inventory"][resource] > 0
|
||||
)
|
||||
|
||||
# Core components for this market
|
||||
# -------------------------------
|
||||
|
||||
def create_bid(self, resource, agent, max_payment):
|
||||
"""Create a new bid for resource, with agent offering max_payment.
|
||||
|
||||
On a successful trade, payment will be at most max_payment, possibly less.
|
||||
|
||||
The agent places the bid coin into escrow so that it may not be spent on
|
||||
something else while the order exists.
|
||||
"""
|
||||
|
||||
# The agent is past the max number of orders
|
||||
# or doesn't have enough money, do nothing
|
||||
if (not self.can_bid(resource, agent)) or agent.state["inventory"][
|
||||
"Coin"
|
||||
] < max_payment:
|
||||
return
|
||||
|
||||
assert self.price_floor <= max_payment <= self.price_ceiling
|
||||
|
||||
bid = {"buyer": agent.idx, "bid": int(max_payment), "bid_lifetime": 0}
|
||||
|
||||
# Add this to the bid book
|
||||
self.bids[resource].append(bid)
|
||||
self.bid_hists[resource][bid["buyer"]][bid["bid"] - self.price_floor] += 1
|
||||
self.n_orders[resource][agent.idx] += 1
|
||||
|
||||
# Set aside whatever money the agent is willing to pay
|
||||
# (will get excess back if price ends up being less)
|
||||
_ = agent.inventory_to_escrow("Coin", int(max_payment))
|
||||
|
||||
# Incur the labor cost of creating an order
|
||||
agent.state["endogenous"]["Labor"] += self.order_labor
|
||||
|
||||
def create_ask(self, resource, agent, min_income):
|
||||
"""
|
||||
Create a new ask for resource, with agent asking for min_income.
|
||||
|
||||
On a successful trade, income will be at least min_income, possibly more.
|
||||
|
||||
The agent places one unit of resource into escrow so that it may not be used
|
||||
for something else while the order exists.
|
||||
"""
|
||||
# The agent is past the max number of orders
|
||||
# or doesn't the resource it's trying to sell, do nothing
|
||||
if not self.can_ask(resource, agent):
|
||||
return
|
||||
|
||||
# is there an upper limit?
|
||||
assert self.price_floor <= min_income <= self.price_ceiling
|
||||
|
||||
ask = {"seller": agent.idx, "ask": int(min_income), "ask_lifetime": 0}
|
||||
|
||||
# Add this to the ask book
|
||||
self.asks[resource].append(ask)
|
||||
self.ask_hists[resource][ask["seller"]][ask["ask"] - self.price_floor] += 1
|
||||
self.n_orders[resource][agent.idx] += 1
|
||||
|
||||
# Set aside the resource the agent is willing to sell
|
||||
amount = agent.inventory_to_escrow(resource, 1)
|
||||
assert amount == 1
|
||||
|
||||
# Incur the labor cost of creating an order
|
||||
agent.state["endogenous"]["Labor"] += self.order_labor
|
||||
|
||||
def match_orders(self):
|
||||
"""
|
||||
This implements the continuous double auction by identifying valid bid/ask
|
||||
pairs and executing trades accordingly.
|
||||
|
||||
Higher (lower) bids (asks) are given priority over lower (higher) bids (asks).
|
||||
Trades are executed using the price of whichever bid/ask order was placed
|
||||
first: bid price if bid was placed first, ask price otherwise.
|
||||
|
||||
Trading removes the payment and resource from bidder's and asker's escrow,
|
||||
respectively, and puts them in the other's inventory.
|
||||
"""
|
||||
self.executed_trades.append([])
|
||||
|
||||
for resource in self.commodities:
|
||||
possible_match = [True for _ in range(self.n_agents)]
|
||||
keep_checking = True
|
||||
|
||||
bids = sorted(
|
||||
self.bids[resource],
|
||||
key=lambda b: (b["bid"], b["bid_lifetime"]),
|
||||
reverse=True,
|
||||
)
|
||||
asks = sorted(
|
||||
self.asks[resource], key=lambda a: (a["ask"], -a["ask_lifetime"])
|
||||
)
|
||||
|
||||
while any(possible_match) and keep_checking:
|
||||
idx_bid, idx_ask = 0, 0
|
||||
while True:
|
||||
# Out of bids to check. Exit both loops.
|
||||
if idx_bid >= len(bids):
|
||||
keep_checking = False
|
||||
break
|
||||
|
||||
# Already know this buyer is no good for this round.
|
||||
# Skip to next bid.
|
||||
if not possible_match[bids[idx_bid]["buyer"]]:
|
||||
idx_bid += 1
|
||||
|
||||
# Out of asks to check. This buyer won't find a match on this round.
|
||||
# (maybe) Restart inner loop.
|
||||
elif idx_ask >= len(asks):
|
||||
possible_match[bids[idx_bid]["buyer"]] = False
|
||||
break
|
||||
|
||||
# Skip to next ask if this ask comes from the buyer
|
||||
# of the current bid.
|
||||
elif asks[idx_ask]["seller"] == bids[idx_bid]["buyer"]:
|
||||
idx_ask += 1
|
||||
|
||||
# If this bid/ask pair can't be matched, this buyer
|
||||
# can't be matched. (maybe) Restart inner loop.
|
||||
elif bids[idx_bid]["bid"] < asks[idx_ask]["ask"]:
|
||||
possible_match[bids[idx_bid]["buyer"]] = False
|
||||
break
|
||||
|
||||
# TRADE! (then restart inner loop)
|
||||
else:
|
||||
bid = bids.pop(idx_bid)
|
||||
ask = asks.pop(idx_ask)
|
||||
|
||||
trade = {"commodity": resource}
|
||||
trade.update(bid)
|
||||
trade.update(ask)
|
||||
|
||||
if (
|
||||
bid["bid_lifetime"] <= ask["ask_lifetime"]
|
||||
): # Ask came earlier. (in other words,
|
||||
# trade triggered by new bid)
|
||||
trade["price"] = int(trade["ask"])
|
||||
else: # Bid came earlier. (in other words,
|
||||
# trade triggered by new ask)
|
||||
trade["price"] = int(trade["bid"])
|
||||
trade["cost"] = trade["price"] # What the buyer pays in total
|
||||
trade["income"] = trade[
|
||||
"price"
|
||||
] # What the seller receives in total
|
||||
|
||||
buyer = self.world.agents[trade["buyer"]]
|
||||
seller = self.world.agents[trade["seller"]]
|
||||
|
||||
# Bookkeeping
|
||||
self.bid_hists[resource][bid["buyer"]][
|
||||
bid["bid"] - self.price_floor
|
||||
] -= 1
|
||||
self.ask_hists[resource][ask["seller"]][
|
||||
ask["ask"] - self.price_floor
|
||||
] -= 1
|
||||
self.n_orders[trade["commodity"]][seller.idx] -= 1
|
||||
self.n_orders[trade["commodity"]][buyer.idx] -= 1
|
||||
self.executed_trades[-1].append(trade)
|
||||
self.price_history[resource][trade["seller"]][
|
||||
trade["price"]
|
||||
] += 1
|
||||
|
||||
# The resource goes from the seller's escrow
|
||||
# to the buyer's inventory
|
||||
seller.state["escrow"][resource] -= 1
|
||||
buyer.state["inventory"][resource] += 1
|
||||
|
||||
# Buyer's money (already set aside) leaves escrow
|
||||
pre_payment = int(trade["bid"])
|
||||
buyer.state["escrow"]["Coin"] -= pre_payment
|
||||
assert buyer.state["escrow"]["Coin"] >= 0
|
||||
|
||||
# Payment is removed from the pre_payment
|
||||
# and given to the seller. Excess returned to buyer.
|
||||
payment_to_seller = int(trade["price"])
|
||||
excess_payment_from_buyer = pre_payment - payment_to_seller
|
||||
assert excess_payment_from_buyer >= 0
|
||||
seller.state["inventory"]["Coin"] += payment_to_seller
|
||||
buyer.state["inventory"]["Coin"] += excess_payment_from_buyer
|
||||
|
||||
# Restart the inner loop
|
||||
break
|
||||
|
||||
# Keep the unfilled bids/asks
|
||||
self.bids[resource] = bids
|
||||
self.asks[resource] = asks
|
||||
|
||||
def remove_expired_orders(self):
|
||||
"""
|
||||
Increment the time counter for any unfilled bids/asks and remove expired
|
||||
orders from the market.
|
||||
|
||||
When orders expire, the payment or resource is removed from escrow and
|
||||
returned to the inventory and the associated order is removed from the order
|
||||
books.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
for resource in self.commodities:
|
||||
|
||||
bids_ = []
|
||||
for bid in self.bids[resource]:
|
||||
bid["bid_lifetime"] += 1
|
||||
# If the bid is not expired, keep it in the bids
|
||||
if bid["bid_lifetime"] <= self.order_duration:
|
||||
bids_.append(bid)
|
||||
# Otherwise, remove it and do the associated bookkeeping
|
||||
else:
|
||||
# Return the set aside money to the buyer
|
||||
amount = world.agents[bid["buyer"]].escrow_to_inventory(
|
||||
"Coin", bid["bid"]
|
||||
)
|
||||
assert amount == bid["bid"]
|
||||
# Adjust the bid histogram to reflect the removal of the bid
|
||||
self.bid_hists[resource][bid["buyer"]][
|
||||
bid["bid"] - self.price_floor
|
||||
] -= 1
|
||||
# Adjust the order counter
|
||||
self.n_orders[resource][bid["buyer"]] -= 1
|
||||
|
||||
asks_ = []
|
||||
for ask in self.asks[resource]:
|
||||
ask["ask_lifetime"] += 1
|
||||
# If the ask is not expired, keep it in the asks
|
||||
if ask["ask_lifetime"] <= self.order_duration:
|
||||
asks_.append(ask)
|
||||
# Otherwise, remove it and do the associated bookkeeping
|
||||
else:
|
||||
# Return the set aside resource to the seller
|
||||
resource_unit = world.agents[ask["seller"]].escrow_to_inventory(
|
||||
resource, 1
|
||||
)
|
||||
assert resource_unit == 1
|
||||
# Adjust the ask histogram to reflect the removal of the ask
|
||||
self.ask_hists[resource][ask["seller"]][
|
||||
ask["ask"] - self.price_floor
|
||||
] -= 1
|
||||
# Adjust the order counter
|
||||
self.n_orders[resource][ask["seller"]] -= 1
|
||||
|
||||
self.bids[resource] = bids_
|
||||
self.asks[resource] = asks_
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Adds 2*C action spaces [ (bid+ask) * n_commodities ], each with 1 + max_bid_ask
|
||||
actions corresponding to price levels 0 to max_bid_ask.
|
||||
"""
|
||||
# This component adds 2*(1+max_bid_ask)*n_resources possible actions:
|
||||
# buy/sell x each-price x each-resource
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
trades = []
|
||||
for c in self.commodities:
|
||||
trades.append(
|
||||
("Buy_{}".format(c), 1 + self.max_bid_ask)
|
||||
) # How much willing to pay for c
|
||||
trades.append(
|
||||
("Sell_{}".format(c), 1 + self.max_bid_ask)
|
||||
) # How much need to receive to sell c
|
||||
return trades
|
||||
|
||||
return None
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
"""
|
||||
# This component doesn't add any state fields
|
||||
return {}
|
||||
|
||||
def component_step(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Create new bids and asks, match and execute valid order pairs, and manage
|
||||
order expiration.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
for resource in self.commodities:
|
||||
for agent in world.agents:
|
||||
self.price_history[resource][agent.idx] *= 0.995
|
||||
|
||||
# Create bid action
|
||||
# -----------------
|
||||
resource_action = agent.get_component_action(
|
||||
self.name, "Buy_{}".format(resource)
|
||||
)
|
||||
|
||||
# No-op
|
||||
if resource_action == 0:
|
||||
pass
|
||||
|
||||
# Create a bid
|
||||
elif resource_action <= self.max_bid_ask + 1:
|
||||
self.create_bid(resource, agent, max_payment=resource_action - 1)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
# Create ask action
|
||||
# -----------------
|
||||
resource_action = agent.get_component_action(
|
||||
self.name, "Sell_{}".format(resource)
|
||||
)
|
||||
|
||||
# No-op
|
||||
if resource_action == 0:
|
||||
pass
|
||||
|
||||
# Create an ask
|
||||
elif resource_action <= self.max_bid_ask + 1:
|
||||
self.create_ask(resource, agent, min_income=resource_action - 1)
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
# Here's where the magic happens:
|
||||
self.match_orders() # Pair bids and asks
|
||||
self.remove_expired_orders() # Get rid of orders that have expired
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Here, agents and the planner both observe historical market behavior and
|
||||
outstanding bids/asks for each tradable commodity. Agents only see the
|
||||
outstanding bids/asks to which they could respond (that is, that they did not
|
||||
submit). Agents also see their own outstanding bids/asks.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
obs = {a.idx: {} for a in world.agents + [world.planner]}
|
||||
|
||||
prices = np.arange(self.price_floor, self.price_ceiling + 1)
|
||||
for c in self.commodities:
|
||||
net_price_history = np.sum(
|
||||
np.stack([self.price_history[c][i] for i in range(self.n_agents)]),
|
||||
axis=0,
|
||||
)
|
||||
market_rate = prices.dot(net_price_history) / np.maximum(
|
||||
0.001, np.sum(net_price_history)
|
||||
)
|
||||
scaled_price_history = net_price_history * self.inv_scale
|
||||
|
||||
full_asks = self.available_asks(c, agent=None)
|
||||
full_bids = self.available_bids(c, agent=None)
|
||||
|
||||
obs[world.planner.idx].update(
|
||||
{
|
||||
"market_rate-{}".format(c): market_rate,
|
||||
"price_history-{}".format(c): scaled_price_history,
|
||||
"full_asks-{}".format(c): full_asks,
|
||||
"full_bids-{}".format(c): full_bids,
|
||||
}
|
||||
)
|
||||
|
||||
for _, agent in enumerate(world.agents):
|
||||
# Private to the agent
|
||||
obs[agent.idx].update(
|
||||
{
|
||||
"market_rate-{}".format(c): market_rate,
|
||||
"price_history-{}".format(c): scaled_price_history,
|
||||
"available_asks-{}".format(c): full_asks
|
||||
- self.ask_hists[c][agent.idx],
|
||||
"available_bids-{}".format(c): full_bids
|
||||
- self.bid_hists[c][agent.idx],
|
||||
"my_asks-{}".format(c): self.ask_hists[c][agent.idx],
|
||||
"my_bids-{}".format(c): self.bid_hists[c][agent.idx],
|
||||
}
|
||||
)
|
||||
|
||||
return obs
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Agents cannot submit bids/asks for resources where they are at the order
|
||||
limit. In addition, they may only submit asks for resources they possess and
|
||||
bids for which they can pay.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
masks = dict()
|
||||
|
||||
for agent in world.agents:
|
||||
masks[agent.idx] = {}
|
||||
|
||||
can_pay = np.arange(self.max_bid_ask + 1) <= agent.inventory["Coin"]
|
||||
|
||||
for resource in self.commodities:
|
||||
if not self.can_ask(resource, agent): # asks_maxed:
|
||||
masks[agent.idx]["Sell_{}".format(resource)] = np.zeros(
|
||||
1 + self.max_bid_ask
|
||||
)
|
||||
else:
|
||||
masks[agent.idx]["Sell_{}".format(resource)] = np.ones(
|
||||
1 + self.max_bid_ask
|
||||
)
|
||||
|
||||
if not self.can_bid(resource, agent):
|
||||
masks[agent.idx]["Buy_{}".format(resource)] = np.zeros(
|
||||
1 + self.max_bid_ask
|
||||
)
|
||||
else:
|
||||
masks[agent.idx]["Buy_{}".format(resource)] = can_pay.astype(
|
||||
np.int32
|
||||
)
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def get_metrics(self):
|
||||
"""
|
||||
Metrics that capture what happened through this component.
|
||||
|
||||
Returns:
|
||||
metrics (dict): A dictionary of {"metric_name": metric_value},
|
||||
where metric_value is a scalar.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
trade_keys = ["price", "cost", "income"]
|
||||
|
||||
selling_stats = {
|
||||
a.idx: {
|
||||
c: {k: 0 for k in trade_keys + ["n_sales"]} for c in self.commodities
|
||||
}
|
||||
for a in world.agents
|
||||
}
|
||||
buying_stats = {
|
||||
a.idx: {
|
||||
c: {k: 0 for k in trade_keys + ["n_sales"]} for c in self.commodities
|
||||
}
|
||||
for a in world.agents
|
||||
}
|
||||
|
||||
n_trades = 0
|
||||
|
||||
for trades in self.executed_trades:
|
||||
for trade in trades:
|
||||
n_trades += 1
|
||||
i_s, i_b, c = trade["seller"], trade["buyer"], trade["commodity"]
|
||||
selling_stats[i_s][c]["n_sales"] += 1
|
||||
buying_stats[i_b][c]["n_sales"] += 1
|
||||
for k in trade_keys:
|
||||
selling_stats[i_s][c][k] += trade[k]
|
||||
buying_stats[i_b][c][k] += trade[k]
|
||||
|
||||
out_dict = {}
|
||||
for a in world.agents:
|
||||
for c in self.commodities:
|
||||
for stats, prefix in zip(
|
||||
[selling_stats, buying_stats], ["Sell", "Buy"]
|
||||
):
|
||||
n = stats[a.idx][c]["n_sales"]
|
||||
if n == 0:
|
||||
for k in trade_keys:
|
||||
stats[a.idx][c][k] = np.nan
|
||||
else:
|
||||
for k in trade_keys:
|
||||
stats[a.idx][c][k] /= n
|
||||
|
||||
for k, v in stats[a.idx][c].items():
|
||||
out_dict["{}/{}{}/{}".format(a.idx, prefix, c, k)] = v
|
||||
|
||||
out_dict["n_trades"] = n_trades
|
||||
|
||||
return out_dict
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Reset the order books.
|
||||
"""
|
||||
self.bids = {c: [] for c in self.commodities}
|
||||
self.asks = {c: [] for c in self.commodities}
|
||||
self.n_orders = {
|
||||
c: {i: 0 for i in range(self.n_agents)} for c in self.commodities
|
||||
}
|
||||
|
||||
self.price_history = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
self.bid_hists = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
self.ask_hists = {
|
||||
c: {i: self._price_zeros() for i in range(self.n_agents)}
|
||||
for c in self.commodities
|
||||
}
|
||||
|
||||
self.executed_trades = []
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Log executed trades.
|
||||
|
||||
Returns:
|
||||
trades (list): A list of trade events. Each entry corresponds to a single
|
||||
timestep and contains a description of any trades that occurred on
|
||||
that timestep.
|
||||
"""
|
||||
return self.executed_trades
|
||||
663
ai_economist/foundation/components/covid19_components.py
Normal file
663
ai_economist/foundation/components/covid19_components.py
Normal file
@@ -0,0 +1,663 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import GPUtil
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
|
||||
try:
|
||||
num_gpus_available = len(GPUtil.getAvailable())
|
||||
print(f"Inside covid19_components.py: {num_gpus_available} GPUs are available.")
|
||||
if num_gpus_available == 0:
|
||||
print("No GPUs found! Running the simulation on a CPU.")
|
||||
else:
|
||||
from warp_drive.utils.constants import Constants
|
||||
from warp_drive.utils.data_feed import DataFeed
|
||||
|
||||
_OBSERVATIONS = Constants.OBSERVATIONS
|
||||
_ACTIONS = Constants.ACTIONS
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Warning: The 'WarpDrive' package is not found and cannot be used! "
|
||||
"If you wish to use WarpDrive, please run "
|
||||
"'pip install rl-warp-drive' first."
|
||||
)
|
||||
except ValueError:
|
||||
print("No GPUs found! Running the simulation on a CPU.")
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class ControlUSStateOpenCloseStatus(BaseComponent):
|
||||
"""
|
||||
Sets the open/close stringency levels for states.
|
||||
Args:
|
||||
n_stringency_levels (int): number of stringency levels the states can chose
|
||||
from. (Must match the number in the model constants dictionary referenced by
|
||||
the parent scenario.)
|
||||
action_cooldown_period (int): action cooldown period in days.
|
||||
Once a stringency level is set, the state(s) cannot switch to another level
|
||||
for a certain number of days (referred to as the "action_cooldown_period")
|
||||
"""
|
||||
|
||||
name = "ControlUSStateOpenCloseStatus"
|
||||
required_entities = []
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
n_stringency_levels=10,
|
||||
action_cooldown_period=28,
|
||||
**base_component_kwargs,
|
||||
):
|
||||
|
||||
self.action_cooldown_period = action_cooldown_period
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
self.np_int_dtype = np.int32
|
||||
|
||||
self.n_stringency_levels = int(n_stringency_levels)
|
||||
assert self.n_stringency_levels >= 2
|
||||
self._checked_n_stringency_levels = False
|
||||
|
||||
self.masks = dict()
|
||||
self.default_agent_action_mask = [1 for _ in range(self.n_stringency_levels)]
|
||||
self.no_op_agent_action_mask = [0 for _ in range(self.n_stringency_levels)]
|
||||
self.masks["a"] = np.repeat(
|
||||
np.array(self.no_op_agent_action_mask)[:, np.newaxis],
|
||||
self.n_agents,
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# (This will be overwritten during reset; see below)
|
||||
self.action_in_cooldown_until = None
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
# Store the times when the next set of actions can be taken.
|
||||
self.action_in_cooldown_until = np.array(
|
||||
[self.world.timestep for _ in range(self.n_agents)]
|
||||
)
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return self.n_stringency_levels
|
||||
return None
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
for agent in self.world.agents:
|
||||
if self.world.use_real_world_policies:
|
||||
self.masks["a"][:, agent.idx] = self.default_agent_action_mask
|
||||
else:
|
||||
if self.world.timestep < self.action_in_cooldown_until[agent.idx]:
|
||||
# Keep masking the actions
|
||||
self.masks["a"][:, agent.idx] = self.no_op_agent_action_mask
|
||||
else: # self.world.timestep == self.action_in_cooldown_until[agent.idx]
|
||||
# Cooldown period has ended; unmask the "subsequent" action
|
||||
self.masks["a"][:, agent.idx] = self.default_agent_action_mask
|
||||
return self.masks
|
||||
|
||||
def get_data_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of data to push to the GPU (device).
|
||||
"""
|
||||
data_dict = DataFeed()
|
||||
data_dict.add_data(
|
||||
name="action_cooldown_period",
|
||||
data=self.action_cooldown_period,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="action_in_cooldown_until",
|
||||
data=self.action_in_cooldown_until,
|
||||
save_copy_and_apply_at_reset=True,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="num_stringency_levels",
|
||||
data=self.n_stringency_levels,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="default_agent_action_mask",
|
||||
data=[1] + self.default_agent_action_mask,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="no_op_agent_action_mask",
|
||||
data=[1] + self.no_op_agent_action_mask,
|
||||
)
|
||||
return data_dict
|
||||
|
||||
def get_tensor_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of (Pytorch-accessible) data to push to the GPU (device).
|
||||
"""
|
||||
tensor_dict = DataFeed()
|
||||
return tensor_dict
|
||||
|
||||
def component_step(self):
|
||||
if self.world.use_cuda:
|
||||
self.world.cuda_component_step[self.name](
|
||||
self.world.cuda_data_manager.device_data("stringency_level"),
|
||||
self.world.cuda_data_manager.device_data("action_cooldown_period"),
|
||||
self.world.cuda_data_manager.device_data("action_in_cooldown_until"),
|
||||
self.world.cuda_data_manager.device_data("default_agent_action_mask"),
|
||||
self.world.cuda_data_manager.device_data("no_op_agent_action_mask"),
|
||||
self.world.cuda_data_manager.device_data("num_stringency_levels"),
|
||||
self.world.cuda_data_manager.device_data(f"{_ACTIONS}_a"),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_a_{self.name}-agent_policy_indicators"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_a_action_mask"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_p_{self.name}-agent_policy_indicators"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data("_timestep_"),
|
||||
self.world.cuda_data_manager.meta_info("n_agents"),
|
||||
self.world.cuda_data_manager.meta_info("episode_length"),
|
||||
block=self.world.cuda_function_manager.block,
|
||||
grid=self.world.cuda_function_manager.grid,
|
||||
)
|
||||
else:
|
||||
if not self._checked_n_stringency_levels:
|
||||
if self.n_stringency_levels != self.world.n_stringency_levels:
|
||||
raise ValueError(
|
||||
"The environment was not configured correctly. For the given "
|
||||
"model fit, you need to set the number of stringency levels to "
|
||||
"be {}".format(self.world.n_stringency_levels)
|
||||
)
|
||||
self._checked_n_stringency_levels = True
|
||||
|
||||
for agent in self.world.agents:
|
||||
if self.world.use_real_world_policies:
|
||||
# Use the action taken in the previous timestep
|
||||
action = self.world.real_world_stringency_policy[
|
||||
self.world.timestep - 1, agent.idx
|
||||
]
|
||||
else:
|
||||
action = agent.get_component_action(self.name)
|
||||
assert 0 <= action <= self.n_stringency_levels
|
||||
|
||||
# We only update the stringency level if the action is not a NO-OP.
|
||||
self.world.global_state["Stringency Level"][
|
||||
self.world.timestep, agent.idx
|
||||
] = (
|
||||
self.world.global_state["Stringency Level"][
|
||||
self.world.timestep - 1, agent.idx
|
||||
]
|
||||
* (action == 0)
|
||||
+ action
|
||||
)
|
||||
|
||||
agent.state[
|
||||
"Current Open Close Stringency Level"
|
||||
] = self.world.global_state["Stringency Level"][
|
||||
self.world.timestep, agent.idx
|
||||
]
|
||||
|
||||
# Check if the action cooldown period has ended, and set the next
|
||||
# time until action cooldown. If current action is a no-op
|
||||
# (i.e., no new action was taken), the agent can take an action
|
||||
# in the very next step, otherwise it needs to wait for
|
||||
# self.action_cooldown_period steps. When in the action cooldown
|
||||
# period, whatever actions the agents take are masked out,
|
||||
# so it's always a NO-OP (see generate_masks() above)
|
||||
# The logic below influences the action masks.
|
||||
if self.world.timestep == self.action_in_cooldown_until[agent.idx] + 1:
|
||||
if action == 0: # NO-OP
|
||||
self.action_in_cooldown_until[agent.idx] += 1
|
||||
else:
|
||||
self.action_in_cooldown_until[
|
||||
agent.idx
|
||||
] += self.action_cooldown_period
|
||||
|
||||
def generate_observations(self):
|
||||
|
||||
# Normalized observations
|
||||
obs_dict = dict()
|
||||
agent_policy_indicators = self.world.global_state["Stringency Level"][
|
||||
self.world.timestep
|
||||
]
|
||||
obs_dict["a"] = {
|
||||
"agent_policy_indicators": agent_policy_indicators
|
||||
/ self.n_stringency_levels
|
||||
}
|
||||
obs_dict[self.world.planner.idx] = {
|
||||
"agent_policy_indicators": agent_policy_indicators
|
||||
/ self.n_stringency_levels
|
||||
}
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class FederalGovernmentSubsidy(BaseComponent):
|
||||
"""
|
||||
Args:
|
||||
subsidy_interval (int): The number of days over which the total subsidy amount
|
||||
is evenly rolled out.
|
||||
Note: shortening the subsidy interval increases the total amount of money
|
||||
that the planner could possibly spend. For instance, if the subsidy
|
||||
interval is 30, the planner can create a subsidy every 30 days.
|
||||
num_subsidy_levels (int): The number of subsidy levels.
|
||||
Note: with max_annual_subsidy_per_person=10000, one round of subsidies at
|
||||
the maximum subsidy level equals an expenditure of roughly $3.3 trillion
|
||||
(given the US population of 330 million).
|
||||
If the planner chooses the maximum subsidy amount, the $3.3 trillion
|
||||
is rolled out gradually over the subsidy interval.
|
||||
max_annual_subsidy_per_person (float): The maximum annual subsidy that may be
|
||||
allocated per person.
|
||||
"""
|
||||
|
||||
name = "FederalGovernmentSubsidy"
|
||||
required_entities = []
|
||||
agent_subclasses = ["BasicPlanner"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
subsidy_interval=90,
|
||||
num_subsidy_levels=20,
|
||||
max_annual_subsidy_per_person=20000,
|
||||
**base_component_kwargs,
|
||||
):
|
||||
self.subsidy_interval = int(subsidy_interval)
|
||||
assert self.subsidy_interval >= 1
|
||||
|
||||
self.num_subsidy_levels = int(num_subsidy_levels)
|
||||
assert self.num_subsidy_levels >= 1
|
||||
|
||||
self.max_annual_subsidy_per_person = float(max_annual_subsidy_per_person)
|
||||
assert self.max_annual_subsidy_per_person >= 0
|
||||
|
||||
self.np_int_dtype = np.int32
|
||||
|
||||
# (This will be overwritten during component_step; see below)
|
||||
self._subsidy_amount_per_level = None
|
||||
self._subsidy_level_array = None
|
||||
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
self.default_planner_action_mask = [1 for _ in range(self.num_subsidy_levels)]
|
||||
self.no_op_planner_action_mask = [0 for _ in range(self.num_subsidy_levels)]
|
||||
|
||||
# (This will be overwritten during reset; see below)
|
||||
self.max_daily_subsidy_per_state = np.array(
|
||||
self.n_agents, dtype=self.np_int_dtype
|
||||
)
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicPlanner":
|
||||
return {"Total Subsidy": 0, "Current Subsidy Level": 0}
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
# Pre-compute maximum state-specific subsidy levels
|
||||
self.max_daily_subsidy_per_state = (
|
||||
self.world.us_state_population * self.max_annual_subsidy_per_person / 365
|
||||
)
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicPlanner":
|
||||
# Number of non-zero subsidy levels
|
||||
# (the action 0 pertains to the no-subsidy case)
|
||||
return self.num_subsidy_levels
|
||||
return None
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
masks = {}
|
||||
if self.world.use_real_world_policies:
|
||||
masks[self.world.planner.idx] = self.default_planner_action_mask
|
||||
else:
|
||||
if self.world.timestep % self.subsidy_interval == 0:
|
||||
masks[self.world.planner.idx] = self.default_planner_action_mask
|
||||
else:
|
||||
masks[self.world.planner.idx] = self.no_op_planner_action_mask
|
||||
return masks
|
||||
|
||||
def get_data_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of data to push to the device
|
||||
"""
|
||||
data_dict = DataFeed()
|
||||
data_dict.add_data(
|
||||
name="subsidy_interval",
|
||||
data=self.subsidy_interval,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="num_subsidy_levels",
|
||||
data=self.num_subsidy_levels,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="max_daily_subsidy_per_state",
|
||||
data=self.max_daily_subsidy_per_state,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="default_planner_action_mask",
|
||||
data=[1] + self.default_planner_action_mask,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="no_op_planner_action_mask",
|
||||
data=[1] + self.no_op_planner_action_mask,
|
||||
)
|
||||
return data_dict
|
||||
|
||||
def get_tensor_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of (Pytorch-accessible) data to push to the device
|
||||
"""
|
||||
tensor_dict = DataFeed()
|
||||
return tensor_dict
|
||||
|
||||
def component_step(self):
|
||||
if self.world.use_cuda:
|
||||
self.world.cuda_component_step[self.name](
|
||||
self.world.cuda_data_manager.device_data("subsidy_level"),
|
||||
self.world.cuda_data_manager.device_data("subsidy"),
|
||||
self.world.cuda_data_manager.device_data("subsidy_interval"),
|
||||
self.world.cuda_data_manager.device_data("num_subsidy_levels"),
|
||||
self.world.cuda_data_manager.device_data("max_daily_subsidy_per_state"),
|
||||
self.world.cuda_data_manager.device_data("default_planner_action_mask"),
|
||||
self.world.cuda_data_manager.device_data("no_op_planner_action_mask"),
|
||||
self.world.cuda_data_manager.device_data(f"{_ACTIONS}_p"),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_a_{self.name}-t_until_next_subsidy"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_a_{self.name}-current_subsidy_level"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_p_{self.name}-t_until_next_subsidy"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_p_{self.name}-current_subsidy_level"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_p_action_mask"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data("_timestep_"),
|
||||
self.world.cuda_data_manager.meta_info("n_agents"),
|
||||
self.world.cuda_data_manager.meta_info("episode_length"),
|
||||
block=self.world.cuda_function_manager.block,
|
||||
grid=self.world.cuda_function_manager.grid,
|
||||
)
|
||||
else:
|
||||
if self.world.use_real_world_policies:
|
||||
if self._subsidy_amount_per_level is None:
|
||||
self._subsidy_amount_per_level = (
|
||||
self.world.us_population
|
||||
* self.max_annual_subsidy_per_person
|
||||
/ self.num_subsidy_levels
|
||||
* self.subsidy_interval
|
||||
/ 365
|
||||
)
|
||||
self._subsidy_level_array = np.zeros((self._episode_length + 1))
|
||||
# Use the action taken in the previous timestep
|
||||
current_subsidy_amount = self.world.real_world_subsidy[
|
||||
self.world.timestep - 1
|
||||
]
|
||||
if current_subsidy_amount > 0:
|
||||
_subsidy_level = np.round(
|
||||
(current_subsidy_amount / self._subsidy_amount_per_level)
|
||||
)
|
||||
for t_idx in range(
|
||||
self.world.timestep - 1,
|
||||
min(
|
||||
len(self._subsidy_level_array),
|
||||
self.world.timestep - 1 + self.subsidy_interval,
|
||||
),
|
||||
):
|
||||
self._subsidy_level_array[t_idx] += _subsidy_level
|
||||
subsidy_level = self._subsidy_level_array[self.world.timestep - 1]
|
||||
else:
|
||||
# Update the subsidy level only every self.subsidy_interval, since the
|
||||
# other actions are masked out.
|
||||
if (self.world.timestep - 1) % self.subsidy_interval == 0:
|
||||
subsidy_level = self.world.planner.get_component_action(self.name)
|
||||
else:
|
||||
subsidy_level = self.world.planner.state["Current Subsidy Level"]
|
||||
|
||||
assert 0 <= subsidy_level <= self.num_subsidy_levels
|
||||
self.world.planner.state["Current Subsidy Level"] = np.array(
|
||||
subsidy_level
|
||||
).astype(self.np_int_dtype)
|
||||
|
||||
# Update subsidy level
|
||||
subsidy_level_frac = subsidy_level / self.num_subsidy_levels
|
||||
daily_statewise_subsidy = (
|
||||
subsidy_level_frac * self.max_daily_subsidy_per_state
|
||||
)
|
||||
|
||||
self.world.global_state["Subsidy"][
|
||||
self.world.timestep
|
||||
] = daily_statewise_subsidy
|
||||
self.world.planner.state["Total Subsidy"] += np.sum(daily_statewise_subsidy)
|
||||
|
||||
def generate_observations(self):
|
||||
# Allow the agents/planner to know when the next subsidy might come.
|
||||
# Obs should = 0 when the next timestep could include a subsidy
|
||||
t_since_last_subsidy = self.world.timestep % self.subsidy_interval
|
||||
# (this is normalized to 0<-->1)
|
||||
t_until_next_subsidy = self.subsidy_interval - t_since_last_subsidy
|
||||
t_vec = t_until_next_subsidy * np.ones(self.n_agents)
|
||||
|
||||
current_subsidy_level = self.world.planner.state["Current Subsidy Level"]
|
||||
sl_vec = current_subsidy_level * np.ones(self.n_agents)
|
||||
|
||||
# Normalized observations
|
||||
obs_dict = dict()
|
||||
obs_dict["a"] = {
|
||||
"t_until_next_subsidy": t_vec / self.subsidy_interval,
|
||||
"current_subsidy_level": sl_vec / self.num_subsidy_levels,
|
||||
}
|
||||
obs_dict[self.world.planner.idx] = {
|
||||
"t_until_next_subsidy": t_until_next_subsidy / self.subsidy_interval,
|
||||
"current_subsidy_level": current_subsidy_level / self.num_subsidy_levels,
|
||||
}
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class VaccinationCampaign(BaseComponent):
|
||||
"""
|
||||
Implements a (passive) component for delivering vaccines to agents once a certain
|
||||
amount of time has elapsed.
|
||||
|
||||
Args:
|
||||
daily_vaccines_per_million_people (int): The number of vaccines available per
|
||||
million people everyday.
|
||||
delivery_interval (int): The number of days between vaccine deliveries.
|
||||
vaccine_delivery_start_date (string): The date (YYYY-MM-DD) when the
|
||||
vaccination begins.
|
||||
"""
|
||||
|
||||
name = "VaccinationCampaign"
|
||||
required_entities = []
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
daily_vaccines_per_million_people=4500,
|
||||
delivery_interval=1,
|
||||
vaccine_delivery_start_date="2020-12-22",
|
||||
observe_rate=False,
|
||||
**base_component_kwargs,
|
||||
):
|
||||
self.daily_vaccines_per_million_people = int(daily_vaccines_per_million_people)
|
||||
assert 0 <= self.daily_vaccines_per_million_people <= 1e6
|
||||
|
||||
self.delivery_interval = int(delivery_interval)
|
||||
assert 1 <= self.delivery_interval <= 5000
|
||||
|
||||
try:
|
||||
self.vaccine_delivery_start_date = datetime.strptime(
|
||||
vaccine_delivery_start_date, "%Y-%m-%d"
|
||||
)
|
||||
except ValueError:
|
||||
print("Incorrect data format, should be YYYY-MM-DD")
|
||||
|
||||
# (This will be overwritten during component_step (see below))
|
||||
self._time_when_vaccine_delivery_begins = None
|
||||
|
||||
self.np_int_dtype = np.int32
|
||||
|
||||
self.observe_rate = bool(observe_rate)
|
||||
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
# (This will be overwritten during reset; see below)
|
||||
self._num_vaccines_per_delivery = None
|
||||
# Convenience for obs (see usage below):
|
||||
self._t_first_delivery = None
|
||||
|
||||
@property
|
||||
def num_vaccines_per_delivery(self):
|
||||
if self._num_vaccines_per_delivery is None:
|
||||
# Pre-compute dispersal numbers
|
||||
millions_of_residents = self.world.us_state_population / 1e6
|
||||
daily_vaccines = (
|
||||
millions_of_residents * self.daily_vaccines_per_million_people
|
||||
)
|
||||
num_vaccines_per_delivery = np.floor(
|
||||
self.delivery_interval * daily_vaccines
|
||||
)
|
||||
self._num_vaccines_per_delivery = np.array(
|
||||
num_vaccines_per_delivery, dtype=self.np_int_dtype
|
||||
)
|
||||
return self._num_vaccines_per_delivery
|
||||
|
||||
@property
|
||||
def time_when_vaccine_delivery_begins(self):
|
||||
if self._time_when_vaccine_delivery_begins is None:
|
||||
self._time_when_vaccine_delivery_begins = (
|
||||
self.vaccine_delivery_start_date - self.world.start_date
|
||||
).days
|
||||
return self._time_when_vaccine_delivery_begins
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return {"Total Vaccinated": 0, "Vaccines Available": 0}
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
pass
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
return # Passive component
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
return {} # Passive component
|
||||
|
||||
def get_data_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of data to push to the device
|
||||
"""
|
||||
data_dict = DataFeed()
|
||||
data_dict.add_data(
|
||||
name="num_vaccines_per_delivery",
|
||||
data=self.num_vaccines_per_delivery,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="delivery_interval",
|
||||
data=self.delivery_interval,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="time_when_vaccine_delivery_begins",
|
||||
data=self.time_when_vaccine_delivery_begins,
|
||||
)
|
||||
data_dict.add_data(
|
||||
name="num_vaccines_available_t",
|
||||
data=np.zeros(self.n_agents),
|
||||
save_copy_and_apply_at_reset=True,
|
||||
)
|
||||
return data_dict
|
||||
|
||||
def get_tensor_dictionary(self):
|
||||
"""
|
||||
Create a dictionary of (Pytorch-accessible) data to push to the device
|
||||
"""
|
||||
tensor_dict = DataFeed()
|
||||
return tensor_dict
|
||||
|
||||
def component_step(self):
|
||||
if self.world.use_cuda:
|
||||
self.world.cuda_component_step[self.name](
|
||||
self.world.cuda_data_manager.device_data("vaccinated"),
|
||||
self.world.cuda_data_manager.device_data("num_vaccines_per_delivery"),
|
||||
self.world.cuda_data_manager.device_data("num_vaccines_available_t"),
|
||||
self.world.cuda_data_manager.device_data("delivery_interval"),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
"time_when_vaccine_delivery_begins"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_a_{self.name}-t_until_next_vaccines"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data(
|
||||
f"{_OBSERVATIONS}_p_{self.name}-t_until_next_vaccines"
|
||||
),
|
||||
self.world.cuda_data_manager.device_data("_timestep_"),
|
||||
self.world.cuda_data_manager.meta_info("n_agents"),
|
||||
self.world.cuda_data_manager.meta_info("episode_length"),
|
||||
block=self.world.cuda_function_manager.block,
|
||||
grid=self.world.cuda_function_manager.grid,
|
||||
)
|
||||
else:
|
||||
# Do nothing if vaccines are not available yet
|
||||
if self.world.timestep < self.time_when_vaccine_delivery_begins:
|
||||
return
|
||||
|
||||
# Do nothing if this is not the start of a delivery interval.
|
||||
# Vaccines are delivered at the start of each interval.
|
||||
if (self.world.timestep % self.delivery_interval) != 0:
|
||||
return
|
||||
|
||||
# Deliver vaccines to each state
|
||||
for aidx, vaccines in enumerate(self.num_vaccines_per_delivery):
|
||||
self.world.agents[aidx].state["Vaccines Available"] += vaccines
|
||||
|
||||
def generate_observations(self):
|
||||
# Allow the agents/planner to know when the next vaccines might come.
|
||||
# Obs should = 0 when the next timestep will deliver vaccines
|
||||
# (this is normalized to 0<-->1)
|
||||
|
||||
if self._t_first_delivery is None:
|
||||
self._t_first_delivery = int(self.time_when_vaccine_delivery_begins)
|
||||
while (self._t_first_delivery % self.delivery_interval) != 0:
|
||||
self._t_first_delivery += 1
|
||||
|
||||
next_t = self.world.timestep + 1
|
||||
if next_t <= self._t_first_delivery:
|
||||
t_until_next_vac = np.minimum(
|
||||
1, (self._t_first_delivery - next_t) / self.delivery_interval
|
||||
)
|
||||
next_vax_rate = 0.0
|
||||
else:
|
||||
t_since_last_vac = next_t % self.delivery_interval
|
||||
t_until_next_vac = self.delivery_interval - t_since_last_vac
|
||||
next_vax_rate = self.daily_vaccines_per_million_people / 1e6
|
||||
t_vec = t_until_next_vac * np.ones(self.n_agents)
|
||||
r_vec = next_vax_rate * np.ones(self.n_agents)
|
||||
|
||||
# Normalized observations
|
||||
obs_dict = dict()
|
||||
obs_dict["a"] = {"t_until_next_vaccines": t_vec / self.delivery_interval}
|
||||
obs_dict[self.world.planner.idx] = {
|
||||
"t_until_next_vaccines": t_until_next_vac / self.delivery_interval
|
||||
}
|
||||
|
||||
if self.observe_rate:
|
||||
obs_dict["a"]["next_vaccination_rate"] = r_vec
|
||||
obs_dict["p"]["next_vaccination_rate"] = float(next_vax_rate)
|
||||
|
||||
return obs_dict
|
||||
263
ai_economist/foundation/components/covid19_components_step.cu
Normal file
263
ai_economist/foundation/components/covid19_components_step.cu
Normal file
@@ -0,0 +1,263 @@
|
||||
// Copyright (c) 2021, salesforce.com, inc.
|
||||
// All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
// For full license text, see the LICENSE file in the repo root
|
||||
// or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
extern "C" {
|
||||
// CUDA version of the components in
|
||||
// "ai_economist.foundation.components.covid19_components.py"
|
||||
__global__ void CudaControlUSStateOpenCloseStatusStep(
|
||||
int * stringency_level,
|
||||
const int kActionCooldownPeriod,
|
||||
int * action_in_cooldown_until,
|
||||
const int * kDefaultAgentActionMask,
|
||||
const int * kNoOpAgentActionMask,
|
||||
const int kNumStringencyLevels,
|
||||
int * actions,
|
||||
float * obs_a_stringency_policy_indicators,
|
||||
float * obs_a_action_mask,
|
||||
float * obs_p_stringency_policy_indicators,
|
||||
int * env_timestep_arr,
|
||||
const int kNumAgents,
|
||||
const int kEpisodeLength
|
||||
) {
|
||||
const int kEnvId = blockIdx.x;
|
||||
const int kAgentId = threadIdx.x;
|
||||
|
||||
// Increment time ONCE -- only 1 thread can do this.
|
||||
if (kAgentId == 0) {
|
||||
env_timestep_arr[kEnvId] += 1;
|
||||
}
|
||||
|
||||
// Wait here until timestep has been updated
|
||||
__syncthreads();
|
||||
|
||||
assert(env_timestep_arr[kEnvId] > 0 &&
|
||||
env_timestep_arr[kEnvId] <= kEpisodeLength);
|
||||
assert (kAgentId <= kNumAgents - 1);
|
||||
|
||||
// Update the stringency levels for the US states
|
||||
if (kAgentId < (kNumAgents - 1)) {
|
||||
// Indices for time-dependent and time-independent arrays
|
||||
// Time dependent arrays have shapes
|
||||
// (num_envs, kEpisodeLength + 1, kNumAgents - 1)
|
||||
// Time independent arrays have shapes (num_envs, kNumAgents - 1)
|
||||
const int kArrayIdxOffset = kEnvId * (kEpisodeLength + 1) *
|
||||
(kNumAgents - 1);
|
||||
int time_dependent_array_index_curr_t = kArrayIdxOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + kAgentId;
|
||||
int time_dependent_array_index_prev_t = kArrayIdxOffset +
|
||||
(env_timestep_arr[kEnvId] - 1) * (kNumAgents - 1) + kAgentId;
|
||||
const int time_independent_array_index = kEnvId * (kNumAgents - 1) +
|
||||
kAgentId;
|
||||
|
||||
// action is not a NO-OP
|
||||
if (actions[time_independent_array_index] != 0) {
|
||||
stringency_level[time_dependent_array_index_curr_t] =
|
||||
actions[time_independent_array_index];
|
||||
} else {
|
||||
stringency_level[time_dependent_array_index_curr_t] =
|
||||
stringency_level[time_dependent_array_index_prev_t];
|
||||
}
|
||||
|
||||
if (env_timestep_arr[kEnvId] == action_in_cooldown_until[
|
||||
time_independent_array_index] + 1) {
|
||||
if (actions[time_independent_array_index] != 0) {
|
||||
assert(0 <= actions[time_independent_array_index] <=
|
||||
kNumStringencyLevels);
|
||||
action_in_cooldown_until[time_independent_array_index] +=
|
||||
kActionCooldownPeriod;
|
||||
} else {
|
||||
action_in_cooldown_until[time_independent_array_index] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
obs_a_stringency_policy_indicators[
|
||||
time_independent_array_index
|
||||
] = stringency_level[time_dependent_array_index_curr_t] /
|
||||
static_cast<float>(kNumStringencyLevels);
|
||||
|
||||
// CUDA version of generate_masks()
|
||||
for (int action_id = 0; action_id < (kNumStringencyLevels + 1);
|
||||
action_id++) {
|
||||
int action_mask_array_index =
|
||||
kEnvId * (kNumStringencyLevels + 1) *
|
||||
(kNumAgents - 1) + action_id * (kNumAgents - 1) + kAgentId;
|
||||
if (env_timestep_arr[kEnvId] < action_in_cooldown_until[
|
||||
time_independent_array_index]
|
||||
) {
|
||||
obs_a_action_mask[action_mask_array_index] =
|
||||
kNoOpAgentActionMask[action_id];
|
||||
} else {
|
||||
obs_a_action_mask[action_mask_array_index] =
|
||||
kDefaultAgentActionMask[action_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update planner obs after all the agents' obs are updated
|
||||
__syncthreads();
|
||||
|
||||
if (kAgentId == kNumAgents - 1) {
|
||||
for (int ag_id = 0; ag_id < (kNumAgents - 1); ag_id++) {
|
||||
const int kIndex = kEnvId * (kNumAgents - 1) + ag_id;
|
||||
obs_p_stringency_policy_indicators[
|
||||
kIndex
|
||||
] =
|
||||
obs_a_stringency_policy_indicators[
|
||||
kIndex
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CudaFederalGovernmentSubsidyStep(
|
||||
int * subsidy_level,
|
||||
float * subsidy,
|
||||
const int kSubsidyInterval,
|
||||
const int kNumSubsidyLevels,
|
||||
const float * KMaxDailySubsidyPerState,
|
||||
const int * kDefaultPlannerActionMask,
|
||||
const int * kNoOpPlannerActionMask,
|
||||
int * actions,
|
||||
float * obs_a_time_until_next_subsidy,
|
||||
float * obs_a_current_subsidy_level,
|
||||
float * obs_p_time_until_next_subsidy,
|
||||
float * obs_p_current_subsidy_level,
|
||||
float * obs_p_action_mask,
|
||||
int * env_timestep_arr,
|
||||
const int kNumAgents,
|
||||
const int kEpisodeLength
|
||||
) {
|
||||
const int kEnvId = blockIdx.x;
|
||||
const int kAgentId = threadIdx.x;
|
||||
|
||||
assert(env_timestep_arr[kEnvId] > 0 &&
|
||||
env_timestep_arr[kEnvId] <= kEpisodeLength);
|
||||
assert (kAgentId <= kNumAgents - 1);
|
||||
|
||||
int t_since_last_subsidy = env_timestep_arr[kEnvId] %
|
||||
kSubsidyInterval;
|
||||
|
||||
// Setting the (federal government) planner's subsidy level
|
||||
// to be the subsidy level for all the US states
|
||||
if (kAgentId < kNumAgents - 1) {
|
||||
// Indices for time-dependent and time-independent arrays
|
||||
// Time dependent arrays have shapes (num_envs,
|
||||
// kEpisodeLength + 1, kNumAgents - 1)
|
||||
// Time independent arrays have shapes (num_envs, kNumAgents - 1)
|
||||
const int kArrayIdxOffset = kEnvId * (kEpisodeLength + 1) *
|
||||
(kNumAgents - 1);
|
||||
int time_dependent_array_index_curr_t = kArrayIdxOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + kAgentId;
|
||||
int time_dependent_array_index_prev_t = kArrayIdxOffset +
|
||||
(env_timestep_arr[kEnvId] - 1) * (kNumAgents - 1) + kAgentId;
|
||||
const int time_independent_array_index = kEnvId *
|
||||
(kNumAgents - 1) + kAgentId;
|
||||
|
||||
if ((env_timestep_arr[kEnvId] - 1) % kSubsidyInterval == 0) {
|
||||
assert(0 <= actions[kEnvId] <= kNumSubsidyLevels);
|
||||
subsidy_level[time_dependent_array_index_curr_t] =
|
||||
actions[kEnvId];
|
||||
} else {
|
||||
subsidy_level[time_dependent_array_index_curr_t] =
|
||||
subsidy_level[time_dependent_array_index_prev_t];
|
||||
}
|
||||
// Setting the subsidies for the US states
|
||||
// based on the federal government's subsidy level
|
||||
subsidy[time_dependent_array_index_curr_t] =
|
||||
subsidy_level[time_dependent_array_index_curr_t] *
|
||||
KMaxDailySubsidyPerState[kAgentId] / kNumSubsidyLevels;
|
||||
|
||||
obs_a_time_until_next_subsidy[
|
||||
time_independent_array_index] =
|
||||
1 - (t_since_last_subsidy /
|
||||
static_cast<float>(kSubsidyInterval));
|
||||
obs_a_current_subsidy_level[
|
||||
time_independent_array_index] =
|
||||
subsidy_level[time_dependent_array_index_curr_t] /
|
||||
static_cast<float>(kNumSubsidyLevels);
|
||||
} else if (kAgentId == (kNumAgents - 1)) {
|
||||
for (int action_id = 0; action_id < kNumSubsidyLevels + 1;
|
||||
action_id++) {
|
||||
int action_mask_array_index = kEnvId *
|
||||
(kNumSubsidyLevels + 1) + action_id;
|
||||
if (env_timestep_arr[kEnvId] % kSubsidyInterval == 0) {
|
||||
obs_p_action_mask[action_mask_array_index] =
|
||||
kDefaultPlannerActionMask[action_id];
|
||||
} else {
|
||||
obs_p_action_mask[action_mask_array_index] =
|
||||
kNoOpPlannerActionMask[action_id];
|
||||
}
|
||||
}
|
||||
// Update planner obs after the agent's obs are updated
|
||||
__syncthreads();
|
||||
|
||||
if (kAgentId == (kNumAgents - 1)) {
|
||||
// Just use the values for agent id 0
|
||||
obs_p_time_until_next_subsidy[kEnvId] =
|
||||
obs_a_time_until_next_subsidy[
|
||||
kEnvId * (kNumAgents - 1)
|
||||
];
|
||||
obs_p_current_subsidy_level[kEnvId] =
|
||||
obs_a_current_subsidy_level[
|
||||
kEnvId * (kNumAgents - 1)
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CudaVaccinationCampaignStep(
|
||||
int * vaccinated,
|
||||
const int * kNumVaccinesPerDelivery,
|
||||
int * num_vaccines_available_t,
|
||||
const int kDeliveryInterval,
|
||||
const int kTimeWhenVaccineDeliveryBegins,
|
||||
float * obs_a_vaccination_campaign_t_until_next_vaccines,
|
||||
float * obs_p_vaccination_campaign_t_until_next_vaccines,
|
||||
int * env_timestep_arr,
|
||||
int kNumAgents,
|
||||
int kEpisodeLength
|
||||
) {
|
||||
const int kEnvId = blockIdx.x;
|
||||
const int kAgentId = threadIdx.x;
|
||||
|
||||
assert(env_timestep_arr[kEnvId] > 0 && env_timestep_arr[kEnvId] <=
|
||||
kEpisodeLength);
|
||||
assert(kTimeWhenVaccineDeliveryBegins > 0);
|
||||
assert (kAgentId <= kNumAgents - 1);
|
||||
|
||||
// CUDA version of generate observations()
|
||||
int t_first_delivery = kTimeWhenVaccineDeliveryBegins +
|
||||
kTimeWhenVaccineDeliveryBegins % kDeliveryInterval;
|
||||
int next_t = env_timestep_arr[kEnvId] + 1;
|
||||
float t_until_next_vac;
|
||||
if (next_t <= t_first_delivery) {
|
||||
t_until_next_vac = min(
|
||||
1,
|
||||
(t_first_delivery - next_t) / kDeliveryInterval);
|
||||
} else {
|
||||
float t_since_last_vac = next_t % kDeliveryInterval;
|
||||
t_until_next_vac = 1 - (t_since_last_vac / kDeliveryInterval);
|
||||
}
|
||||
|
||||
// Update the vaccinated numbers for just the US states
|
||||
if (kAgentId < (kNumAgents - 1)) {
|
||||
const int time_independent_array_index = kEnvId *
|
||||
(kNumAgents - 1) + kAgentId;
|
||||
if ((env_timestep_arr[kEnvId] >= kTimeWhenVaccineDeliveryBegins) &&
|
||||
(env_timestep_arr[kEnvId] % kDeliveryInterval == 0)) {
|
||||
num_vaccines_available_t[time_independent_array_index] =
|
||||
kNumVaccinesPerDelivery[kAgentId];
|
||||
} else {
|
||||
num_vaccines_available_t[time_independent_array_index] = 0;
|
||||
}
|
||||
obs_a_vaccination_campaign_t_until_next_vaccines[
|
||||
time_independent_array_index] = t_until_next_vac;
|
||||
} else if (kAgentId == kNumAgents - 1) {
|
||||
obs_p_vaccination_campaign_t_until_next_vaccines[kEnvId] =
|
||||
t_until_next_vac;
|
||||
}
|
||||
}
|
||||
}
|
||||
222
ai_economist/foundation/components/move.py
Normal file
222
ai_economist/foundation/components/move.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import rand
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class Gather(BaseComponent):
|
||||
"""
|
||||
Allows mobile agents to move around the world and collect resources and prevents
|
||||
agents from moving to invalid locations.
|
||||
|
||||
Can be configured to include collection skill, where agents have heterogeneous
|
||||
probabilities of collecting bonus resources without additional labor cost.
|
||||
|
||||
Args:
|
||||
move_labor (float): Labor cost associated with movement. Must be >= 0.
|
||||
Default is 1.0.
|
||||
collect_labor (float): Labor cost associated with collecting resources. This
|
||||
cost is added (in addition to any movement cost) when the agent lands on
|
||||
a tile that is populated with resources (triggering collection).
|
||||
Must be >= 0. Default is 1.0.
|
||||
skill_dist (str): Distribution type for sampling skills. Default ("none")
|
||||
gives all agents identical skill equal to a bonus prob of 0. "pareto" and
|
||||
"lognormal" sample skills from the associated distributions.
|
||||
"""
|
||||
|
||||
name = "Gather"
|
||||
required_entities = ["Coin", "House", "Labor"]
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
move_labor=1.0,
|
||||
collect_labor=1.0,
|
||||
skill_dist="none",
|
||||
**base_component_kwargs
|
||||
):
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
self.move_labor = float(move_labor)
|
||||
assert self.move_labor >= 0
|
||||
|
||||
self.collect_labor = float(collect_labor)
|
||||
assert self.collect_labor >= 0
|
||||
|
||||
self.skill_dist = skill_dist.lower()
|
||||
assert self.skill_dist in ["none", "pareto", "lognormal"]
|
||||
|
||||
self.gathers = []
|
||||
|
||||
self._aidx = np.arange(self.n_agents)[:, None].repeat(4, axis=1)
|
||||
self._roff = np.array([[0, 0, -1, 1]])
|
||||
self._coff = np.array([[-1, 1, 0, 0]])
|
||||
|
||||
# Required methods for implementing components
|
||||
# --------------------------------------------
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Adds 4 actions (move up, down, left, or right) for mobile agents.
|
||||
"""
|
||||
# This component adds 4 action that agents can take:
|
||||
# move up, down, left, or right
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return 4
|
||||
return None
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
For mobile agents, add state field for collection skill.
|
||||
"""
|
||||
if agent_cls_name not in self.agent_subclasses:
|
||||
return {}
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return {"bonus_gather_prob": 0.0}
|
||||
raise NotImplementedError
|
||||
|
||||
def component_step(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Move to adjacent, unoccupied locations. Collect resources when moving to
|
||||
populated resource tiles, adding the resource to the agent's inventory and
|
||||
de-populating it from the tile.
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
gathers = []
|
||||
for agent in world.get_random_order_agents():
|
||||
|
||||
if self.name not in agent.action:
|
||||
return
|
||||
action = agent.get_component_action(self.name)
|
||||
|
||||
r, c = [int(x) for x in agent.loc]
|
||||
|
||||
if action == 0: # NO-OP!
|
||||
new_r, new_c = r, c
|
||||
|
||||
elif action <= 4:
|
||||
if action == 1: # Left
|
||||
new_r, new_c = r, c - 1
|
||||
elif action == 2: # Right
|
||||
new_r, new_c = r, c + 1
|
||||
elif action == 3: # Up
|
||||
new_r, new_c = r - 1, c
|
||||
else: # action == 4, # Down
|
||||
new_r, new_c = r + 1, c
|
||||
|
||||
# Attempt to move the agent (if the new coordinates aren't accessible,
|
||||
# nothing will happen)
|
||||
new_r, new_c = world.set_agent_loc(agent, new_r, new_c)
|
||||
|
||||
# If the agent did move, incur the labor cost of moving
|
||||
if (new_r != r) or (new_c != c):
|
||||
agent.state["endogenous"]["Labor"] += self.move_labor
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
for resource, health in world.location_resources(new_r, new_c).items():
|
||||
if health >= 1:
|
||||
n_gathered = 1 + (rand() < agent.state["bonus_gather_prob"])
|
||||
agent.state["inventory"][resource] += n_gathered
|
||||
world.consume_resource(resource, new_r, new_c)
|
||||
# Incur the labor cost of collecting a resource
|
||||
agent.state["endogenous"]["Labor"] += self.collect_labor
|
||||
# Log the gather
|
||||
gathers.append(
|
||||
dict(
|
||||
agent=agent.idx,
|
||||
resource=resource,
|
||||
n=n_gathered,
|
||||
loc=[new_r, new_c],
|
||||
)
|
||||
)
|
||||
|
||||
self.gathers.append(gathers)
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Here, agents observe their collection skill. The planner does not observe
|
||||
anything from this component.
|
||||
"""
|
||||
return {
|
||||
str(agent.idx): {"bonus_gather_prob": agent.state["bonus_gather_prob"]}
|
||||
for agent in self.world.agents
|
||||
}
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Prevent moving to adjacent tiles that are already occupied (or outside the
|
||||
boundaries of the world)
|
||||
"""
|
||||
world = self.world
|
||||
|
||||
coords = np.array([agent.loc for agent in world.agents])[:, :, None]
|
||||
ris = coords[:, 0] + self._roff + 1
|
||||
cis = coords[:, 1] + self._coff + 1
|
||||
|
||||
occ = np.pad(world.maps.unoccupied, ((1, 1), (1, 1)))
|
||||
acc = np.pad(world.maps.accessibility, ((0, 0), (1, 1), (1, 1)))
|
||||
mask_array = np.logical_and(occ[ris, cis], acc[self._aidx, ris, cis]).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
masks = {agent.idx: mask_array[i] for i, agent in enumerate(world.agents)}
|
||||
|
||||
return masks
|
||||
|
||||
# For non-required customization
|
||||
# ------------------------------
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
See base_component.py for detailed description.
|
||||
|
||||
Re-sample agents' collection skills.
|
||||
"""
|
||||
for agent in self.world.agents:
|
||||
if self.skill_dist == "none":
|
||||
bonus_rate = 0.0
|
||||
elif self.skill_dist == "pareto":
|
||||
bonus_rate = np.minimum(2, np.random.pareto(3)) / 2
|
||||
elif self.skill_dist == "lognormal":
|
||||
bonus_rate = np.minimum(2, np.random.lognormal(-2.022, 0.938)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
agent.state["bonus_gather_prob"] = float(bonus_rate)
|
||||
|
||||
self.gathers = []
|
||||
|
||||
def get_dense_log(self):
|
||||
"""
|
||||
Log resource collections.
|
||||
|
||||
Returns:
|
||||
gathers (list): A list of gather events. Each entry corresponds to a single
|
||||
timestep and contains a description of any resource gathers that
|
||||
occurred on that timestep.
|
||||
|
||||
"""
|
||||
return self.gathers
|
||||
1202
ai_economist/foundation/components/redistribution.py
Normal file
1202
ai_economist/foundation/components/redistribution.py
Normal file
File diff suppressed because it is too large
Load Diff
134
ai_economist/foundation/components/simple_labor.py
Normal file
134
ai_economist/foundation/components/simple_labor.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.base_component import (
|
||||
BaseComponent,
|
||||
component_registry,
|
||||
)
|
||||
|
||||
|
||||
@component_registry.add
|
||||
class SimpleLabor(BaseComponent):
|
||||
"""
|
||||
Allows Agents to select a level of labor, which earns income based on skill.
|
||||
|
||||
Labor is "simple" because this simplifies labor to a choice along a 1D axis. More
|
||||
concretely, this component adds 100 labor actions, each representing a choice of
|
||||
how many hours to work, e.g. action 50 represents doing 50 hours of work; each
|
||||
Agent earns income proportional to the product of its labor amount (representing
|
||||
hours worked) and its skill (representing wage), with higher skill and higher labor
|
||||
yielding higher income.
|
||||
|
||||
This component is intended to be used with the 'PeriodicBracketTax' component and
|
||||
the 'one-step-economy' scenario.
|
||||
|
||||
Args:
|
||||
mask_first_step (bool): Defaults to True. If True, masks all non-0 labor
|
||||
actions on the first step of the environment. When combined with the
|
||||
intended component/scenario, the first env step is used to set taxes
|
||||
(via the 'redistribution' component) and the second step is used to
|
||||
select labor (via this component).
|
||||
payment_max_skill_multiplier (float): When determining the skill level of
|
||||
each Agent, sampled skills are clipped to this maximum value.
|
||||
"""
|
||||
|
||||
name = "SimpleLabor"
|
||||
required_entities = ["Coin"]
|
||||
agent_subclasses = ["BasicMobileAgent"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_component_args,
|
||||
mask_first_step=True,
|
||||
payment_max_skill_multiplier=3,
|
||||
pareto_param=4.0,
|
||||
**base_component_kwargs
|
||||
):
|
||||
super().__init__(*base_component_args, **base_component_kwargs)
|
||||
|
||||
# This defines the size of the action space (the max # hours an agent can work).
|
||||
self.num_labor_hours = 100 # max 100 hours
|
||||
|
||||
assert isinstance(mask_first_step, bool)
|
||||
self.mask_first_step = mask_first_step
|
||||
|
||||
self.is_first_step = True
|
||||
self.common_mask_on = {
|
||||
agent.idx: np.ones((self.num_labor_hours,)) for agent in self.world.agents
|
||||
}
|
||||
self.common_mask_off = {
|
||||
agent.idx: np.zeros((self.num_labor_hours,)) for agent in self.world.agents
|
||||
}
|
||||
|
||||
# Skill distribution
|
||||
self.pareto_param = float(pareto_param)
|
||||
assert self.pareto_param > 0
|
||||
self.payment_max_skill_multiplier = float(payment_max_skill_multiplier)
|
||||
pmsm = self.payment_max_skill_multiplier
|
||||
num_agents = len(self.world.agents)
|
||||
# Generate a batch (1000) of num_agents (sorted/clipped) Pareto samples.
|
||||
pareto_samples = np.random.pareto(4, size=(1000, num_agents))
|
||||
clipped_skills = np.minimum(pmsm, (pmsm - 1) * pareto_samples + 1)
|
||||
sorted_clipped_skills = np.sort(clipped_skills, axis=1)
|
||||
# The skill level of the i-th skill-ranked agent is the average of the
|
||||
# i-th ranked samples throughout the batch.
|
||||
self.skills = sorted_clipped_skills.mean(axis=0)
|
||||
|
||||
def get_additional_state_fields(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return {"skill": 0, "production": 0}
|
||||
return {}
|
||||
|
||||
def additional_reset_steps(self):
|
||||
self.is_first_step = True
|
||||
for agent in self.world.agents:
|
||||
agent.state["skill"] = self.skills[agent.idx]
|
||||
|
||||
def get_n_actions(self, agent_cls_name):
|
||||
if agent_cls_name == "BasicMobileAgent":
|
||||
return self.num_labor_hours
|
||||
return None
|
||||
|
||||
def generate_masks(self, completions=0):
|
||||
if self.is_first_step:
|
||||
self.is_first_step = False
|
||||
if self.mask_first_step:
|
||||
return self.common_mask_off
|
||||
|
||||
return self.common_mask_on
|
||||
|
||||
def component_step(self):
|
||||
|
||||
for agent in self.world.get_random_order_agents():
|
||||
|
||||
action = agent.get_component_action(self.name)
|
||||
|
||||
if action == 0: # NO-OP.
|
||||
# Agent is not interacting with this component.
|
||||
continue
|
||||
|
||||
if 1 <= action <= self.num_labor_hours: # set reopening phase
|
||||
|
||||
hours_worked = action # NO-OP is 0 hours.
|
||||
agent.state["endogenous"]["Labor"] = hours_worked
|
||||
|
||||
payoff = hours_worked * agent.state["skill"]
|
||||
agent.state["production"] += payoff
|
||||
agent.inventory["Coin"] += payoff
|
||||
|
||||
else:
|
||||
# If action > num_labor_hours, this is an error.
|
||||
raise ValueError
|
||||
|
||||
def generate_observations(self):
|
||||
obs_dict = dict()
|
||||
for agent in self.world.agents:
|
||||
obs_dict[str(agent.idx)] = {
|
||||
"skill": agent.state["skill"] / self.payment_max_skill_multiplier
|
||||
}
|
||||
return obs_dict
|
||||
115
ai_economist/foundation/components/utils.py
Normal file
115
ai_economist/foundation/components/utils.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def annealed_tax_limit(completions, warmup_period, slope, final_max_tax_value=1.0):
|
||||
"""
|
||||
Compute the maximum tax rate available at this stage of tax annealing.
|
||||
|
||||
This function uses the number of episode completions and the annealing schedule
|
||||
(warmup_period, slope, & final_max_tax_value) to determine what the maximum tax
|
||||
rate can be.
|
||||
This type of annealing allows for a tax curriculum where earlier episodes are
|
||||
restricted to lower tax rates. As more episodes are played, higher tax values are
|
||||
allowed.
|
||||
|
||||
Args:
|
||||
completions (int): Number of times the environment has completed an episode.
|
||||
Expected to be >= 0.
|
||||
warmup_period (int): Until warmup_period completions, only allow 0 tax. Using
|
||||
a negative value will enable non-0 taxes at 0 environment completions.
|
||||
slope (float): After warmup_period completions, percentage of full tax value
|
||||
unmasked with each new completion.
|
||||
final_max_tax_value (float): The maximum tax value at the end of annealing.
|
||||
|
||||
Returns:
|
||||
A scalar value indicating the maximum tax at this stage of annealing.
|
||||
|
||||
Example:
|
||||
>> WARMUP = 100
|
||||
>> SLOPE = 0.01
|
||||
>> annealed_tax_limit(0, WARMUP, SLOPE)
|
||||
0.0
|
||||
>> annealed_tax_limit(100, WARMUP, SLOPE)
|
||||
0.0
|
||||
>> annealed_tax_limit(150, WARMUP, SLOPE)
|
||||
0.5
|
||||
>> annealed_tax_limit(200, WARMUP, SLOPE)
|
||||
1.0
|
||||
>> annealed_tax_limit(1000, WARMUP, SLOPE)
|
||||
1.0
|
||||
"""
|
||||
# What percentage of the full range is currently visible
|
||||
# (between 0 [only 0 tax] and 1 [all taxes visible])
|
||||
percentage_visible = np.maximum(
|
||||
0.0, np.minimum(1.0, slope * (completions - warmup_period))
|
||||
)
|
||||
|
||||
# Determine the highest allowable tax,
|
||||
# given the current position in the annealing schedule
|
||||
current_max_tax = percentage_visible * final_max_tax_value
|
||||
|
||||
return current_max_tax
|
||||
|
||||
|
||||
def annealed_tax_mask(completions, warmup_period, slope, tax_values):
|
||||
"""
|
||||
Generate a mask applied to a set of tax values for the purpose of tax annealing.
|
||||
|
||||
This function uses the number of episode completions and the annealing schedule
|
||||
to determine which of the tax values are considered valid. The most extreme
|
||||
tax/subsidy values are unmasked last. Zero tax is always unmasked (i.e. always
|
||||
valid).
|
||||
This type of annealing allows for a tax curriculum where earlier episodes are
|
||||
restricted to lower tax rates. As more episodes are played, higher tax values are
|
||||
allowed.
|
||||
|
||||
Args:
|
||||
completions (int): Number of times the environment has completed an episode.
|
||||
Expected to be >= 0.
|
||||
warmup_period (int): Until warmup_period completions, only allow 0 tax. Using
|
||||
a negative value will enable non-0 taxes at 0 environment completions.
|
||||
slope (float): After warmup_period completions, percentage of full tax value
|
||||
unmasked with each new completion.
|
||||
tax_values (list): The list of tax values associated with each action to
|
||||
which this mask will apply.
|
||||
|
||||
Returns:
|
||||
A binary mask with same shape as tax_values, indicating which tax values are
|
||||
currently valid.
|
||||
|
||||
Example:
|
||||
>> WARMUP = 100
|
||||
>> SLOPE = 0.01
|
||||
>> TAX_VALUES = [0.0, 0.25, 0.50, 0.75, 1.0]
|
||||
>> annealed_tax_limit(0, WARMUP, SLOPE, TAX_VALUES)
|
||||
[0, 0, 0, 0, 0]
|
||||
>> annealed_tax_limit(100, WARMUP, SLOPE, TAX_VALUES)
|
||||
[0, 0, 0, 0, 0]
|
||||
>> annealed_tax_limit(150, WARMUP, SLOPE, TAX_VALUES)
|
||||
[1, 1, 1, 0, 0]
|
||||
>> annealed_tax_limit(200, WARMUP, SLOPE, TAX_VALUES)
|
||||
[1, 1, 1, 1, 1]
|
||||
>> annealed_tax_limit(1000, WARMUP, SLOPE, TAX_VALUES)
|
||||
[1, 1, 1, 1, 1]
|
||||
"""
|
||||
# Infer the most extreme tax level from the supplied tax values.
|
||||
abs_tax = np.abs(tax_values)
|
||||
full_tax_amount = np.max(abs_tax)
|
||||
|
||||
# Determine the highest allowable tax, given the current position
|
||||
# in the annealing schedule
|
||||
max_absolute_visible_tax = annealed_tax_limit(
|
||||
completions, warmup_period, slope, full_tax_amount
|
||||
)
|
||||
|
||||
# Return a binary mask to allow for taxes
|
||||
# at or below the highest absolute visible tax
|
||||
return np.less_equal(np.abs(tax_values), max_absolute_visible_tax).astype(
|
||||
np.float32
|
||||
)
|
||||
9
ai_economist/foundation/entities/__init__.py
Normal file
9
ai_economist/foundation/entities/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from .endogenous import endogenous_registry
|
||||
from .landmarks import landmark_registry
|
||||
from .resources import resource_registry
|
||||
36
ai_economist/foundation/entities/endogenous.py
Normal file
36
ai_economist/foundation/entities/endogenous.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
|
||||
|
||||
class Endogenous:
|
||||
"""Base class for endogenous entity classes.
|
||||
|
||||
Endogenous entities are those that, conceptually, describe the internal state
|
||||
of an agent. This provides a convenient way to separate physical entities (which
|
||||
may exist in the world, be exchanged among agents, or are otherwise in principal
|
||||
observable by others) from endogenous entities (such as the amount of labor
|
||||
effort an agent has experienced).
|
||||
|
||||
Endogenous entities are registered in the "endogenous" portion of an agent's
|
||||
state and should only be observable by the agent itself.
|
||||
"""
|
||||
|
||||
name = None
|
||||
|
||||
def __init__(self):
|
||||
assert self.name is not None
|
||||
|
||||
|
||||
endogenous_registry = Registry(Endogenous)
|
||||
|
||||
|
||||
@endogenous_registry.add
|
||||
class Labor(Endogenous):
|
||||
"""Labor accumulated through working. Included in all environments by default."""
|
||||
|
||||
name = "Labor"
|
||||
88
ai_economist/foundation/entities/landmarks.py
Normal file
88
ai_economist/foundation/entities/landmarks.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
from ai_economist.foundation.entities.resources import resource_registry
|
||||
|
||||
|
||||
class Landmark:
|
||||
"""Base class for Landmark entity classes.
|
||||
|
||||
Landmark classes describe the entities that exist exclusively in the environment
|
||||
world. In other words, they represent entities that should not be included in an
|
||||
agent's inventory and are only observable through observations from the
|
||||
spatial world.
|
||||
|
||||
Landmark classes describe the following properties:
|
||||
ownable: If each instance of the landmark belongs to an agent. For example, a
|
||||
"House" is ownable and belongs to the agent that constructs it whereas
|
||||
"Water" is not ownable.
|
||||
solid: If the landmark creates a physical barrier to movement (that is,
|
||||
if agents are prevented from occupying cells with the landmark).
|
||||
Importantly, if the landmark is ownable, the agent that owns a given
|
||||
landmark can occupy its cell even if the landmark is solid.
|
||||
"""
|
||||
|
||||
name = None
|
||||
color = None # array of RGB values [0 - 1]
|
||||
ownable = None
|
||||
solid = True # Solid = Cannot be passed through
|
||||
# (unless it is owned by the agent trying to pass through)
|
||||
|
||||
def __init__(self):
|
||||
assert self.name is not None
|
||||
assert self.color is not None
|
||||
assert self.ownable is not None
|
||||
|
||||
# No agent can pass through this landmark
|
||||
self.blocking = self.solid and not self.ownable
|
||||
|
||||
# Only the agent that owns this landmark can pass through it
|
||||
self.private = self.solid and self.ownable
|
||||
|
||||
# This landmark does not belong to any agent and it does not inhibit movement
|
||||
self.public = not self.solid and not self.ownable
|
||||
|
||||
|
||||
landmark_registry = Registry(Landmark)
|
||||
|
||||
# Registering each collectible resource's source block
|
||||
# allows treating source blocks in a specific way
|
||||
for resource_name in resource_registry.entries:
|
||||
resource = resource_registry.get(resource_name)
|
||||
if not resource.collectible:
|
||||
continue
|
||||
|
||||
@landmark_registry.add
|
||||
class SourceBlock(Landmark):
|
||||
"""Special Landmark for generating resources. Not ownable. Not solid."""
|
||||
|
||||
name = "{}SourceBlock".format(resource.name)
|
||||
color = np.array(resource.color)
|
||||
ownable = False
|
||||
solid = False
|
||||
|
||||
|
||||
@landmark_registry.add
|
||||
class House(Landmark):
|
||||
"""House landmark. Ownable. Solid."""
|
||||
|
||||
name = "House"
|
||||
color = np.array([220, 20, 220]) / 255.0
|
||||
ownable = True
|
||||
solid = True
|
||||
|
||||
|
||||
@landmark_registry.add
|
||||
class Water(Landmark):
|
||||
"""Water Landmark. Not ownable. Solid."""
|
||||
|
||||
name = "Water"
|
||||
color = np.array([50, 50, 250]) / 255.0
|
||||
ownable = False
|
||||
solid = True
|
||||
84
ai_economist/foundation/entities/resources.py
Normal file
84
ai_economist/foundation/entities/resources.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.registrar import Registry
|
||||
|
||||
|
||||
class Resource:
|
||||
"""Base class for Resource entity classes.
|
||||
|
||||
Resource classes describe entities that can be a part of an agent's inventory.
|
||||
|
||||
Resources can also be a part of the world as collectible entities: for each
|
||||
Resource class with Resource.collectible=True, a complementary
|
||||
ResourceSourceBlock Landmark class will be created in landmarks.py. For each
|
||||
collectible resource in the environment, the world map will include a resource
|
||||
source block channel (representing landmarks where collectible resources are
|
||||
generated) and a resource channel (representing locations where collectible
|
||||
resources have generated).
|
||||
"""
|
||||
|
||||
name = None
|
||||
color = None # array of RGB values [0 - 1]
|
||||
collectible = None # Is this something that exists in the world?
|
||||
# (versus something that can only be owned)
|
||||
craft_recp = None # dict of recource name and amount
|
||||
craft_labour_base= 0.0
|
||||
|
||||
def __init__(self):
|
||||
assert self.name is not None
|
||||
assert self.color is not None
|
||||
assert self.collectible is not None
|
||||
|
||||
|
||||
resource_registry = Registry(Resource)
|
||||
|
||||
|
||||
@resource_registry.add
|
||||
class Wood(Resource):
|
||||
"""Wood resource. collectible."""
|
||||
|
||||
name = "Wood"
|
||||
color = np.array([107, 143, 113]) / 255.0
|
||||
collectible = True
|
||||
|
||||
|
||||
@resource_registry.add
|
||||
class Stone(Resource):
|
||||
"""Stone resource. collectible."""
|
||||
|
||||
name = "Stone"
|
||||
color = np.array([241, 233, 219]) / 255.0
|
||||
collectible = True
|
||||
|
||||
|
||||
@resource_registry.add
|
||||
class Coin(Resource):
|
||||
"""Coin resource. Included in all environments by default. Not collectible."""
|
||||
|
||||
name = "Coin"
|
||||
color = np.array([229, 211, 82]) / 255.0
|
||||
collectible = False
|
||||
|
||||
@resource_registry.add
|
||||
class RawGem(Resource):
|
||||
"""Raw Gem that can be processed further"""
|
||||
|
||||
name = "Raw_Gem"
|
||||
color = np.array([241, 233, 219]) / 255.0
|
||||
collectible = True
|
||||
|
||||
@resource_registry.add
|
||||
class Gem(Resource):
|
||||
"""Proccesed Gem. Craftable."""
|
||||
|
||||
name = "Gem"
|
||||
color = np.array([241, 233, 219]) / 255.0
|
||||
collectible = False
|
||||
craft_recp= {"Raw_Gem": 1}
|
||||
craft_labour_base= 1
|
||||
418
ai_economist/foundation/env_wrapper.py
Normal file
418
ai_economist/foundation/env_wrapper.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
"""
|
||||
The env wrapper class
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import GPUtil
|
||||
|
||||
try:
|
||||
num_gpus_available = len(GPUtil.getAvailable())
|
||||
print(f"Inside env_wrapper.py: {num_gpus_available} GPUs are available.")
|
||||
if num_gpus_available == 0:
|
||||
print("No GPUs found! Running the simulation on a CPU.")
|
||||
else:
|
||||
from warp_drive.managers.data_manager import CUDADataManager
|
||||
from warp_drive.managers.function_manager import (
|
||||
CUDAEnvironmentReset,
|
||||
CUDAFunctionManager,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Warning: The 'WarpDrive' package is not found and cannot be used! "
|
||||
"If you wish to use WarpDrive, please run "
|
||||
"'pip install rl-warp-drive' first."
|
||||
)
|
||||
except ValueError:
|
||||
print("No GPUs found! Running the simulation on a CPU.")
|
||||
|
||||
import numpy as np
|
||||
from gym.spaces import Box, Dict, Discrete, MultiDiscrete
|
||||
|
||||
BIG_NUMBER = 1e20
|
||||
|
||||
|
||||
def recursive_obs_dict_to_spaces_dict(obs):
|
||||
"""Recursively return the observation space dictionary
|
||||
for a dictionary of observations
|
||||
|
||||
Args:
|
||||
obs (dict): A dictionary of observations keyed by agent index
|
||||
for a multi-agent environment
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary (space.Dict) of observation spaces
|
||||
"""
|
||||
assert isinstance(obs, dict)
|
||||
dict_of_spaces = {}
|
||||
for k, v in obs.items():
|
||||
|
||||
# list of lists are listified np arrays
|
||||
_v = v
|
||||
if isinstance(v, list):
|
||||
_v = np.array(v)
|
||||
elif isinstance(v, (int, np.integer, float, np.floating)):
|
||||
_v = np.array([v])
|
||||
|
||||
# assign Space
|
||||
if isinstance(_v, np.ndarray):
|
||||
x = float(BIG_NUMBER)
|
||||
box = Box(low=-x, high=x, shape=_v.shape, dtype=_v.dtype)
|
||||
low_high_valid = (box.low < 0).all() and (box.high > 0).all()
|
||||
|
||||
# This loop avoids issues with overflow to make sure low/high are good.
|
||||
while not low_high_valid:
|
||||
x = x // 2
|
||||
box = Box(low=-x, high=x, shape=_v.shape, dtype=_v.dtype)
|
||||
low_high_valid = (box.low < 0).all() and (box.high > 0).all()
|
||||
|
||||
dict_of_spaces[k] = box
|
||||
|
||||
elif isinstance(_v, dict):
|
||||
dict_of_spaces[k] = recursive_obs_dict_to_spaces_dict(_v)
|
||||
else:
|
||||
raise TypeError
|
||||
return Dict(dict_of_spaces)
|
||||
|
||||
|
||||
class FoundationEnvWrapper:
|
||||
"""
|
||||
The environment wrapper class for Foundation.
|
||||
This wrapper determines whether the environment reset and steps happen on the
|
||||
CPU or the GPU, and proceeds accordingly.
|
||||
If the environment runs on the CPU, the reset() and step() calls also occur on
|
||||
the CPU.
|
||||
If the environment runs on the GPU, only the first reset() happens on the CPU,
|
||||
all the relevant data is copied over the GPU after, and the subsequent steps
|
||||
all happen on the GPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_obj=None,
|
||||
env_name=None,
|
||||
env_config=None,
|
||||
num_envs=1,
|
||||
use_cuda=False,
|
||||
env_registrar=None,
|
||||
event_messenger=None,
|
||||
process_id=0,
|
||||
):
|
||||
"""
|
||||
'env_obj': an environment object
|
||||
'env_name': an environment name that is registered on the
|
||||
WarpDrive environment registrar
|
||||
'env_config': environment configuration to instantiate
|
||||
an environment from the registrar
|
||||
'use_cuda': if True, step through the environment on the GPU, else on the CPU
|
||||
'num_envs': the number of parallel environments to instantiate. Note: this is
|
||||
only relevant when use_cuda is True
|
||||
'env_registrar': EnvironmentRegistrar object
|
||||
it provides the customized env info (like src path) for the build
|
||||
'event_messenger': multiprocessing Event to sync up the build
|
||||
when using multiple processes
|
||||
'process_id': id of the process running WarpDrive
|
||||
"""
|
||||
# Need to pass in an environment instance
|
||||
if env_obj is not None:
|
||||
self.env = env_obj
|
||||
else:
|
||||
assert (
|
||||
env_name is not None
|
||||
and env_config is not None
|
||||
and env_registrar is not None
|
||||
)
|
||||
self.env = env_registrar.get(env_name, use_cuda)(**env_config)
|
||||
|
||||
self.n_agents = self.env.num_agents
|
||||
self.episode_length = self.env.episode_length
|
||||
|
||||
assert self.env.name
|
||||
self.name = self.env.name
|
||||
|
||||
# Add observation space to the env
|
||||
# --------------------------------
|
||||
# Note: when the collated agent "a" is present, add obs keys
|
||||
# for each individual agent to the env
|
||||
# and remove the collated agent "a" from the observation
|
||||
obs = self.obs_at_reset()
|
||||
self.env.observation_space = recursive_obs_dict_to_spaces_dict(obs)
|
||||
|
||||
# Add action space to the env
|
||||
# ---------------------------
|
||||
self.env.action_space = {}
|
||||
for agent_id in range(len(self.env.world.agents)):
|
||||
if self.env.world.agents[agent_id].multi_action_mode:
|
||||
self.env.action_space[str([agent_id])] = MultiDiscrete(
|
||||
self.env.get_agent(str(agent_id)).action_spaces
|
||||
)
|
||||
else:
|
||||
self.env.action_space[str(agent_id)] = Discrete(
|
||||
self.env.get_agent(str(agent_id)).action_spaces
|
||||
)
|
||||
self.env.action_space[str(agent_id)].dtype = np.int32
|
||||
|
||||
if self.env.world.planner.multi_action_mode:
|
||||
self.env.action_space["p"] = MultiDiscrete(
|
||||
self.env.get_agent("p").action_spaces
|
||||
)
|
||||
else:
|
||||
self.env.action_space["p"] = Discrete(self.env.get_agent("p").action_spaces)
|
||||
self.env.action_space["p"].dtype = np.int32
|
||||
|
||||
# Ensure the observation and action spaces share the same keys
|
||||
assert set(self.env.observation_space.keys()) == set(
|
||||
self.env.action_space.keys()
|
||||
)
|
||||
|
||||
# CUDA-specific initializations
|
||||
# -----------------------------
|
||||
# Flag to determine whether to use CUDA or not
|
||||
self.use_cuda = use_cuda
|
||||
if self.use_cuda:
|
||||
assert len(GPUtil.getAvailable()) > 0, (
|
||||
"The env wrapper needs a GPU to run" " when use_cuda is True!"
|
||||
)
|
||||
assert hasattr(self.env, "use_cuda")
|
||||
assert hasattr(self.env, "cuda_data_manager")
|
||||
assert hasattr(self.env, "cuda_function_manager")
|
||||
|
||||
assert hasattr(self.env.world, "use_cuda")
|
||||
assert hasattr(self.env.world, "cuda_data_manager")
|
||||
assert hasattr(self.env.world, "cuda_function_manager")
|
||||
self.env.use_cuda = use_cuda
|
||||
self.env.world.use_cuda = self.use_cuda
|
||||
|
||||
# Flag to determine where the reset happens (host or device)
|
||||
# First reset is always on the host (CPU), and subsequent resets are on
|
||||
# the device (GPU)
|
||||
self.reset_on_host = True
|
||||
|
||||
# Steps specific to GPU runs
|
||||
# --------------------------
|
||||
if self.use_cuda:
|
||||
logging.info("USING CUDA...")
|
||||
|
||||
# Number of environments to run in parallel
|
||||
assert num_envs >= 1
|
||||
self.n_envs = num_envs
|
||||
|
||||
logging.info("Initializing the CUDA data manager...")
|
||||
self.cuda_data_manager = CUDADataManager(
|
||||
num_agents=self.n_agents,
|
||||
episode_length=self.episode_length,
|
||||
num_envs=self.n_envs,
|
||||
)
|
||||
|
||||
logging.info("Initializing the CUDA function manager...")
|
||||
self.cuda_function_manager = CUDAFunctionManager(
|
||||
num_agents=int(self.cuda_data_manager.meta_info("n_agents")),
|
||||
num_envs=int(self.cuda_data_manager.meta_info("n_envs")),
|
||||
process_id=process_id,
|
||||
)
|
||||
self.cuda_function_manager.compile_and_load_cuda(
|
||||
env_name=self.name,
|
||||
template_header_file="template_env_config.h",
|
||||
template_runner_file="template_env_runner.cu",
|
||||
customized_env_registrar=env_registrar,
|
||||
event_messenger=event_messenger,
|
||||
)
|
||||
|
||||
# Register the CUDA step() function for the env
|
||||
# Note: generate_observation() and compute_reward()
|
||||
# should be part of the step function itself
|
||||
step_function = f"Cuda{self.name}Step"
|
||||
self.cuda_function_manager.initialize_functions([step_function])
|
||||
self.env.cuda_step = self.cuda_function_manager.get_function(step_function)
|
||||
|
||||
# Register additional cuda functions (other than the scenario step)
|
||||
# Component step
|
||||
# Create a cuda_component_step dictionary
|
||||
self.env.world.cuda_component_step = {}
|
||||
for component in self.env.components:
|
||||
self.cuda_function_manager.initialize_functions(
|
||||
["Cuda" + component.name + "Step"]
|
||||
)
|
||||
self.env.world.cuda_component_step[
|
||||
component.name
|
||||
] = self.cuda_function_manager.get_function(
|
||||
"Cuda" + component.name + "Step"
|
||||
)
|
||||
|
||||
# Compute reward
|
||||
self.cuda_function_manager.initialize_functions(["CudaComputeReward"])
|
||||
self.env.cuda_compute_reward = self.cuda_function_manager.get_function(
|
||||
"CudaComputeReward"
|
||||
)
|
||||
|
||||
# Add wrapper attributes for use within env
|
||||
self.env.cuda_data_manager = self.cuda_data_manager
|
||||
self.env.cuda_function_manager = self.cuda_function_manager
|
||||
|
||||
# Register the env resetter
|
||||
self.env_resetter = CUDAEnvironmentReset(
|
||||
function_manager=self.cuda_function_manager
|
||||
)
|
||||
|
||||
# Add to self.env.world for use in components
|
||||
self.env.world.cuda_data_manager = self.cuda_data_manager
|
||||
self.env.world.cuda_function_manager = self.cuda_function_manager
|
||||
|
||||
def reset_all_envs(self):
|
||||
"""
|
||||
Reset the state of the environment to initialize a new episode.
|
||||
if self.reset_on_host is True:
|
||||
calls the CPU env to prepare and return the initial state
|
||||
if self.use_cuda is True:
|
||||
if self.reset_on_host is True:
|
||||
expands initial state to parallel example_envs and push to GPU once
|
||||
sets self.reset_on_host = False
|
||||
else:
|
||||
calls device hard reset managed by the CUDAResetter
|
||||
"""
|
||||
self.env.world.timestep = 0
|
||||
|
||||
if self.reset_on_host:
|
||||
# Produce observation
|
||||
obs = self.obs_at_reset()
|
||||
else:
|
||||
assert self.use_cuda
|
||||
|
||||
if self.use_cuda: # GPU version
|
||||
if self.reset_on_host:
|
||||
|
||||
# Helper function to repeat data across the env dimension
|
||||
def repeat_across_env_dimension(array, num_envs):
|
||||
return np.stack([array for _ in range(num_envs)], axis=0)
|
||||
|
||||
# Copy host data and tensors to device
|
||||
# Note: this happens only once after the first reset on the host
|
||||
|
||||
scenario_and_components = [self.env] + self.env.components
|
||||
|
||||
for item in scenario_and_components:
|
||||
# Add env dimension to data
|
||||
# if "save_copy_and_apply_at_reset" is True
|
||||
data_dictionary = item.get_data_dictionary()
|
||||
tensor_dictionary = item.get_tensor_dictionary()
|
||||
for key in data_dictionary:
|
||||
if data_dictionary[key]["attributes"][
|
||||
"save_copy_and_apply_at_reset"
|
||||
]:
|
||||
data_dictionary[key]["data"] = repeat_across_env_dimension(
|
||||
data_dictionary[key]["data"], self.n_envs
|
||||
)
|
||||
|
||||
for key in tensor_dictionary:
|
||||
if tensor_dictionary[key]["attributes"][
|
||||
"save_copy_and_apply_at_reset"
|
||||
]:
|
||||
tensor_dictionary[key][
|
||||
"data"
|
||||
] = repeat_across_env_dimension(
|
||||
tensor_dictionary[key]["data"], self.n_envs
|
||||
)
|
||||
|
||||
self.cuda_data_manager.push_data_to_device(data_dictionary)
|
||||
|
||||
self.cuda_data_manager.push_data_to_device(
|
||||
tensor_dictionary, torch_accessible=True
|
||||
)
|
||||
|
||||
# All subsequent resets happen on the GPU
|
||||
self.reset_on_host = False
|
||||
|
||||
# Return the obs
|
||||
return obs
|
||||
# Returns an empty dictionary for all subsequent resets on the GPU
|
||||
# as arrays are modified in place
|
||||
self.env_resetter.reset_when_done(
|
||||
self.cuda_data_manager, mode="force_reset"
|
||||
)
|
||||
return {}
|
||||
return obs # CPU version
|
||||
|
||||
def reset_only_done_envs(self):
|
||||
"""
|
||||
This function only works for GPU example_envs.
|
||||
It will check all the running example_envs,
|
||||
and only resets those example_envs that are observing done flag is True
|
||||
"""
|
||||
assert self.use_cuda and not self.reset_on_host, (
|
||||
"reset_only_done_envs() only works "
|
||||
"for self.use_cuda = True and self.reset_on_host = False"
|
||||
)
|
||||
|
||||
self.env_resetter.reset_when_done(self.cuda_data_manager, mode="if_done")
|
||||
return {}
|
||||
|
||||
def step_all_envs(self, actions=None):
|
||||
"""
|
||||
Step through all the environments' components and scenario
|
||||
"""
|
||||
if self.use_cuda:
|
||||
# Step through each component
|
||||
for component in self.env.components:
|
||||
component.component_step()
|
||||
|
||||
# Scenario step
|
||||
self.env.scenario_step()
|
||||
|
||||
# Compute rewards
|
||||
self.env.generate_rewards()
|
||||
|
||||
result = None # Do not return anything
|
||||
else:
|
||||
assert actions is not None, "Please provide actions to step with."
|
||||
obs, rew, done, info = self.env.step(actions)
|
||||
obs = self._reformat_obs(obs)
|
||||
rew = self._reformat_rew(rew)
|
||||
result = obs, rew, done, info
|
||||
return result
|
||||
|
||||
def obs_at_reset(self):
|
||||
"""
|
||||
Calls the (Python) env to reset and return the initial state
|
||||
"""
|
||||
obs = self.env.reset()
|
||||
obs = self._reformat_obs(obs)
|
||||
return obs
|
||||
|
||||
def _reformat_obs(self, obs):
|
||||
if "a" in obs:
|
||||
# This means the env uses collated obs.
|
||||
# Set each individual agent as obs keys for processing with WarpDrive.
|
||||
for agent_id in range(self.env.n_agents):
|
||||
obs[str(agent_id)] = {}
|
||||
for key in obs["a"].keys():
|
||||
obs[str(agent_id)][key] = obs["a"][key][..., agent_id]
|
||||
del obs["a"] # remove the key "a"
|
||||
return obs
|
||||
|
||||
def _reformat_rew(self, rew):
|
||||
if "a" in rew:
|
||||
# This means the env uses collated rew.
|
||||
# Set each individual agent as rew keys for processing with WarpDrive.
|
||||
assert isinstance(rew, dict)
|
||||
for agent_id in range(self.env.n_agents):
|
||||
rew[str(agent_id)] = rew["a"][agent_id]
|
||||
del rew["a"] # remove the key "a"
|
||||
return rew
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Alias for reset_all_envs() when CPU is used (conforms to gym-style)
|
||||
"""
|
||||
return self.reset_all_envs()
|
||||
|
||||
def step(self, actions=None):
|
||||
"""
|
||||
Alias for step_all_envs() when CPU is used (conforms to gym-style)
|
||||
"""
|
||||
return self.step_all_envs(actions)
|
||||
14
ai_economist/foundation/scenarios/__init__.py
Normal file
14
ai_economist/foundation/scenarios/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from ai_economist.foundation.base.base_env import scenario_registry
|
||||
|
||||
from .covid19 import covid19_env
|
||||
from .one_step_economy import one_step_economy
|
||||
from .simple_wood_and_stone import dynamic_layout, layout_from_file
|
||||
|
||||
# Import files that add Scenario class(es) to scenario_registry
|
||||
# -------------------------------------------------------------
|
||||
5
ai_economist/foundation/scenarios/covid19/__init__.py
Normal file
5
ai_economist/foundation/scenarios/covid19/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
13
ai_economist/foundation/scenarios/covid19/covid19_build.cu
Normal file
13
ai_economist/foundation/scenarios/covid19/covid19_build.cu
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright (c) 2021, salesforce.com, inc.
|
||||
// All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
// For full license text, see the LICENSE file in the repo root
|
||||
// or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
#ifndef CUDA_INCLUDES_COVID19_CONST_H_
|
||||
#define CUDA_INCLUDES_COVID19_CONST_H_
|
||||
|
||||
#include "../../components/covid19_components_step.cu"
|
||||
#include "covid19_env_step.cu"
|
||||
|
||||
#endif // CUDA_INCLUDES_COVID19_CONST_H_
|
||||
1687
ai_economist/foundation/scenarios/covid19/covid19_env.py
Normal file
1687
ai_economist/foundation/scenarios/covid19/covid19_env.py
Normal file
File diff suppressed because it is too large
Load Diff
620
ai_economist/foundation/scenarios/covid19/covid19_env_step.cu
Normal file
620
ai_economist/foundation/scenarios/covid19/covid19_env_step.cu
Normal file
@@ -0,0 +1,620 @@
|
||||
// Copyright (c) 2021, salesforce.com, inc.
|
||||
// All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
// For full license text, see the LICENSE file in the repo root
|
||||
// or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
__constant__ float kEpsilon = 1.0e-10; // used to prevent division by 0
|
||||
|
||||
extern "C" {
|
||||
// CUDA version of the scenario_step() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
|
||||
// CUDA version of the sir_step() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ void cuda_sir_step(
|
||||
float* susceptible,
|
||||
float* infected,
|
||||
float* recovered,
|
||||
float* vaccinated,
|
||||
float* deaths,
|
||||
int* num_vaccines_available_t,
|
||||
const int* kRealWorldStringencyPolicyHistory,
|
||||
const float kStatePopulation,
|
||||
const int kNumAgents,
|
||||
const int kBetaDelay,
|
||||
const float kBetaSlope,
|
||||
const float kbetaIntercept,
|
||||
int* stringency_level,
|
||||
float* beta,
|
||||
const float kGamma,
|
||||
const float kDeathRate,
|
||||
const int kEnvId,
|
||||
const int kAgentId,
|
||||
int timestep,
|
||||
const int kEpisodeLength,
|
||||
const int kArrayIdxCurrentTime,
|
||||
const int kArrayIdxPrevTime,
|
||||
const int kTimeIndependentArrayIdx
|
||||
) {
|
||||
float susceptible_fraction_vaccinated = min(
|
||||
1.0,
|
||||
num_vaccines_available_t[kTimeIndependentArrayIdx] /
|
||||
(susceptible[kArrayIdxPrevTime] + kEpsilon));
|
||||
float vaccinated_t = min(
|
||||
static_cast<float>(num_vaccines_available_t[
|
||||
kTimeIndependentArrayIdx]),
|
||||
susceptible[kArrayIdxPrevTime]);
|
||||
|
||||
// (S/N) * I in place of (S*I) / N to prevent overflow
|
||||
float neighborhood_SI_over_N = susceptible[kArrayIdxPrevTime] /
|
||||
kStatePopulation * infected[kArrayIdxPrevTime];
|
||||
int stringency_level_tmk;
|
||||
if (timestep < kBetaDelay) {
|
||||
stringency_level_tmk = kRealWorldStringencyPolicyHistory[
|
||||
(timestep - 1) * (kNumAgents - 1) + kAgentId];
|
||||
} else {
|
||||
stringency_level_tmk = stringency_level[kEnvId * (
|
||||
kEpisodeLength + 1) * (kNumAgents - 1) +
|
||||
(timestep - kBetaDelay) * (kNumAgents - 1) + kAgentId];
|
||||
}
|
||||
beta[kTimeIndependentArrayIdx] = stringency_level_tmk *
|
||||
kBetaSlope + kbetaIntercept;
|
||||
|
||||
float dS_t = -(neighborhood_SI_over_N * beta[
|
||||
kTimeIndependentArrayIdx] *
|
||||
(1 - susceptible_fraction_vaccinated) + vaccinated_t);
|
||||
float dR_t = kGamma * infected[kArrayIdxPrevTime] + vaccinated_t;
|
||||
float dI_t = - dS_t - dR_t;
|
||||
|
||||
susceptible[kArrayIdxCurrentTime] = max(
|
||||
0.0,
|
||||
susceptible[kArrayIdxPrevTime] + dS_t);
|
||||
infected[kArrayIdxCurrentTime] = max(
|
||||
0.0,
|
||||
infected[kArrayIdxPrevTime] + dI_t);
|
||||
recovered[kArrayIdxCurrentTime] = max(
|
||||
0.0,
|
||||
recovered[kArrayIdxPrevTime] + dR_t);
|
||||
|
||||
vaccinated[kArrayIdxCurrentTime] = vaccinated_t +
|
||||
vaccinated[kArrayIdxPrevTime];
|
||||
float recovered_but_not_vaccinated = recovered[kArrayIdxCurrentTime] -
|
||||
vaccinated[kArrayIdxCurrentTime];
|
||||
deaths[kArrayIdxCurrentTime] = recovered_but_not_vaccinated *
|
||||
kDeathRate;
|
||||
}
|
||||
|
||||
// CUDA version of the softplus() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ float softplus(float x) {
|
||||
const float kBeta = 1.0;
|
||||
const float kThreshold = 20.0;
|
||||
if (kBeta * x < kThreshold) {
|
||||
return 1.0 / kBeta * log(1.0 + exp(kBeta * x));
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ float signal2unemployment(
|
||||
const int kEnvId,
|
||||
const int kAgentId,
|
||||
float* signal,
|
||||
const float* kUnemploymentConvolutionalFilters,
|
||||
const float kUnemploymentBias,
|
||||
const int kNumAgents,
|
||||
const int kFilterLen,
|
||||
const int kNumFilters
|
||||
) {
|
||||
float unemployment = 0.0;
|
||||
const int kArrayIndexOffset = kEnvId * (kNumAgents - 1) * kNumFilters *
|
||||
kFilterLen + kAgentId * kNumFilters * kFilterLen;
|
||||
for (int index = 0; index < (kFilterLen * kNumFilters); index ++) {
|
||||
unemployment += signal[kArrayIndexOffset + index] *
|
||||
kUnemploymentConvolutionalFilters[index];
|
||||
}
|
||||
return softplus(unemployment) + kUnemploymentBias;
|
||||
}
|
||||
|
||||
// CUDA version of the unemployment_step() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ void cuda_unemployment_step(
|
||||
float* unemployed,
|
||||
int* stringency_level,
|
||||
int* delta_stringency_level,
|
||||
const float* kGroupedConvolutionalFilterWeights,
|
||||
const float* kUnemploymentConvolutionalFilters,
|
||||
const float* kUnemploymentBias,
|
||||
float* convolved_signal,
|
||||
const int kFilterLen,
|
||||
const int kNumFilters,
|
||||
const float kStatePopulation,
|
||||
const int kNumAgents,
|
||||
const int kEnvId,
|
||||
const int kAgentId,
|
||||
int timestep,
|
||||
const int kArrayIdxCurrentTime,
|
||||
const int kArrayIdxPrevTime
|
||||
) {
|
||||
// Shift array by kNumAgents - 1
|
||||
for (int idx = 0; idx < kFilterLen - 1; idx ++) {
|
||||
delta_stringency_level[
|
||||
kEnvId * kFilterLen * (kNumAgents - 1) + idx *
|
||||
(kNumAgents - 1) + kAgentId
|
||||
] =
|
||||
delta_stringency_level[
|
||||
kEnvId * kFilterLen * (kNumAgents - 1) + (idx + 1) *
|
||||
(kNumAgents - 1) + kAgentId
|
||||
];
|
||||
}
|
||||
|
||||
delta_stringency_level[
|
||||
kEnvId * kFilterLen * (kNumAgents - 1) + (kFilterLen - 1) *
|
||||
(kNumAgents - 1) + kAgentId
|
||||
] = stringency_level[kArrayIdxCurrentTime] -
|
||||
stringency_level[kArrayIdxPrevTime];
|
||||
|
||||
// convolved_signal refers to the convolution between the filter weights
|
||||
// and the delta stringency levels
|
||||
for (int filter_idx = 0; filter_idx < kNumFilters; filter_idx ++) {
|
||||
for (int idx = 0; idx < kFilterLen; idx ++) {
|
||||
convolved_signal[
|
||||
kEnvId * (kNumAgents - 1) * kNumFilters * kFilterLen +
|
||||
kAgentId * kNumFilters * kFilterLen +
|
||||
filter_idx * kFilterLen +
|
||||
idx
|
||||
] =
|
||||
delta_stringency_level[kEnvId * kFilterLen * (kNumAgents - 1) +
|
||||
idx * (kNumAgents - 1) + kAgentId] *
|
||||
kGroupedConvolutionalFilterWeights[kAgentId * kNumFilters +
|
||||
filter_idx];
|
||||
}
|
||||
}
|
||||
|
||||
float unemployment_rate = signal2unemployment(
|
||||
kEnvId,
|
||||
kAgentId,
|
||||
convolved_signal,
|
||||
kUnemploymentConvolutionalFilters,
|
||||
kUnemploymentBias[kAgentId],
|
||||
kNumAgents,
|
||||
kFilterLen,
|
||||
kNumFilters);
|
||||
|
||||
unemployed[kArrayIdxCurrentTime] =
|
||||
unemployment_rate * kStatePopulation / 100.0;
|
||||
}
|
||||
|
||||
// CUDA version of the economy_step() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ void cuda_economy_step(
|
||||
float* infected,
|
||||
float* deaths,
|
||||
float* unemployed,
|
||||
float* incapacitated,
|
||||
float* cant_work,
|
||||
float* num_people_that_can_work,
|
||||
const float kStatePopulation,
|
||||
const float kInfectionTooSickToWorkRate,
|
||||
const float kPopulationBetweenAge18And65,
|
||||
const float kDailyProductionPerWorker,
|
||||
float* productivity,
|
||||
float* subsidy,
|
||||
float* postsubsidy_productivity,
|
||||
int timestep,
|
||||
const int kArrayIdxCurrentTime,
|
||||
int kTimeIndependentArrayIdx
|
||||
) {
|
||||
incapacitated[kTimeIndependentArrayIdx] =
|
||||
kInfectionTooSickToWorkRate * infected[kArrayIdxCurrentTime] +
|
||||
deaths[kArrayIdxCurrentTime];
|
||||
cant_work[kTimeIndependentArrayIdx] =
|
||||
incapacitated[kTimeIndependentArrayIdx] *
|
||||
kPopulationBetweenAge18And65 + unemployed[kArrayIdxCurrentTime];
|
||||
int num_workers = static_cast<int>(kStatePopulation) * kPopulationBetweenAge18And65;
|
||||
num_people_that_can_work[kTimeIndependentArrayIdx] = max(
|
||||
0.0,
|
||||
num_workers - cant_work[kTimeIndependentArrayIdx]);
|
||||
productivity[kArrayIdxCurrentTime] =
|
||||
num_people_that_can_work[kTimeIndependentArrayIdx] *
|
||||
kDailyProductionPerWorker;
|
||||
|
||||
postsubsidy_productivity[kArrayIdxCurrentTime] =
|
||||
productivity[kArrayIdxCurrentTime] +
|
||||
subsidy[kArrayIdxCurrentTime];
|
||||
}
|
||||
|
||||
// CUDA version of crra_nonlinearity() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ float crra_nonlinearity(
|
||||
float x,
|
||||
const float kEta,
|
||||
const int kNumDaysInAnYear
|
||||
) {
|
||||
float annual_x = kNumDaysInAnYear * x;
|
||||
float annual_x_clipped = annual_x;
|
||||
if (annual_x < 0.1) {
|
||||
annual_x_clipped = 0.1;
|
||||
} else if (annual_x > 3.0) {
|
||||
annual_x_clipped = 3.0;
|
||||
}
|
||||
float annual_crra = 1 + (pow(annual_x_clipped, (1 - kEta)) - 1) /
|
||||
(1 - kEta);
|
||||
float daily_crra = annual_crra / kNumDaysInAnYear;
|
||||
return daily_crra;
|
||||
}
|
||||
|
||||
// CUDA version of min_max_normalization() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ float min_max_normalization(
|
||||
float x,
|
||||
const float kMinX,
|
||||
const float kMaxX
|
||||
) {
|
||||
return (x - kMinX) / (kMaxX - kMinX + kEpsilon);
|
||||
}
|
||||
|
||||
// CUDA version of get_rew() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__device__ float get_rew(
|
||||
const float kHealthIndexWeightage,
|
||||
float health_index,
|
||||
const float kEconomicIndexWeightage,
|
||||
float economic_index
|
||||
) {
|
||||
return (
|
||||
kHealthIndexWeightage * health_index
|
||||
+ kEconomicIndexWeightage * economic_index) /
|
||||
(kHealthIndexWeightage + kEconomicIndexWeightage);
|
||||
}
|
||||
|
||||
// CUDA version of scenario_step() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__global__ void CudaCovidAndEconomySimulationStep(
|
||||
float* susceptible,
|
||||
float* infected,
|
||||
float* recovered,
|
||||
float* deaths,
|
||||
float* vaccinated,
|
||||
float* unemployed,
|
||||
float* subsidy,
|
||||
float* productivity,
|
||||
int* stringency_level,
|
||||
const int kNumStringencyLevels,
|
||||
float* postsubsidy_productivity,
|
||||
int* num_vaccines_available_t,
|
||||
const int* kRealWorldStringencyPolicyHistory,
|
||||
const int kBetaDelay,
|
||||
const float* kBetaSlopes,
|
||||
const float* kbetaIntercepts,
|
||||
float* beta,
|
||||
const float kGamma,
|
||||
const float kDeathRate,
|
||||
float* incapacitated,
|
||||
float* cant_work,
|
||||
float* num_people_that_can_work,
|
||||
const int* us_kStatePopulation,
|
||||
const float kInfectionTooSickToWorkRate,
|
||||
const float kPopulationBetweenAge18And65,
|
||||
const int kFilterLen,
|
||||
const int kNumFilters,
|
||||
int* delta_stringency_level,
|
||||
const float* kGroupedConvolutionalFilterWeights,
|
||||
const float* kUnemploymentConvolutionalFilters,
|
||||
const float* kUnemploymentBias,
|
||||
float* signal,
|
||||
const float kDailyProductionPerWorker,
|
||||
const float* maximum_productivity,
|
||||
float* obs_a_world_agent_state,
|
||||
float* obs_a_world_agent_postsubsidy_productivity,
|
||||
float* obs_a_world_lagged_stringency_level,
|
||||
float* obs_a_time,
|
||||
float* obs_p_world_agent_state,
|
||||
float* obs_p_world_agent_postsubsidy_productivity,
|
||||
float* obs_p_world_lagged_stringency_level,
|
||||
float* obs_p_time,
|
||||
int * env_timestep_arr,
|
||||
const int kNumAgents,
|
||||
const int kEpisodeLength
|
||||
) {
|
||||
const int kEnvId = blockIdx.x;
|
||||
const int kAgentId = threadIdx.x;
|
||||
|
||||
assert(env_timestep_arr[kEnvId] > 0 &&
|
||||
env_timestep_arr[kEnvId] <= kEpisodeLength);
|
||||
assert (kAgentId <= kNumAgents - 1);
|
||||
const int kNumFeatures = 6;
|
||||
|
||||
if (kAgentId < (kNumAgents - 1)) {
|
||||
// Indices for time-dependent and time-independent arrays
|
||||
// Time dependent arrays have shapes (num_envs,
|
||||
// kEpisodeLength + 1, kNumAgents - 1)
|
||||
// Time independent arrays have shapes (num_envs, kNumAgents - 1)
|
||||
const int kArrayIndexOffset = kEnvId * (kEpisodeLength + 1) *
|
||||
(kNumAgents - 1);
|
||||
int kArrayIdxCurrentTime = kArrayIndexOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + kAgentId;
|
||||
int kArrayIdxPrevTime = kArrayIndexOffset +
|
||||
(env_timestep_arr[kEnvId] - 1) * (kNumAgents - 1) + kAgentId;
|
||||
const int kTimeIndependentArrayIdx = kEnvId *
|
||||
(kNumAgents - 1) + kAgentId;
|
||||
|
||||
const float kStatePopulation = static_cast<float>(us_kStatePopulation[kAgentId]);
|
||||
|
||||
cuda_sir_step(
|
||||
susceptible,
|
||||
infected,
|
||||
recovered,
|
||||
vaccinated,
|
||||
deaths,
|
||||
num_vaccines_available_t,
|
||||
kRealWorldStringencyPolicyHistory,
|
||||
kStatePopulation,
|
||||
kNumAgents,
|
||||
kBetaDelay,
|
||||
kBetaSlopes[kAgentId],
|
||||
kbetaIntercepts[kAgentId],
|
||||
stringency_level,
|
||||
beta,
|
||||
kGamma,
|
||||
kDeathRate,
|
||||
kEnvId,
|
||||
kAgentId,
|
||||
env_timestep_arr[kEnvId],
|
||||
kEpisodeLength,
|
||||
kArrayIdxCurrentTime,
|
||||
kArrayIdxPrevTime,
|
||||
kTimeIndependentArrayIdx);
|
||||
|
||||
cuda_unemployment_step(
|
||||
unemployed,
|
||||
stringency_level,
|
||||
delta_stringency_level,
|
||||
kGroupedConvolutionalFilterWeights,
|
||||
kUnemploymentConvolutionalFilters,
|
||||
kUnemploymentBias,
|
||||
signal,
|
||||
kFilterLen,
|
||||
kNumFilters,
|
||||
kStatePopulation,
|
||||
kNumAgents,
|
||||
kEnvId,
|
||||
kAgentId,
|
||||
env_timestep_arr[kEnvId],
|
||||
kArrayIdxCurrentTime,
|
||||
kArrayIdxPrevTime);
|
||||
|
||||
cuda_economy_step(
|
||||
infected,
|
||||
deaths,
|
||||
unemployed,
|
||||
incapacitated,
|
||||
cant_work,
|
||||
num_people_that_can_work,
|
||||
kStatePopulation,
|
||||
kInfectionTooSickToWorkRate,
|
||||
kPopulationBetweenAge18And65,
|
||||
kDailyProductionPerWorker,
|
||||
productivity,
|
||||
subsidy,
|
||||
postsubsidy_productivity,
|
||||
env_timestep_arr[kEnvId],
|
||||
kArrayIdxCurrentTime,
|
||||
kTimeIndependentArrayIdx);
|
||||
|
||||
// CUDA version of generate observations
|
||||
// Agents' observations
|
||||
int kFeatureArrayIndexOffset = kEnvId * kNumFeatures *
|
||||
(kNumAgents - 1) + kAgentId;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 0 * (kNumAgents - 1)
|
||||
] = susceptible[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 1 * (kNumAgents - 1)
|
||||
] = infected[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 2 * (kNumAgents - 1)
|
||||
] = recovered[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 3 * (kNumAgents - 1)
|
||||
] = deaths[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 4 * (kNumAgents - 1)
|
||||
] = vaccinated[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
obs_a_world_agent_state[
|
||||
kFeatureArrayIndexOffset + 5 * (kNumAgents - 1)
|
||||
] = unemployed[kArrayIdxCurrentTime] / kStatePopulation;
|
||||
|
||||
for (int feature_id = 0; feature_id < kNumFeatures; feature_id ++) {
|
||||
const int kIndex = feature_id * (kNumAgents - 1);
|
||||
obs_p_world_agent_state[kFeatureArrayIndexOffset +
|
||||
kIndex
|
||||
] = obs_a_world_agent_state[kFeatureArrayIndexOffset +
|
||||
kIndex];
|
||||
}
|
||||
|
||||
obs_a_world_agent_postsubsidy_productivity[
|
||||
kTimeIndependentArrayIdx
|
||||
] = postsubsidy_productivity[kArrayIdxCurrentTime] /
|
||||
maximum_productivity[kAgentId];
|
||||
obs_p_world_agent_postsubsidy_productivity[
|
||||
kTimeIndependentArrayIdx
|
||||
] = obs_a_world_agent_postsubsidy_productivity[
|
||||
kTimeIndependentArrayIdx
|
||||
];
|
||||
|
||||
int t_beta = env_timestep_arr[kEnvId] - kBetaDelay + 1;
|
||||
if (t_beta < 0) {
|
||||
obs_a_world_lagged_stringency_level[
|
||||
kTimeIndependentArrayIdx
|
||||
] = kRealWorldStringencyPolicyHistory[
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + kAgentId
|
||||
] / static_cast<float>(kNumStringencyLevels);
|
||||
} else {
|
||||
obs_a_world_lagged_stringency_level[
|
||||
kTimeIndependentArrayIdx
|
||||
] = stringency_level[
|
||||
kArrayIndexOffset +
|
||||
t_beta * (kNumAgents - 1) +
|
||||
kAgentId
|
||||
] / static_cast<float>(kNumStringencyLevels);
|
||||
}
|
||||
obs_p_world_lagged_stringency_level[
|
||||
kTimeIndependentArrayIdx
|
||||
] = obs_a_world_lagged_stringency_level[
|
||||
kTimeIndependentArrayIdx];
|
||||
// Below, we assume observation scaling = True
|
||||
// (otherwise, 'obs_a_time[kTimeIndependentArrayIdx] =
|
||||
// static_cast<float>(env_timestep_arr[kEnvId])
|
||||
obs_a_time[kTimeIndependentArrayIdx] =
|
||||
env_timestep_arr[kEnvId] / static_cast<float>(kEpisodeLength);
|
||||
} else if (kAgentId == kNumAgents - 1) {
|
||||
obs_p_time[kEnvId] = env_timestep_arr[kEnvId] /
|
||||
static_cast<float>(kEpisodeLength);
|
||||
}
|
||||
}
|
||||
|
||||
// CUDA version of the compute_reward() in
|
||||
// "ai_economist.foundation.scenarios.covid19_env.py"
|
||||
__global__ void CudaComputeReward(
|
||||
float* rewards_a,
|
||||
float* rewards_p,
|
||||
const int kNumDaysInAnYear,
|
||||
const int kValueOfLife,
|
||||
const float kRiskFreeInterestRate,
|
||||
const float kEconomicRewardCrraEta,
|
||||
const float* kMinMarginalAgentHealthIndex,
|
||||
const float* kMaxMarginalAgentHealthIndex,
|
||||
const float* kMinMarginalAgentEconomicIndex,
|
||||
const float* kMaxMarginalAgentEconomicIndex,
|
||||
const float kMinMarginalPlannerHealthIndex,
|
||||
const float kMaxMarginalPlannerHealthIndex,
|
||||
const float kMinMarginalPlannerEconomicIndex,
|
||||
const float kMaxMarginalPlannerEconomicIndex,
|
||||
const float* kWeightageOnMarginalAgentHealthIndex,
|
||||
const float* kWeightageOnMarginalPlannerHealthIndex,
|
||||
const float kWeightageOnMarginalAgentEconomicIndex,
|
||||
const float kWeightageOnMarginalPlannerEconomicIndex,
|
||||
const float* kAgentsHealthNorm,
|
||||
const float* kAgentsEconomicNorm,
|
||||
const float kPlannerHealthNorm,
|
||||
const float kPlannerEconomicNorm,
|
||||
float* deaths,
|
||||
float* subsidy,
|
||||
float* postsubsidy_productivity,
|
||||
int* env_done_arr,
|
||||
int* env_timestep_arr,
|
||||
const int kNumAgents,
|
||||
const int kEpisodeLength
|
||||
) {
|
||||
const int kEnvId = blockIdx.x;
|
||||
const int kAgentId = threadIdx.x;
|
||||
|
||||
assert(env_timestep_arr[kEnvId] > 0 &&
|
||||
env_timestep_arr[kEnvId] <= kEpisodeLength);
|
||||
assert (kAgentId <= kNumAgents - 1);
|
||||
|
||||
const int kArrayIndexOffset = kEnvId * (kEpisodeLength + 1) *
|
||||
(kNumAgents - 1);
|
||||
if (kAgentId < (kNumAgents - 1)) {
|
||||
// Agents' rewards
|
||||
// Indices for time-dependent and time-independent arrays
|
||||
// Time dependent arrays have shapes (num_envs,
|
||||
// kEpisodeLength + 1, kNumAgents - 1)
|
||||
// Time independent arrays have shapes (num_envs, kNumAgents - 1)
|
||||
int kArrayIdxCurrentTime = kArrayIndexOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + kAgentId;
|
||||
int kArrayIdxPrevTime = kArrayIndexOffset +
|
||||
(env_timestep_arr[kEnvId] - 1) * (kNumAgents - 1) + kAgentId;
|
||||
const int kTimeIndependentArrayIdx = kEnvId *
|
||||
(kNumAgents - 1) + kAgentId;
|
||||
|
||||
float marginal_deaths = deaths[kArrayIdxCurrentTime] -
|
||||
deaths[kArrayIdxPrevTime];
|
||||
|
||||
// Note: changing the order of operations to prevent overflow
|
||||
float marginal_agent_health_index = - marginal_deaths /
|
||||
(kAgentsHealthNorm[kAgentId] /
|
||||
static_cast<float>(kValueOfLife));
|
||||
|
||||
float marginal_agent_economic_index = crra_nonlinearity(
|
||||
postsubsidy_productivity[kArrayIdxCurrentTime] /
|
||||
kAgentsEconomicNorm[kAgentId],
|
||||
kEconomicRewardCrraEta,
|
||||
kNumDaysInAnYear);
|
||||
|
||||
marginal_agent_health_index = min_max_normalization(
|
||||
marginal_agent_health_index,
|
||||
kMinMarginalAgentHealthIndex[kAgentId],
|
||||
kMaxMarginalAgentHealthIndex[kAgentId]);
|
||||
marginal_agent_economic_index = min_max_normalization(
|
||||
marginal_agent_economic_index,
|
||||
kMinMarginalAgentEconomicIndex[kAgentId],
|
||||
kMaxMarginalAgentEconomicIndex[kAgentId]);
|
||||
|
||||
rewards_a[kTimeIndependentArrayIdx] = get_rew(
|
||||
kWeightageOnMarginalAgentHealthIndex[kAgentId],
|
||||
marginal_agent_health_index,
|
||||
kWeightageOnMarginalPlannerHealthIndex[kAgentId],
|
||||
marginal_agent_economic_index);
|
||||
} else if (kAgentId == kNumAgents - 1) {
|
||||
// Planner's rewards
|
||||
float total_marginal_deaths = 0;
|
||||
for (int ag_id = 0; ag_id < (kNumAgents - 1); ag_id ++) {
|
||||
total_marginal_deaths += (
|
||||
deaths[kArrayIndexOffset + env_timestep_arr[kEnvId] *
|
||||
(kNumAgents - 1) + ag_id] -
|
||||
deaths[kArrayIndexOffset + (env_timestep_arr[kEnvId] - 1) *
|
||||
(kNumAgents - 1) + ag_id]);
|
||||
}
|
||||
// Note: changing the order of operations to prevent overflow
|
||||
float marginal_planner_health_index = -total_marginal_deaths /
|
||||
(kPlannerHealthNorm / static_cast<float>(kValueOfLife));
|
||||
|
||||
float total_subsidy = 0.0;
|
||||
float total_postsubsidy_productivity = 0.0;
|
||||
for (int ag_id = 0; ag_id < (kNumAgents - 1); ag_id ++) {
|
||||
total_subsidy += subsidy[kArrayIndexOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + ag_id];
|
||||
total_postsubsidy_productivity +=
|
||||
postsubsidy_productivity[kArrayIndexOffset +
|
||||
env_timestep_arr[kEnvId] * (kNumAgents - 1) + ag_id];
|
||||
}
|
||||
|
||||
float cost_of_subsidy = (1 + kRiskFreeInterestRate) *
|
||||
total_subsidy;
|
||||
float marginal_planner_economic_index = crra_nonlinearity(
|
||||
(total_postsubsidy_productivity - cost_of_subsidy) /
|
||||
kPlannerEconomicNorm,
|
||||
kEconomicRewardCrraEta,
|
||||
kNumDaysInAnYear);
|
||||
|
||||
marginal_planner_health_index = min_max_normalization(
|
||||
marginal_planner_health_index,
|
||||
kMinMarginalPlannerHealthIndex,
|
||||
kMaxMarginalPlannerHealthIndex);
|
||||
marginal_planner_economic_index = min_max_normalization(
|
||||
marginal_planner_economic_index,
|
||||
kMinMarginalPlannerEconomicIndex,
|
||||
kMaxMarginalPlannerEconomicIndex);
|
||||
|
||||
rewards_p[kEnvId] = get_rew(
|
||||
kWeightageOnMarginalAgentEconomicIndex,
|
||||
marginal_planner_health_index,
|
||||
kWeightageOnMarginalPlannerEconomicIndex,
|
||||
marginal_planner_economic_index);
|
||||
}
|
||||
|
||||
// Wait here for all agents to finish computing rewards
|
||||
__syncthreads();
|
||||
|
||||
// Use only agent 0's thread to set done_arr
|
||||
if (kAgentId == 0) {
|
||||
if (env_timestep_arr[kEnvId] == kEpisodeLength) {
|
||||
env_timestep_arr[kEnvId] = 0;
|
||||
env_done_arr[kEnvId] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpgIBAAKCAQEAk1+Qz0/Qg4OOGrskBJnVI9KVGTEUvsldHUV4AzLeecYZSV5+
|
||||
FZUQpl8lq1mstUZZ0xMlGSHz2t+AAJxyEro8mAj9gAp1qeN58pAX2k29DOt4YRnp
|
||||
sTF1UG+nrV2aW+jfH16aeVsjWY+Nq+GxGyE3Q5bsxOhnOg0TUaB6RY8SBE/scTHn
|
||||
bfNsgTc5EuiAAGqYYYdu12n5zeyvfjGW7bBf4Q9t0F0bI+YdZQY9HD35KAoNcqFQ
|
||||
dvd2vKbojejkn+WyO1amnZxgAhVjpT61FV4u18jPN0Qrt0LHuF5kUVzYal+73ySY
|
||||
BbwEo4onEn9xvUlQGFJWmv4OPwbI3d4nLqP+mQIHK9xUXfK97QKCAQABeR2EO0uu
|
||||
ERyRXa5Mh7xsOEq/OJ9sQq+si8B5gDyyM1SW61wQMKF4Wiqw68bMCVvGRwScZD+T
|
||||
XwBEBJMm9lCVx/UfOWqYSNFCk/YBefv9AI0Kg5lfCMZQuTdjMcbJdjoR5xoiCbO1
|
||||
ya7oOU8mfWx/SV0o/698b/zMVBKBBQDNZaN9pmtTOgm3G1QnM9ZlmrdlKYpe9Ihs
|
||||
3sG4437QaPhumdZi8IoLBGMyYL2O38pG34LJjIkP8Efj1QVTndIIZX8CKghir++j
|
||||
nUAyofFt7/PBS2k7gQ/1gFISwHxKjmzl/Fc25o7ahlLbO+i2UnRiB9IXcmiGDXMv
|
||||
tY09oXhxCtTZAoGBAMEkMTzoiqKjXLwKLyFIF5QzXqQKcGqfC8NhQMsm43K0TgHg
|
||||
Sv1fLdnKw0FWSG30gppBorAY9p5FoI+AWwTSd+AJhz7T1y/shpJx1oBR8qKWO5kO
|
||||
gMru9kRRb0zb5hydakie3mujz7GUPiXrntKZjC4QYLar0USPulJnU+UTF6QjAoGB
|
||||
AMNWJqG1ybrk0sNkWJJDW+MnMT0T9o0E+CtbRHqMHh7K1LF9Sc/qh0gLfDo51+kr
|
||||
pscLaaJiF1Q8phzDhW9QDeNv+4lknNqMFBCFtzns1wVDlXL4U87oqhuBSs6IZAuO
|
||||
CGVefYKgefdwn64rcyRNala44BbiMJKwRoDvvgH1FvATAoGAV1YK9ZHB1RkXkZ5a
|
||||
uBePXvkScaujH4DxadMGf2tBuI1wIpVwhxOQ56yDwYoAuexXPUa8BAx2V69/LFo7
|
||||
H/yDYqzndA8WwZLy8oy7Ug+fFLtCp7VhkEwMPciBq6KjzUyShIBlgZOx5m5kTbfu
|
||||
Cs2JQU35YHeompcpLooRG1/cFZkCgYAyVlWABzmgSKJL9ohwlSBBZFCjQ1mjN6uc
|
||||
uRJxncqfCe3XQ5erFjuWMfPayWONBsWexNucJFc7Iz2LzCOXkUsftldEEET9f/2w
|
||||
PrbsEu8khNTLqUcow2Whz+A8C0dV6p2cqtTKR1XlSmNVqP30lmpHcmF+R3M/J1ON
|
||||
K7S9zJJ+zwKBgHIuCATGCGCOzAsUo80OQL46j74SxRV3H1CJASLKzatiTo54dbO6
|
||||
86w+N6BfYtYeRlnX1CTGl6bHqVUMBBlKws8Ig3gV3xFS8BiSav8zQ2m99JuhlVHF
|
||||
Ocfowmuad3WXYvYXQ5IeP2JM/3q7BoPLg1DKP4GGZlNbatMRI+H0HimV
|
||||
-----END RSA PRIVATE KEY-----
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
@@ -0,0 +1,336 @@
|
||||
# Copyright (c) 2021 salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.base.base_env import BaseEnvironment, scenario_registry
|
||||
from ai_economist.foundation.scenarios.utils import rewards, social_metrics
|
||||
|
||||
|
||||
@scenario_registry.add
|
||||
class OneStepEconomy(BaseEnvironment):
|
||||
"""
|
||||
A simple model featuring one "step" of setting taxes and earning income.
|
||||
|
||||
As described in https://arxiv.org/abs/2108.02755:
|
||||
A simplified version of simple_wood_and_stone scenario where both the planner
|
||||
and the agents each make a single decision: the planner setting taxes and the
|
||||
agents choosing labor. Each agent chooses an amount of labor that optimizes
|
||||
its post-tax utility, and this optimal labor depends on its skill and the tax
|
||||
rates, and it does not depend on the labor choices of other agents. Before
|
||||
the agents act, the planner sets the marginal tax rates in order to optimize
|
||||
social welfare.
|
||||
|
||||
Note:
|
||||
This scenario is intended to be used with the 'PeriodicBracketTax' and
|
||||
'SimpleLabor' components.
|
||||
It should use an episode length of 2. In the first step, taxes are set by
|
||||
the planner via 'PeriodicBracketTax'. In the second, agents select how much
|
||||
to work/earn via 'SimpleLabor'.
|
||||
|
||||
Args:
|
||||
agent_reward_type (str): The type of utility function used to compute each
|
||||
agent's reward. Defaults to "coin_minus_labor_cost".
|
||||
isoelastic_eta (float): The shape parameter of the isoelastic function used
|
||||
in the "isoelastic_coin_minus_labor" utility function.
|
||||
labor_exponent (float): The labor exponent parameter used in the
|
||||
"coin_minus_labor_cost" utility function.
|
||||
labor_cost (float): The coefficient used to weight the cost of labor.
|
||||
planner_reward_type (str): The type of social welfare function (SWF) used to
|
||||
compute the planner's reward. Defaults to "inv_income_weighted_utility".
|
||||
mixing_weight_gini_vs_coin (float): Must be between 0 and 1 (inclusive).
|
||||
Controls the weighting of equality and productivity when using SWF
|
||||
"coin_eq_times_productivity", where a value of 0 (default) yields equal
|
||||
weighting, and 1 only considers productivity.
|
||||
"""
|
||||
|
||||
name = "one-step-economy"
|
||||
agent_subclasses = ["BasicMobileAgent", "BasicPlanner"]
|
||||
required_entities = ["Coin"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_env_args,
|
||||
agent_reward_type="coin_minus_labor_cost",
|
||||
isoelastic_eta=0.23,
|
||||
labor_exponent=2.0,
|
||||
labor_cost=1.0,
|
||||
planner_reward_type="inv_income_weighted_utility",
|
||||
mixing_weight_gini_vs_coin=0,
|
||||
**base_env_kwargs
|
||||
):
|
||||
super().__init__(*base_env_args, **base_env_kwargs)
|
||||
|
||||
self.num_agents = len(self.world.agents)
|
||||
|
||||
self.labor_cost = labor_cost
|
||||
self.agent_reward_type = agent_reward_type
|
||||
self.isoelastic_eta = isoelastic_eta
|
||||
self.labor_exponent = labor_exponent
|
||||
self.planner_reward_type = planner_reward_type
|
||||
self.mixing_weight_gini_vs_coin = mixing_weight_gini_vs_coin
|
||||
self.planner_starting_coin = 0
|
||||
|
||||
self.curr_optimization_metrics = {str(a.idx): 0 for a in self.all_agents}
|
||||
|
||||
# The following methods must be implemented for each scenario
|
||||
# -----------------------------------------------------------
|
||||
def reset_starting_layout(self):
|
||||
"""
|
||||
Part 1/2 of scenario reset. This method handles resetting the state of the
|
||||
environment managed by the scenario (i.e. resource & landmark layout).
|
||||
|
||||
Here, generate a resource source layout consistent with target parameters.
|
||||
"""
|
||||
|
||||
def reset_agent_states(self):
|
||||
"""
|
||||
Part 2/2 of scenario reset. This method handles resetting the state of the
|
||||
agents themselves (i.e. inventory, locations, etc.).
|
||||
|
||||
Here, empty inventories, give mobile agents any starting coin, and place them
|
||||
in random accesible locations to start.
|
||||
"""
|
||||
self.world.clear_agent_locs()
|
||||
|
||||
for agent in self.world.agents:
|
||||
# Clear everything to start with
|
||||
agent.state["inventory"] = {k: 0 for k in agent.state["inventory"].keys()}
|
||||
agent.state["escrow"] = {k: 0 for k in agent.state["escrow"].keys()}
|
||||
agent.state["endogenous"] = {k: 0 for k in agent.state["endogenous"].keys()}
|
||||
|
||||
self.world.planner.inventory["Coin"] = self.planner_starting_coin
|
||||
|
||||
def scenario_step(self):
|
||||
"""
|
||||
Update the state of the world according to whatever rules this scenario
|
||||
implements.
|
||||
|
||||
This gets called in the 'step' method (of base_env) after going through each
|
||||
component step and before generating observations, rewards, etc.
|
||||
|
||||
NOTE: does not take agent actions into account.
|
||||
"""
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
Generate observations associated with this scenario.
|
||||
|
||||
A scenario does not need to produce observations and can provide observations
|
||||
for only some agent types; however, for a given agent type, it should either
|
||||
always or never yield an observation. If it does yield an observation,
|
||||
that observation should always have the same structure/sizes!
|
||||
|
||||
Returns:
|
||||
obs (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent (which can including
|
||||
the planner) for which this scenario provides an observation. For each
|
||||
entry, the key specifies the index of the agent and the value contains
|
||||
its associated observation dictionary.
|
||||
|
||||
Here, non-planner agents receive spatial observations (depending on the env
|
||||
config) as well as the contents of their inventory and endogenous quantities.
|
||||
The planner also receives spatial observations (again, depending on the env
|
||||
config) as well as the inventory of each of the mobile agents.
|
||||
"""
|
||||
obs_dict = dict()
|
||||
for agent in self.world.agents:
|
||||
obs_dict[str(agent.idx)] = {}
|
||||
|
||||
coin_endowments = np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
)
|
||||
equality = social_metrics.get_equality(coin_endowments)
|
||||
productivity = social_metrics.get_productivity(coin_endowments)
|
||||
normalized_per_capita_productivity = productivity / self.num_agents / 1000
|
||||
obs_dict[self.world.planner.idx] = {
|
||||
"normalized_per_capita_productivity": normalized_per_capita_productivity,
|
||||
"equality": equality,
|
||||
}
|
||||
|
||||
return obs_dict
|
||||
|
||||
def compute_reward(self):
|
||||
"""
|
||||
Apply the reward function(s) associated with this scenario to get the rewards
|
||||
from this step.
|
||||
|
||||
Returns:
|
||||
rew (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent in the environment
|
||||
(including the planner). For each entry, the key specifies the index of
|
||||
the agent and the value contains the scalar reward earned this timestep.
|
||||
|
||||
Rewards are computed as the marginal utility (agents) or marginal social
|
||||
welfare (planner) experienced on this timestep. Ignoring discounting,
|
||||
this means that agents' (planner's) objective is to maximize the utility
|
||||
(social welfare) associated with the terminal state of the episode.
|
||||
"""
|
||||
curr_optimization_metrics = self.get_current_optimization_metrics(
|
||||
self.world.agents,
|
||||
isoelastic_eta=float(self.isoelastic_eta),
|
||||
labor_exponent=float(self.labor_exponent),
|
||||
labor_coefficient=float(self.labor_cost),
|
||||
)
|
||||
planner_agents_rew = {
|
||||
k: v - self.curr_optimization_metrics[k]
|
||||
for k, v in curr_optimization_metrics.items()
|
||||
}
|
||||
self.curr_optimization_metrics = curr_optimization_metrics
|
||||
return planner_agents_rew
|
||||
|
||||
# Optional methods for customization
|
||||
# ----------------------------------
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
Extra scenario-specific steps that should be performed at the end of the reset
|
||||
cycle.
|
||||
|
||||
For each reset cycle...
|
||||
First, reset_starting_layout() and reset_agent_states() will be called.
|
||||
|
||||
Second, <component>.reset() will be called for each registered component.
|
||||
|
||||
Lastly, this method will be called to allow for any final customization of
|
||||
the reset cycle.
|
||||
"""
|
||||
self.curr_optimization_metrics = self.get_current_optimization_metrics(
|
||||
self.world.agents,
|
||||
isoelastic_eta=float(self.isoelastic_eta),
|
||||
labor_exponent=float(self.labor_exponent),
|
||||
labor_coefficient=float(self.labor_cost),
|
||||
)
|
||||
|
||||
def scenario_metrics(self):
|
||||
"""
|
||||
Allows the scenario to generate metrics (collected along with component metrics
|
||||
in the 'metrics' property).
|
||||
|
||||
To have the scenario add metrics, this function needs to return a dictionary of
|
||||
{metric_key: value} where 'value' is a scalar (no nesting or lists!)
|
||||
|
||||
Here, summarize social metrics, endowments, utilities, and labor cost annealing.
|
||||
"""
|
||||
metrics = dict()
|
||||
|
||||
# Log social/economic indicators
|
||||
coin_endowments = np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
)
|
||||
pretax_incomes = np.array(
|
||||
[agent.state["production"] for agent in self.world.agents]
|
||||
)
|
||||
metrics["social/productivity"] = social_metrics.get_productivity(
|
||||
coin_endowments
|
||||
)
|
||||
metrics["social/equality"] = social_metrics.get_equality(coin_endowments)
|
||||
|
||||
utilities = np.array(
|
||||
[self.curr_optimization_metrics[agent.idx] for agent in self.world.agents]
|
||||
)
|
||||
metrics[
|
||||
"social_welfare/coin_eq_times_productivity"
|
||||
] = rewards.coin_eq_times_productivity(
|
||||
coin_endowments=coin_endowments, equality_weight=1.0
|
||||
)
|
||||
metrics[
|
||||
"social_welfare/inv_income_weighted_utility"
|
||||
] = rewards.inv_income_weighted_utility(
|
||||
coin_endowments=pretax_incomes, utilities=utilities # coin_endowments,
|
||||
)
|
||||
|
||||
# Log average endowments, endogenous, and utility for agents
|
||||
agent_endows = {}
|
||||
agent_endogenous = {}
|
||||
agent_utilities = []
|
||||
for agent in self.world.agents:
|
||||
for resource in agent.inventory.keys():
|
||||
if resource not in agent_endows:
|
||||
agent_endows[resource] = []
|
||||
agent_endows[resource].append(
|
||||
agent.inventory[resource] + agent.escrow[resource]
|
||||
)
|
||||
|
||||
for endogenous, quantity in agent.endogenous.items():
|
||||
if endogenous not in agent_endogenous:
|
||||
agent_endogenous[endogenous] = []
|
||||
agent_endogenous[endogenous].append(quantity)
|
||||
|
||||
agent_utilities.append(self.curr_optimization_metrics[agent.idx])
|
||||
|
||||
for resource, quantities in agent_endows.items():
|
||||
metrics["endow/avg_agent/{}".format(resource)] = np.mean(quantities)
|
||||
|
||||
for endogenous, quantities in agent_endogenous.items():
|
||||
metrics["endogenous/avg_agent/{}".format(endogenous)] = np.mean(quantities)
|
||||
|
||||
metrics["util/avg_agent"] = np.mean(agent_utilities)
|
||||
|
||||
# Log endowments and utility for the planner
|
||||
for resource, quantity in self.world.planner.inventory.items():
|
||||
metrics["endow/p/{}".format(resource)] = quantity
|
||||
|
||||
metrics["util/p"] = self.curr_optimization_metrics[self.world.planner.idx]
|
||||
|
||||
return metrics
|
||||
|
||||
def get_current_optimization_metrics(
|
||||
self, agents, isoelastic_eta=0.23, labor_exponent=2.0, labor_coefficient=0.1
|
||||
):
|
||||
"""
|
||||
Compute optimization metrics based on the current state. Used to compute reward.
|
||||
|
||||
Returns:
|
||||
curr_optimization_metric (dict): A dictionary of {agent.idx: metric}
|
||||
with an entry for each agent (including the planner) in the env.
|
||||
"""
|
||||
curr_optimization_metric = {}
|
||||
|
||||
coin_endowments = np.array([agent.total_endowment("Coin") for agent in agents])
|
||||
|
||||
pretax_incomes = np.array([agent.state["production"] for agent in agents])
|
||||
|
||||
# Optimization metric for agents:
|
||||
for agent in agents:
|
||||
if self.agent_reward_type == "isoelastic_coin_minus_labor":
|
||||
assert 0.0 <= isoelastic_eta <= 1.0
|
||||
curr_optimization_metric[
|
||||
agent.idx
|
||||
] = rewards.isoelastic_coin_minus_labor(
|
||||
coin_endowment=agent.total_endowment("Coin"),
|
||||
total_labor=agent.state["endogenous"]["Labor"],
|
||||
isoelastic_eta=isoelastic_eta,
|
||||
labor_coefficient=labor_coefficient,
|
||||
)
|
||||
elif self.agent_reward_type == "coin_minus_labor_cost":
|
||||
assert labor_exponent > 1.0
|
||||
curr_optimization_metric[agent.idx] = rewards.coin_minus_labor_cost(
|
||||
coin_endowment=agent.total_endowment("Coin"),
|
||||
total_labor=agent.state["endogenous"]["Labor"],
|
||||
labor_exponent=labor_exponent,
|
||||
labor_coefficient=labor_coefficient,
|
||||
)
|
||||
# Optimization metric for the planner:
|
||||
if self.planner_reward_type == "coin_eq_times_productivity":
|
||||
curr_optimization_metric[
|
||||
self.world.planner.idx
|
||||
] = rewards.coin_eq_times_productivity(
|
||||
coin_endowments=coin_endowments,
|
||||
equality_weight=1 - self.mixing_weight_gini_vs_coin,
|
||||
)
|
||||
elif self.planner_reward_type == "inv_income_weighted_utility":
|
||||
curr_optimization_metric[
|
||||
self.world.planner.idx
|
||||
] = rewards.inv_income_weighted_utility(
|
||||
coin_endowments=pretax_incomes, # coin_endowments,
|
||||
utilities=np.array(
|
||||
[curr_optimization_metric[agent.idx] for agent in agents]
|
||||
),
|
||||
)
|
||||
else:
|
||||
print("No valid planner reward selected!")
|
||||
raise NotImplementedError
|
||||
return curr_optimization_metric
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,800 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
from ai_economist.foundation.base.base_env import BaseEnvironment, scenario_registry
|
||||
from ai_economist.foundation.scenarios.utils import rewards, social_metrics
|
||||
|
||||
|
||||
@scenario_registry.add
|
||||
class LayoutFromFile(BaseEnvironment):
|
||||
"""
|
||||
World containing stone and wood with stochastic regeneration. Refers to a fixed
|
||||
layout file (see ./map_txt/ for examples) to determine the spatial arrangement of
|
||||
stone, wood, and water tiles.
|
||||
|
||||
Args:
|
||||
planner_gets_spatial_obs (bool): Whether the planner agent receives spatial
|
||||
observations from the world.
|
||||
full_observability (bool): Whether the mobile agents' spatial observation
|
||||
includes the full world view or is instead an egocentric view.
|
||||
mobile_agent_observation_range (int): If not using full_observability,
|
||||
the spatial range (on each side of the agent) that is visible in the
|
||||
spatial observations.
|
||||
env_layout_file (str): Name of the layout file in ./map_txt/ to use.
|
||||
Note: The world dimensions of that layout must match the world dimensions
|
||||
argument used to construct the environment.
|
||||
resource_regen_prob (float): Probability that an empty source tile will
|
||||
regenerate a new resource unit.
|
||||
fixed_four_skill_and_loc (bool): Whether to use a fixed set of build skills and
|
||||
starting locations, with agents grouped into starting locations based on
|
||||
which skill quartile they are in. False, by default.
|
||||
True, for experiments in https://arxiv.org/abs/2004.13332.
|
||||
Note: Requires that the environment uses the "Build" component with
|
||||
skill_dist="pareto".
|
||||
starting_agent_coin (int, float): Amount of coin agents have at t=0. Defaults
|
||||
to zero coin.
|
||||
isoelastic_eta (float): Parameter controlling the shape of agent utility
|
||||
wrt coin endowment.
|
||||
energy_cost (float): Coefficient for converting labor to negative utility.
|
||||
energy_warmup_constant (float): Decay constant that controls the rate at which
|
||||
the effective energy cost is annealed from 0 to energy_cost. Set to 0
|
||||
(default) to disable annealing, meaning that the effective energy cost is
|
||||
always energy_cost. The units of the decay constant depend on the choice of
|
||||
energy_warmup_method.
|
||||
energy_warmup_method (str): How to schedule energy annealing (warmup). If
|
||||
"decay" (default), use the number of completed episodes. If "auto",
|
||||
use the number of timesteps where the average agent reward was positive.
|
||||
planner_reward_type (str): The type of reward used for the planner. Options
|
||||
are "coin_eq_times_productivity" (default),
|
||||
"inv_income_weighted_coin_endowment", and "inv_income_weighted_utility".
|
||||
mixing_weight_gini_vs_coin (float): Degree to which equality is ignored w/
|
||||
"coin_eq_times_productivity". Default is 0, which weights equality and
|
||||
productivity equally. If set to 1, only productivity is rewarded.
|
||||
"""
|
||||
|
||||
name = "layout_from_file/simple_wood_and_stone"
|
||||
agent_subclasses = ["BasicMobileAgent", "BasicPlanner"]
|
||||
required_entities = ["Wood", "Stone", "Water"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*base_env_args,
|
||||
planner_gets_spatial_info=True,
|
||||
full_observability=False,
|
||||
mobile_agent_observation_range=5,
|
||||
env_layout_file="quadrant_25x25_20each_30clump.txt",
|
||||
resource_regen_prob=0.01,
|
||||
fixed_four_skill_and_loc=False,
|
||||
starting_agent_coin=0,
|
||||
isoelastic_eta=0.23,
|
||||
energy_cost=0.21,
|
||||
energy_warmup_constant=0,
|
||||
energy_warmup_method="decay",
|
||||
planner_reward_type="coin_eq_times_productivity",
|
||||
mixing_weight_gini_vs_coin=0.0,
|
||||
**base_env_kwargs,
|
||||
):
|
||||
super().__init__(*base_env_args, **base_env_kwargs)
|
||||
|
||||
# Whether agents receive spatial information in their observation tensor
|
||||
self._planner_gets_spatial_info = bool(planner_gets_spatial_info)
|
||||
|
||||
# Whether the (non-planner) agents can see the whole world map
|
||||
self._full_observability = bool(full_observability)
|
||||
|
||||
self._mobile_agent_observation_range = int(mobile_agent_observation_range)
|
||||
|
||||
# Load in the layout
|
||||
path_to_layout_file = Path(f"{Path(__file__).parent}/map_txt/{env_layout_file}")
|
||||
|
||||
with open(path_to_layout_file, "r") as f:
|
||||
self.env_layout_string = f.read()
|
||||
self.env_layout = self.env_layout_string.split(";")
|
||||
|
||||
# Convert the layout to landmark maps
|
||||
landmark_lookup = {"W": "Wood", "S": "Stone", "@": "Water"}
|
||||
self._source_maps = {
|
||||
r: np.zeros(self.world_size) for r in landmark_lookup.values()
|
||||
}
|
||||
for r, symbol_row in enumerate(self.env_layout):
|
||||
for c, symbol in enumerate(symbol_row):
|
||||
landmark = landmark_lookup.get(symbol, None)
|
||||
if landmark:
|
||||
self._source_maps[landmark][r, c] = 1
|
||||
|
||||
# For controlling how resource regeneration behavior
|
||||
self.layout_specs = dict(
|
||||
Wood={
|
||||
"regen_weight": float(resource_regen_prob),
|
||||
"regen_halfwidth": 0,
|
||||
"max_health": 1,
|
||||
},
|
||||
Stone={
|
||||
"regen_weight": float(resource_regen_prob),
|
||||
"regen_halfwidth": 0,
|
||||
"max_health": 1,
|
||||
},
|
||||
)
|
||||
assert 0 <= self.layout_specs["Wood"]["regen_weight"] <= 1
|
||||
assert 0 <= self.layout_specs["Stone"]["regen_weight"] <= 1
|
||||
|
||||
# How much coin do agents begin with at upon reset
|
||||
self.starting_agent_coin = float(starting_agent_coin)
|
||||
assert self.starting_agent_coin >= 0.0
|
||||
|
||||
# Controls the diminishing marginal utility of coin.
|
||||
# isoelastic_eta=0 means no diminishing utility.
|
||||
self.isoelastic_eta = float(isoelastic_eta)
|
||||
assert 0.0 <= self.isoelastic_eta <= 1.0
|
||||
|
||||
# The amount that labor is weighted in utility computation
|
||||
# (once annealing is finished)
|
||||
self.energy_cost = float(energy_cost)
|
||||
assert self.energy_cost >= 0
|
||||
|
||||
# Which method to use for calculating the progress of energy annealing
|
||||
# If method = 'decay': #completed episodes
|
||||
# If method = 'auto' : #timesteps where avg. agent reward > 0
|
||||
self.energy_warmup_method = energy_warmup_method.lower()
|
||||
assert self.energy_warmup_method in ["decay", "auto"]
|
||||
# Decay constant for annealing to full energy cost
|
||||
# (if energy_warmup_constant == 0, there is no annealing)
|
||||
self.energy_warmup_constant = float(energy_warmup_constant)
|
||||
assert self.energy_warmup_constant >= 0
|
||||
self._auto_warmup_integrator = 0
|
||||
|
||||
# Which social welfare function to use
|
||||
self.planner_reward_type = str(planner_reward_type).lower()
|
||||
|
||||
# How much to weight equality if using SWF=eq*prod:
|
||||
# 0 -> SWF=eq * prod
|
||||
# 1 -> SWF=prod
|
||||
self.mixing_weight_gini_vs_coin = float(mixing_weight_gini_vs_coin)
|
||||
assert 0 <= self.mixing_weight_gini_vs_coin <= 1.0
|
||||
|
||||
# Use this to calculate marginal changes and deliver that as reward
|
||||
self.init_optimization_metric = {agent.idx: 0 for agent in self.all_agents}
|
||||
self.prev_optimization_metric = {agent.idx: 0 for agent in self.all_agents}
|
||||
self.curr_optimization_metric = {agent.idx: 0 for agent in self.all_agents}
|
||||
|
||||
"""
|
||||
Fixed Four Skill and Loc
|
||||
------------------------
|
||||
"""
|
||||
self.agent_starting_pos = {agent.idx: [] for agent in self.world.agents}
|
||||
|
||||
self.fixed_four_skill_and_loc = bool(fixed_four_skill_and_loc)
|
||||
if self.fixed_four_skill_and_loc:
|
||||
bm = self.get_component("Build")
|
||||
assert bm.skill_dist == "pareto"
|
||||
pmsm = bm.payment_max_skill_multiplier
|
||||
|
||||
# Temporarily switch to a fixed seed for controlling randomness
|
||||
seed_state = np.random.get_state()
|
||||
np.random.seed(seed=1)
|
||||
|
||||
# Generate a batch (100000) of num_agents (sorted/clipped) Pareto samples.
|
||||
pareto_samples = np.random.pareto(4, size=(100000, self.n_agents))
|
||||
clipped_skills = np.minimum(pmsm, (pmsm - 1) * pareto_samples + 1)
|
||||
sorted_clipped_skills = np.sort(clipped_skills, axis=1)
|
||||
# The skill level of the i-th skill-ranked agent is the average of the
|
||||
# i-th ranked samples throughout the batch.
|
||||
average_ranked_skills = sorted_clipped_skills.mean(axis=0)
|
||||
self._avg_ranked_skill = average_ranked_skills * bm.payment
|
||||
|
||||
np.random.set_state(seed_state)
|
||||
|
||||
# Fill in the starting location associated with each skill rank
|
||||
starting_ranked_locs = [
|
||||
# Worst group of agents goes in top right
|
||||
(0, self.world_size[1] - 1),
|
||||
# Second-worst group of agents goes in bottom left
|
||||
(self.world_size[0] - 1, 0),
|
||||
# Second-best group of agents goes in top left
|
||||
(0, 0),
|
||||
# Best group of agents goes in bottom right
|
||||
(self.world_size[1] - 1, self.world_size[1] - 1),
|
||||
]
|
||||
self._ranked_locs = []
|
||||
|
||||
# Based on skill, assign each agent to one of the location groups
|
||||
skill_groups = np.floor(
|
||||
np.arange(self.n_agents) * (4 / self.n_agents),
|
||||
).astype(np.int)
|
||||
n_in_group = np.zeros(4, dtype=np.int)
|
||||
for g in skill_groups:
|
||||
# The position within the group is given by the number of agents
|
||||
# counted in the group thus far.
|
||||
g_pos = n_in_group[g]
|
||||
|
||||
# Top right
|
||||
if g == 0:
|
||||
r = starting_ranked_locs[g][0] + (g_pos // 4)
|
||||
c = starting_ranked_locs[g][1] - (g_pos % 4)
|
||||
self._ranked_locs.append((r, c))
|
||||
|
||||
# Bottom left
|
||||
elif g == 1:
|
||||
r = starting_ranked_locs[g][0] - (g_pos // 4)
|
||||
c = starting_ranked_locs[g][1] + (g_pos % 4)
|
||||
self._ranked_locs.append((r, c))
|
||||
|
||||
# Top left
|
||||
elif g == 2:
|
||||
r = starting_ranked_locs[g][0] + (g_pos // 4)
|
||||
c = starting_ranked_locs[g][1] + (g_pos % 4)
|
||||
self._ranked_locs.append((r, c))
|
||||
|
||||
# Bottom right
|
||||
elif g == 3:
|
||||
r = starting_ranked_locs[g][0] - (g_pos // 4)
|
||||
c = starting_ranked_locs[g][1] - (g_pos % 4)
|
||||
self._ranked_locs.append((r, c))
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
# Count the agent we just placed.
|
||||
n_in_group[g] = n_in_group[g] + 1
|
||||
|
||||
@property
|
||||
def energy_weight(self):
|
||||
"""
|
||||
Energy annealing progress. Multiply with self.energy_cost to get the
|
||||
effective energy coefficient.
|
||||
"""
|
||||
if self.energy_warmup_constant <= 0.0:
|
||||
return 1.0
|
||||
|
||||
if self.energy_warmup_method == "decay":
|
||||
return float(1.0 - np.exp(-self._completions / self.energy_warmup_constant))
|
||||
|
||||
if self.energy_warmup_method == "auto":
|
||||
return float(
|
||||
1.0
|
||||
- np.exp(-self._auto_warmup_integrator / self.energy_warmup_constant)
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def get_current_optimization_metrics(self):
|
||||
"""
|
||||
Compute optimization metrics based on the current state. Used to compute reward.
|
||||
|
||||
Returns:
|
||||
curr_optimization_metric (dict): A dictionary of {agent.idx: metric}
|
||||
with an entry for each agent (including the planner) in the env.
|
||||
"""
|
||||
curr_optimization_metric = {}
|
||||
# (for agents)
|
||||
for agent in self.world.agents:
|
||||
curr_optimization_metric[agent.idx] = rewards.isoelastic_coin_minus_labor(
|
||||
coin_endowment=agent.total_endowment("Coin"),
|
||||
total_labor=agent.state["endogenous"]["Labor"],
|
||||
isoelastic_eta=self.isoelastic_eta,
|
||||
labor_coefficient=self.energy_weight * self.energy_cost,
|
||||
)
|
||||
# (for the planner)
|
||||
if self.planner_reward_type == "coin_eq_times_productivity":
|
||||
curr_optimization_metric[
|
||||
self.world.planner.idx
|
||||
] = rewards.coin_eq_times_productivity(
|
||||
coin_endowments=np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
),
|
||||
equality_weight=1 - self.mixing_weight_gini_vs_coin,
|
||||
)
|
||||
elif self.planner_reward_type == "inv_income_weighted_coin_endowments":
|
||||
curr_optimization_metric[
|
||||
self.world.planner.idx
|
||||
] = rewards.inv_income_weighted_coin_endowments(
|
||||
coin_endowments=np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
)
|
||||
)
|
||||
elif self.planner_reward_type == "inv_income_weighted_utility":
|
||||
curr_optimization_metric[
|
||||
self.world.planner.idx
|
||||
] = rewards.inv_income_weighted_utility(
|
||||
coin_endowments=np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
),
|
||||
utilities=np.array(
|
||||
[curr_optimization_metric[agent.idx] for agent in self.world.agents]
|
||||
),
|
||||
)
|
||||
else:
|
||||
print("No valid planner reward selected!")
|
||||
raise NotImplementedError
|
||||
return curr_optimization_metric
|
||||
|
||||
# The following methods must be implemented for each scenario
|
||||
# -----------------------------------------------------------
|
||||
|
||||
def reset_starting_layout(self):
|
||||
"""
|
||||
Part 1/2 of scenario reset. This method handles resetting the state of the
|
||||
environment managed by the scenario (i.e. resource & landmark layout).
|
||||
|
||||
Here, reset to the layout in the fixed layout file
|
||||
"""
|
||||
self.world.maps.clear()
|
||||
for landmark, landmark_map in self._source_maps.items():
|
||||
self.world.maps.set(landmark, landmark_map)
|
||||
if landmark in ["Stone", "Wood"]:
|
||||
self.world.maps.set(landmark + "SourceBlock", landmark_map)
|
||||
|
||||
def reset_agent_states(self):
|
||||
"""
|
||||
Part 2/2 of scenario reset. This method handles resetting the state of the
|
||||
agents themselves (i.e. inventory, locations, etc.).
|
||||
|
||||
Here, empty inventories and place mobile agents in random, accessible
|
||||
locations to start. Note: If using fixed_four_skill_and_loc, the starting
|
||||
locations will be overridden in self.additional_reset_steps.
|
||||
"""
|
||||
self.world.clear_agent_locs()
|
||||
for agent in self.world.agents:
|
||||
agent.state["inventory"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["escrow"] = {k: 0 for k in agent.inventory.keys()}
|
||||
agent.state["endogenous"] = {k: 0 for k in agent.endogenous.keys()}
|
||||
# Add starting coin
|
||||
agent.state["inventory"]["Coin"] = float(self.starting_agent_coin)
|
||||
|
||||
self.world.planner.state["inventory"] = {
|
||||
k: 0 for k in self.world.planner.inventory.keys()
|
||||
}
|
||||
self.world.planner.state["escrow"] = {
|
||||
k: 0 for k in self.world.planner.escrow.keys()
|
||||
}
|
||||
|
||||
for agent in self.world.agents:
|
||||
r = np.random.randint(0, self.world_size[0])
|
||||
c = np.random.randint(0, self.world_size[1])
|
||||
n_tries = 0
|
||||
while not self.world.can_agent_occupy(r, c, agent):
|
||||
r = np.random.randint(0, self.world_size[0])
|
||||
c = np.random.randint(0, self.world_size[1])
|
||||
n_tries += 1
|
||||
if n_tries > 200:
|
||||
raise TimeoutError
|
||||
r, c = self.world.set_agent_loc(agent, r, c)
|
||||
|
||||
def scenario_step(self):
|
||||
"""
|
||||
Update the state of the world according to whatever rules this scenario
|
||||
implements.
|
||||
|
||||
This gets called in the 'step' method (of base_env) after going through each
|
||||
component step and before generating observations, rewards, etc.
|
||||
|
||||
In this class of scenarios, the scenario step handles stochastic resource
|
||||
regeneration.
|
||||
"""
|
||||
|
||||
resources = ["Wood", "Stone"]
|
||||
|
||||
for resource in resources:
|
||||
d = 1 + (2 * self.layout_specs[resource]["regen_halfwidth"])
|
||||
kernel = (
|
||||
self.layout_specs[resource]["regen_weight"] * np.ones((d, d)) / (d ** 2)
|
||||
)
|
||||
|
||||
resource_map = self.world.maps.get(resource)
|
||||
resource_source_blocks = self.world.maps.get(resource + "SourceBlock")
|
||||
spawnable = (
|
||||
self.world.maps.empty + resource_map + resource_source_blocks
|
||||
) > 0
|
||||
spawnable *= resource_source_blocks > 0
|
||||
|
||||
health = np.maximum(resource_map, resource_source_blocks)
|
||||
respawn = np.random.rand(*health.shape) < signal.convolve2d(
|
||||
health, kernel, "same"
|
||||
)
|
||||
respawn *= spawnable
|
||||
|
||||
self.world.maps.set(
|
||||
resource,
|
||||
np.minimum(
|
||||
resource_map + respawn, self.layout_specs[resource]["max_health"]
|
||||
),
|
||||
)
|
||||
|
||||
def generate_observations(self):
|
||||
"""
|
||||
Generate observations associated with this scenario.
|
||||
|
||||
A scenario does not need to produce observations and can provide observations
|
||||
for only some agent types; however, for a given agent type, it should either
|
||||
always or never yield an observation. If it does yield an observation,
|
||||
that observation should always have the same structure/sizes!
|
||||
|
||||
Returns:
|
||||
obs (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent (which can including
|
||||
the planner) for which this scenario provides an observation. For each
|
||||
entry, the key specifies the index of the agent and the value contains
|
||||
its associated observation dictionary.
|
||||
|
||||
Here, non-planner agents receive spatial observations (depending on the env
|
||||
config) as well as the contents of their inventory and endogenous quantities.
|
||||
The planner also receives spatial observations (again, depending on the env
|
||||
config) as well as the inventory of each of the mobile agents.
|
||||
"""
|
||||
obs = {}
|
||||
curr_map = self.world.maps.state
|
||||
|
||||
owner_map = self.world.maps.owner_state
|
||||
loc_map = self.world.loc_map
|
||||
agent_idx_maps = np.concatenate([owner_map, loc_map[None, :, :]], axis=0)
|
||||
agent_idx_maps += 2
|
||||
agent_idx_maps[agent_idx_maps == 1] = 0
|
||||
|
||||
agent_locs = {
|
||||
str(agent.idx): {
|
||||
"loc-row": agent.loc[0] / self.world_size[0],
|
||||
"loc-col": agent.loc[1] / self.world_size[1],
|
||||
}
|
||||
for agent in self.world.agents
|
||||
}
|
||||
agent_invs = {
|
||||
str(agent.idx): {
|
||||
"inventory-" + k: v * self.inv_scale for k, v in agent.inventory.items()
|
||||
}
|
||||
for agent in self.world.agents
|
||||
}
|
||||
|
||||
obs[self.world.planner.idx] = {
|
||||
"inventory-" + k: v * self.inv_scale
|
||||
for k, v in self.world.planner.inventory.items()
|
||||
}
|
||||
if self._planner_gets_spatial_info:
|
||||
obs[self.world.planner.idx].update(
|
||||
dict(map=curr_map, idx_map=agent_idx_maps)
|
||||
)
|
||||
|
||||
# Mobile agents see the full map. Convey location info via one-hot map channels.
|
||||
if self._full_observability:
|
||||
for agent in self.world.agents:
|
||||
my_map = np.array(agent_idx_maps)
|
||||
my_map[my_map == int(agent.idx) + 2] = 1
|
||||
sidx = str(agent.idx)
|
||||
obs[sidx] = {"map": curr_map, "idx_map": my_map}
|
||||
obs[sidx].update(agent_invs[sidx])
|
||||
|
||||
# Mobile agents only see within a window around their position
|
||||
else:
|
||||
w = (
|
||||
self._mobile_agent_observation_range
|
||||
) # View halfwidth (only applicable without full observability)
|
||||
|
||||
padded_map = np.pad(
|
||||
curr_map,
|
||||
[(0, 1), (w, w), (w, w)],
|
||||
mode="constant",
|
||||
constant_values=[(0, 1), (0, 0), (0, 0)],
|
||||
)
|
||||
|
||||
padded_idx = np.pad(
|
||||
agent_idx_maps,
|
||||
[(0, 0), (w, w), (w, w)],
|
||||
mode="constant",
|
||||
constant_values=[(0, 0), (0, 0), (0, 0)],
|
||||
)
|
||||
|
||||
for agent in self.world.agents:
|
||||
r, c = [c + w for c in agent.loc]
|
||||
visible_map = padded_map[
|
||||
:, (r - w) : (r + w + 1), (c - w) : (c + w + 1)
|
||||
]
|
||||
visible_idx = np.array(
|
||||
padded_idx[:, (r - w) : (r + w + 1), (c - w) : (c + w + 1)]
|
||||
)
|
||||
|
||||
visible_idx[visible_idx == int(agent.idx) + 2] = 1
|
||||
|
||||
sidx = str(agent.idx)
|
||||
|
||||
obs[sidx] = {"map": visible_map, "idx_map": visible_idx}
|
||||
obs[sidx].update(agent_locs[sidx])
|
||||
obs[sidx].update(agent_invs[sidx])
|
||||
|
||||
# Agent-wise planner info (gets crunched into the planner obs in the
|
||||
# base scenario code)
|
||||
obs["p" + sidx] = agent_invs[sidx]
|
||||
if self._planner_gets_spatial_info:
|
||||
obs["p" + sidx].update(agent_locs[sidx])
|
||||
|
||||
return obs
|
||||
|
||||
def compute_reward(self):
|
||||
"""
|
||||
Apply the reward function(s) associated with this scenario to get the rewards
|
||||
from this step.
|
||||
|
||||
Returns:
|
||||
rew (dict): A dictionary of {agent.idx: agent_obs_dict}. In words,
|
||||
return a dictionary with an entry for each agent in the environment
|
||||
(including the planner). For each entry, the key specifies the index of
|
||||
the agent and the value contains the scalar reward earned this timestep.
|
||||
|
||||
Rewards are computed as the marginal utility (agents) or marginal social
|
||||
welfare (planner) experienced on this timestep. Ignoring discounting,
|
||||
this means that agents' (planner's) objective is to maximize the utility
|
||||
(social welfare) associated with the terminal state of the episode.
|
||||
"""
|
||||
|
||||
# "curr_optimization_metric" hasn't been updated yet, so it gives us the
|
||||
# utility from the last step.
|
||||
utility_at_end_of_last_time_step = deepcopy(self.curr_optimization_metric)
|
||||
|
||||
# compute current objectives and store the values
|
||||
self.curr_optimization_metric = self.get_current_optimization_metrics()
|
||||
|
||||
# reward = curr - prev objectives
|
||||
rew = {
|
||||
k: float(v - utility_at_end_of_last_time_step[k])
|
||||
for k, v in self.curr_optimization_metric.items()
|
||||
}
|
||||
|
||||
# store the previous objective values
|
||||
self.prev_optimization_metric.update(utility_at_end_of_last_time_step)
|
||||
|
||||
# Automatic Energy Cost Annealing
|
||||
# -------------------------------
|
||||
avg_agent_rew = np.mean([rew[a.idx] for a in self.world.agents])
|
||||
# Count the number of timesteps where the avg agent reward was > 0
|
||||
if avg_agent_rew > 0:
|
||||
self._auto_warmup_integrator += 1
|
||||
|
||||
return rew
|
||||
|
||||
# Optional methods for customization
|
||||
# ----------------------------------
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
Extra scenario-specific steps that should be performed at the end of the reset
|
||||
cycle.
|
||||
|
||||
For each reset cycle...
|
||||
First, reset_starting_layout() and reset_agent_states() will be called.
|
||||
|
||||
Second, <component>.reset() will be called for each registered component.
|
||||
|
||||
Lastly, this method will be called to allow for any final customization of
|
||||
the reset cycle.
|
||||
|
||||
For this scenario, this method resets optimization metric trackers. If using
|
||||
fixed_four_skill_and_loc, this is where each agent gets assigned to one of
|
||||
the four fixed skill/loc combinations. The agent-->skill/loc assignment is
|
||||
permuted so that all four skill/loc combinations are used.
|
||||
"""
|
||||
if self.fixed_four_skill_and_loc:
|
||||
self.world.clear_agent_locs()
|
||||
for i, agent in enumerate(self.world.get_random_order_agents()):
|
||||
self.world.set_agent_loc(agent, *self._ranked_locs[i])
|
||||
agent.state["build_payment"] = self._avg_ranked_skill[i]
|
||||
|
||||
# compute current objectives
|
||||
curr_optimization_metric = self.get_current_optimization_metrics()
|
||||
|
||||
self.curr_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
self.init_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
self.prev_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
|
||||
def scenario_metrics(self):
|
||||
"""
|
||||
Allows the scenario to generate metrics (collected along with component metrics
|
||||
in the 'metrics' property).
|
||||
|
||||
To have the scenario add metrics, this function needs to return a dictionary of
|
||||
{metric_key: value} where 'value' is a scalar (no nesting or lists!)
|
||||
|
||||
Here, summarize social metrics, endowments, utilities, and labor cost annealing.
|
||||
"""
|
||||
metrics = dict()
|
||||
|
||||
coin_endowments = np.array(
|
||||
[agent.total_endowment("Coin") for agent in self.world.agents]
|
||||
)
|
||||
metrics["social/productivity"] = social_metrics.get_productivity(
|
||||
coin_endowments
|
||||
)
|
||||
metrics["social/equality"] = social_metrics.get_equality(coin_endowments)
|
||||
|
||||
utilities = np.array(
|
||||
[self.curr_optimization_metric[agent.idx] for agent in self.world.agents]
|
||||
)
|
||||
metrics[
|
||||
"social_welfare/coin_eq_times_productivity"
|
||||
] = rewards.coin_eq_times_productivity(
|
||||
coin_endowments=coin_endowments, equality_weight=1.0
|
||||
)
|
||||
metrics[
|
||||
"social_welfare/inv_income_weighted_coin_endow"
|
||||
] = rewards.inv_income_weighted_coin_endowments(coin_endowments=coin_endowments)
|
||||
metrics[
|
||||
"social_welfare/inv_income_weighted_utility"
|
||||
] = rewards.inv_income_weighted_utility(
|
||||
coin_endowments=coin_endowments, utilities=utilities
|
||||
)
|
||||
|
||||
for agent in self.all_agents:
|
||||
for resource, quantity in agent.inventory.items():
|
||||
metrics[
|
||||
"endow/{}/{}".format(agent.idx, resource)
|
||||
] = agent.total_endowment(resource)
|
||||
|
||||
if agent.endogenous is not None:
|
||||
for resource, quantity in agent.endogenous.items():
|
||||
metrics["endogenous/{}/{}".format(agent.idx, resource)] = quantity
|
||||
|
||||
metrics["util/{}".format(agent.idx)] = self.curr_optimization_metric[
|
||||
agent.idx
|
||||
]
|
||||
|
||||
# Labor weight
|
||||
metrics["labor/weighted_cost"] = self.energy_cost * self.energy_weight
|
||||
metrics["labor/warmup_integrator"] = int(self._auto_warmup_integrator)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@scenario_registry.add
|
||||
class SplitLayout(LayoutFromFile):
|
||||
"""
|
||||
Extends layout_from_file/simple_wood_and_stone to impose a row of water midway
|
||||
through the map, uses a fixed set of pareto-distributed building skills (requires a
|
||||
Build component), and places agents in the top/bottom depending on skill rank.
|
||||
|
||||
Args:
|
||||
water_row (int): Row of the map where the water barrier is placed. Defaults
|
||||
to half the world height.
|
||||
skill_rank_of_top_agents (int, float, tuple, list): Index/indices specifying
|
||||
which agent(s) to place in the top of the map. Indices refer to the skill
|
||||
ranking, with 0 referring to the highest-skilled agent. Defaults to only
|
||||
the highest-skilled agent in the top.
|
||||
planner_gets_spatial_obs (bool): Whether the planner agent receives spatial
|
||||
observations from the world.
|
||||
full_observability (bool): Whether the mobile agents' spatial observation
|
||||
includes the full world view or is instead an egocentric view.
|
||||
mobile_agent_observation_range (int): If not using full_observability,
|
||||
the spatial range (on each side of the agent) that is visible in the
|
||||
spatial observations.
|
||||
env_layout_file (str): Name of the layout file in ./map_txt/ to use.
|
||||
Note: The world dimensions of that layout must match the world dimensions
|
||||
argument used to construct the environment.
|
||||
resource_regen_prob (float): Probability that an empty source tile will
|
||||
regenerate a new resource unit.
|
||||
starting_agent_coin (int, float): Amount of coin agents have at t=0. Defaults
|
||||
to zero coin.
|
||||
isoelastic_eta (float): Parameter controlling the shape of agent utility
|
||||
wrt coin endowment.
|
||||
energy_cost (float): Coefficient for converting labor to negative utility.
|
||||
energy_warmup_constant (float): Decay constant that controls the rate at which
|
||||
the effective energy cost is annealed from 0 to energy_cost. Set to 0
|
||||
(default) to disable annealing, meaning that the effective energy cost is
|
||||
always energy_cost. The units of the decay constant depend on the choice of
|
||||
energy_warmup_method.
|
||||
energy_warmup_method (str): How to schedule energy annealing (warmup). If
|
||||
"decay" (default), use the number of completed episodes. If "auto",
|
||||
use the number of timesteps where the average agent reward was positive.
|
||||
planner_reward_type (str): The type of reward used for the planner. Options
|
||||
are "coin_eq_times_productivity" (default),
|
||||
"inv_income_weighted_coin_endowment", and "inv_income_weighted_utility".
|
||||
mixing_weight_gini_vs_coin (float): Degree to which equality is ignored w/
|
||||
"coin_eq_times_productivity". Default is 0, which weights equality and
|
||||
productivity equally. If set to 1, only productivity is rewarded.
|
||||
"""
|
||||
|
||||
name = "split_layout/simple_wood_and_stone"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
water_row=None,
|
||||
skill_rank_of_top_agents=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.fixed_four_skill_and_loc:
|
||||
raise ValueError(
|
||||
"The split layout scenario does not support "
|
||||
"fixed_four_skill_and_loc. Set this to False."
|
||||
)
|
||||
|
||||
# Augment the fixed layout to include a row of water through the middle
|
||||
if water_row is None:
|
||||
self._water_line = self.world_size[0] // 2
|
||||
else:
|
||||
self._water_line = int(water_row)
|
||||
assert 0 < self._water_line < self.world_size[0] - 1
|
||||
for landmark, landmark_map in self._source_maps.items():
|
||||
landmark_map[self._water_line, :] = 1 if landmark == "Water" else 0
|
||||
self._source_maps[landmark] = landmark_map
|
||||
|
||||
# Controls logic for which agents (by skill rank) get placed on the top
|
||||
if skill_rank_of_top_agents is None:
|
||||
skill_rank_of_top_agents = [0]
|
||||
|
||||
if isinstance(skill_rank_of_top_agents, (int, float)):
|
||||
self.skill_rank_of_top_agents = [int(skill_rank_of_top_agents)]
|
||||
elif isinstance(skill_rank_of_top_agents, (tuple, list)):
|
||||
self.skill_rank_of_top_agents = list(set(skill_rank_of_top_agents))
|
||||
else:
|
||||
raise TypeError(
|
||||
"skill_rank_of_top_agents must be a scalar "
|
||||
"index, or a list of scalar indices."
|
||||
)
|
||||
for rank in self.skill_rank_of_top_agents:
|
||||
assert 0 <= rank < self.n_agents
|
||||
assert 0 < len(self.skill_rank_of_top_agents) < self.n_agents
|
||||
|
||||
# Set the skill associated with each skill rank
|
||||
bm = self.get_component("Build")
|
||||
assert bm.skill_dist == "pareto"
|
||||
pmsm = bm.payment_max_skill_multiplier
|
||||
# Generate a batch (100000) of num_agents (sorted/clipped) Pareto samples.
|
||||
pareto_samples = np.random.pareto(4, size=(100000, self.n_agents))
|
||||
clipped_skills = np.minimum(pmsm, (pmsm - 1) * pareto_samples + 1)
|
||||
sorted_clipped_skills = np.sort(clipped_skills, axis=1)
|
||||
# The skill level of the i-th skill-ranked agent is the average of the
|
||||
# i-th ranked samples throughout the batch.
|
||||
average_ranked_skills = sorted_clipped_skills.mean(axis=0)
|
||||
self._avg_ranked_skill = average_ranked_skills * bm.payment
|
||||
# Reverse the order so index 0 is the highest-skilled
|
||||
self._avg_ranked_skill = self._avg_ranked_skill[::-1]
|
||||
|
||||
def additional_reset_steps(self):
|
||||
"""
|
||||
Extra scenario-specific steps that should be performed at the end of the reset
|
||||
cycle.
|
||||
|
||||
For each reset cycle...
|
||||
First, reset_starting_layout() and reset_agent_states() will be called.
|
||||
|
||||
Second, <component>.reset() will be called for each registered component.
|
||||
|
||||
Lastly, this method will be called to allow for any final customization of
|
||||
the reset cycle.
|
||||
|
||||
For this scenario, this method resets optimization metric trackers. This is
|
||||
where each agent gets assigned to one of the skills and the starting
|
||||
locations are reset according to self.skill_rank_of_top_agents.
|
||||
"""
|
||||
self.world.clear_agent_locs()
|
||||
for i, agent in enumerate(self.world.get_random_order_agents()):
|
||||
agent.state["build_payment"] = self._avg_ranked_skill[i]
|
||||
if i in self.skill_rank_of_top_agents:
|
||||
r_min, r_max = 0, self._water_line
|
||||
else:
|
||||
r_min, r_max = self._water_line + 1, self.world_size[0]
|
||||
|
||||
r = np.random.randint(r_min, r_max)
|
||||
c = np.random.randint(0, self.world_size[1])
|
||||
n_tries = 0
|
||||
while not self.world.can_agent_occupy(r, c, agent):
|
||||
r = np.random.randint(r_min, r_max)
|
||||
c = np.random.randint(0, self.world_size[1])
|
||||
n_tries += 1
|
||||
if n_tries > 200:
|
||||
raise TimeoutError
|
||||
self.world.set_agent_loc(agent, r, c)
|
||||
|
||||
# compute current objectives
|
||||
curr_optimization_metric = self.get_current_optimization_metrics()
|
||||
|
||||
self.curr_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
self.init_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
self.prev_optimization_metric = deepcopy(curr_optimization_metric)
|
||||
@@ -0,0 +1 @@
|
||||
WWWWW W @ W ;WW W @ W W WW; W @ W W;SW S S @ ; @ W W ; SS @ ; S @ ; @ ; @ ;S S @ ; @ ; @ ;@@@@@@@@@@@@@@@@@@@@@@@@@; S @ ;S S @ ; S @ ;SS @ ; SS @ ;SS @ ; S @ ; @ ; @ ; @ ; @ ;S @ ;
|
||||
@@ -0,0 +1 @@
|
||||
WWW @ ;WSS @ ;WWW @ ;WWW @ ;WSS ;SWS ; @ ;@@@@ @@@@ @@@; @ ;WWW @ S ; WW SS ;WWW SS;W W @ ;W @ ; @ S ;
|
||||
@@ -0,0 +1 @@
|
||||
WW W @ ; SWWW @ ;SSWW @ ;WSSSW @ ;WSSWW ;WS WS ; WWS ;SWW S @ ; S W @ ; WS W @ ; @ ; @ ;@@@@ @@@@@@@@@@@@ @@@; @ ; @ ; W @ SSSSS; WW @ SSS ; @ S SS; WW W @ SSS S;W WW SSSS ;WWW S SS; WWWW S S ;WW W @ S ; W @ S ; W @ S SSS;
|
||||
@@ -0,0 +1 @@
|
||||
W WW ; W ; W W W ; W ; W W ; W W W ; W W ; WW ; ; ; ; ; ; ; ; ; SS W ; WW W; ; S S S; W WW S ; S S W ; W WS S S ; ; ; ; ; ; ; ; ; S ;S S ; ; SS ; ; S SS S ; SS S S ;S SS S S ;SSSS S SS ;
|
||||
@@ -0,0 +1 @@
|
||||
WWWWW W @ W ;WW W @ W W WW; W @ W W;SW S S @ ; @ W W ; SS ; S ; ; ;S S @ ; @ ; @ ;@@@@@ @@@@@@@ @@@@@; S @ ;S S @ ; S @ ;SS ; SS ;SS ; S ; @ ; @ ; @ ; @ ;S @ ;
|
||||
@@ -0,0 +1 @@
|
||||
WWWWW W W ;WW W W W WW; W W W;SW S S ; W W ; SS ; S ; ; ;S S ; ; ; ; S ;S S ; S ;SS ; SS ;SS ; S ; ; ; ; ;S ;
|
||||
@@ -0,0 +1 @@
|
||||
WWWWWWWW WW @@ WW ;WWWWWWWW WW @@ WW ;WWW W @@ WW W WWW; W @@ WW WW; W @@ WW WW;SSW S SS @@ ; @@ W W ; @@ W W ; SSS ; SSS ; S ; ; ; ;SS SS @@ ;SS SS @@ ; @@ ; @@ ; @@ ;@@@@@@@@ @@@@@@@@@@@@ @@@@@@@@;@@@@@@@@ @@@@@@@@@@@@ @@@@@@@@; S @@ ;SS SS @@ ;SS SS @@ ; S @@ ; S @@ ;SSS ; SSS ; SSS ;SSS ; SS ; SS ; @@ ; @@ ; @@ ; @@ ; @@ ; @@ ;SS @@ ;SS @@ ;
|
||||
@@ -0,0 +1 @@
|
||||
WWWWWWWW WW WW ;WWWWWWWW WW WW ;WWW W WW W WWW; W WW WW; W WW WW;SSW S SS ; W W ; W W ; SSS ; SSS ; S ; ; ; ;SS SS ;SS SS ; ; ; ; ; ; S ;SS SS ;SS SS ; S ; S ;SSS ; SSS ; SSS ;SSS ; SS ; SS ; ; ; ; ; ; ;SS ;SS ;
|
||||
@@ -0,0 +1 @@
|
||||
WWWW@WW W; WW @ WWW;SW S@ S ;S @ ;@@ @@@ @@; S @ ;S S @ ;S @ ; S@ ;
|
||||
@@ -0,0 +1 @@
|
||||
WW W ; A WW A ;WW WA W ; A A ; ; AA A ; A ; ; ; ; A ; A SS S;SA S AS;S SA SA
|
||||
@@ -0,0 +1 @@
|
||||
SSSSS; SS SSSS; SS SSSS; S SSSS; SS ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; W; W WW; WW ;W WW W W WWW;WWWWW WW W WWWWW;
|
||||
5
ai_economist/foundation/scenarios/utils/__init__.py
Normal file
5
ai_economist/foundation/scenarios/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
133
ai_economist/foundation/scenarios/utils/rewards.py
Normal file
133
ai_economist/foundation/scenarios/utils/rewards.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai_economist.foundation.scenarios.utils import social_metrics
|
||||
|
||||
|
||||
def isoelastic_coin_minus_labor(
|
||||
coin_endowment, total_labor, isoelastic_eta, labor_coefficient
|
||||
):
|
||||
"""Agent utility, concave increasing in coin and linearly decreasing in labor.
|
||||
|
||||
Args:
|
||||
coin_endowment (float, ndarray): The amount of coin owned by the agent(s).
|
||||
total_labor (float, ndarray): The amount of labor performed by the agent(s).
|
||||
isoelastic_eta (float): Constant describing the shape of the utility profile
|
||||
with respect to coin endowment. Must be between 0 and 1. 0 yields utility
|
||||
that increases linearly with coin. 1 yields utility that increases with
|
||||
log(coin). Utility from coin uses:
|
||||
https://en.wikipedia.org/wiki/Isoelastic_utility
|
||||
labor_coefficient (float): Constant describing the disutility experienced per
|
||||
unit of labor performed. Disutility from labor equals:
|
||||
labor_coefficient * total_labor
|
||||
|
||||
Returns:
|
||||
Agent utility (float) or utilities (ndarray).
|
||||
"""
|
||||
# https://en.wikipedia.org/wiki/Isoelastic_utility
|
||||
assert np.all(coin_endowment >= 0)
|
||||
assert 0 <= isoelastic_eta <= 1.0
|
||||
|
||||
# Utility from coin endowment
|
||||
if isoelastic_eta == 1.0: # dangerous
|
||||
util_c = np.log(np.max(1, coin_endowment))
|
||||
else: # isoelastic_eta >= 0
|
||||
util_c = (coin_endowment ** (1 - isoelastic_eta) - 1) / (1 - isoelastic_eta)
|
||||
|
||||
# disutility from labor
|
||||
util_l = total_labor * labor_coefficient
|
||||
|
||||
# Net utility
|
||||
util = util_c - util_l
|
||||
|
||||
return util
|
||||
|
||||
|
||||
def coin_minus_labor_cost(
|
||||
coin_endowment, total_labor, labor_exponent, labor_coefficient
|
||||
):
|
||||
"""Agent utility, linearly increasing in coin and decreasing as a power of labor.
|
||||
|
||||
Args:
|
||||
coin_endowment (float, ndarray): The amount of coin owned by the agent(s).
|
||||
total_labor (float, ndarray): The amount of labor performed by the agent(s).
|
||||
labor_exponent (float): Constant describing the shape of the utility profile
|
||||
with respect to total labor. Must be between >1.
|
||||
labor_coefficient (float): Constant describing the disutility experienced per
|
||||
unit of labor performed. Disutility from labor equals:
|
||||
labor_coefficient * total_labor.
|
||||
|
||||
Returns:
|
||||
Agent utility (float) or utilities (ndarray).
|
||||
"""
|
||||
# https://en.wikipedia.org/wiki/Isoelastic_utility
|
||||
assert np.all(coin_endowment >= 0)
|
||||
assert labor_exponent > 1
|
||||
|
||||
# Utility from coin endowment
|
||||
util_c = coin_endowment
|
||||
|
||||
# Disutility from labor
|
||||
util_l = (total_labor ** labor_exponent) * labor_coefficient
|
||||
|
||||
# Net utility
|
||||
util = util_c - util_l
|
||||
|
||||
return util
|
||||
|
||||
|
||||
def coin_eq_times_productivity(coin_endowments, equality_weight):
|
||||
"""Social welfare, measured as productivity scaled by the degree of coin equality.
|
||||
|
||||
Args:
|
||||
coin_endowments (ndarray): The array of coin endowments for each of the
|
||||
agents in the simulated economy.
|
||||
equality_weight (float): Constant that determines how productivity is scaled
|
||||
by coin equality. Must be between 0 (SW = prod) and 1 (SW = prod * eq).
|
||||
|
||||
Returns:
|
||||
Product of coin equality and productivity (float).
|
||||
"""
|
||||
n_agents = len(coin_endowments)
|
||||
prod = social_metrics.get_productivity(coin_endowments) / n_agents
|
||||
equality = equality_weight * social_metrics.get_equality(coin_endowments) + (
|
||||
1 - equality_weight
|
||||
)
|
||||
return equality * prod
|
||||
|
||||
|
||||
def inv_income_weighted_coin_endowments(coin_endowments):
|
||||
"""Social welfare, as weighted average endowment (weighted by inverse endowment).
|
||||
|
||||
Args:
|
||||
coin_endowments (ndarray): The array of coin endowments for each of the
|
||||
agents in the simulated economy.
|
||||
|
||||
Returns:
|
||||
Weighted average coin endowment (float).
|
||||
"""
|
||||
pareto_weights = 1 / np.maximum(coin_endowments, 1)
|
||||
pareto_weights = pareto_weights / np.sum(pareto_weights)
|
||||
return np.sum(coin_endowments * pareto_weights)
|
||||
|
||||
|
||||
def inv_income_weighted_utility(coin_endowments, utilities):
|
||||
"""Social welfare, as weighted average utility (weighted by inverse endowment).
|
||||
|
||||
Args:
|
||||
coin_endowments (ndarray): The array of coin endowments for each of the
|
||||
agents in the simulated economy.
|
||||
utilities (ndarray): The array of utilities for each of the agents in the
|
||||
simulated economy.
|
||||
|
||||
Returns:
|
||||
Weighted average utility (float).
|
||||
"""
|
||||
pareto_weights = 1 / np.maximum(coin_endowments, 1)
|
||||
pareto_weights = pareto_weights / np.sum(pareto_weights)
|
||||
return np.sum(utilities * pareto_weights)
|
||||
75
ai_economist/foundation/scenarios/utils/social_metrics.py
Normal file
75
ai_economist/foundation/scenarios/utils/social_metrics.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_gini(endowments):
|
||||
"""Returns the normalized Gini index describing the distribution of endowments.
|
||||
|
||||
https://en.wikipedia.org/wiki/Gini_coefficient
|
||||
|
||||
Args:
|
||||
endowments (ndarray): The array of endowments for each of the agents in the
|
||||
simulated economy.
|
||||
|
||||
Returns:
|
||||
Normalized Gini index for the distribution of endowments (float). A value of 1
|
||||
indicates everything belongs to 1 agent (perfect inequality), whereas a
|
||||
value of 0 indicates all agents have equal endowments (perfect equality).
|
||||
|
||||
Note:
|
||||
Uses a slightly different method depending on the number of agents. For fewer
|
||||
agents (<30), uses an exact but slow method. Switches to using a much faster
|
||||
method for more agents, where both methods produce approximately equivalent
|
||||
results.
|
||||
"""
|
||||
n_agents = len(endowments)
|
||||
|
||||
if n_agents < 30: # Slower. Accurate for all n.
|
||||
diff_ij = np.abs(
|
||||
endowments.reshape((n_agents, 1)) - endowments.reshape((1, n_agents))
|
||||
)
|
||||
diff = np.sum(diff_ij)
|
||||
norm = 2 * n_agents * endowments.sum(axis=0)
|
||||
unscaled_gini = diff / (norm + 1e-10)
|
||||
gini = unscaled_gini / ((n_agents - 1) / n_agents)
|
||||
return gini
|
||||
|
||||
# Much faster. Slightly overestimated for low n.
|
||||
s_endows = np.sort(endowments)
|
||||
return 1 - (2 / (n_agents + 1)) * np.sum(
|
||||
np.cumsum(s_endows) / (np.sum(s_endows) + 1e-10)
|
||||
)
|
||||
|
||||
|
||||
def get_equality(endowments):
|
||||
"""Returns the complement of the normalized Gini index (equality = 1 - Gini).
|
||||
|
||||
Args:
|
||||
endowments (ndarray): The array of endowments for each of the agents in the
|
||||
simulated economy.
|
||||
|
||||
Returns:
|
||||
Normalized equality index for the distribution of endowments (float). A value
|
||||
of 0 indicates everything belongs to 1 agent (perfect inequality),
|
||||
whereas a value of 1 indicates all agents have equal endowments (perfect
|
||||
equality).
|
||||
"""
|
||||
return 1 - get_gini(endowments)
|
||||
|
||||
|
||||
def get_productivity(coin_endowments):
|
||||
"""Returns the total coin inside the simulated economy.
|
||||
|
||||
Args:
|
||||
coin_endowments (ndarray): The array of coin endowments for each of the
|
||||
agents in the simulated economy.
|
||||
|
||||
Returns:
|
||||
Total coin endowment (float).
|
||||
"""
|
||||
return np.sum(coin_endowments)
|
||||
123
ai_economist/foundation/utils.py
Normal file
123
ai_economist/foundation/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from hashlib import sha512
|
||||
|
||||
import lz4.frame
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from ai_economist.foundation.base.base_env import BaseEnvironment
|
||||
|
||||
|
||||
def save_episode_log(game_object, filepath, compression_level=16):
|
||||
"""Save a lz4 compressed version of the dense log stored
|
||||
in the provided game object"""
|
||||
assert isinstance(game_object, BaseEnvironment)
|
||||
compression_level = int(compression_level)
|
||||
if compression_level < 0:
|
||||
compression_level = 0
|
||||
elif compression_level > 16:
|
||||
compression_level = 16
|
||||
|
||||
with lz4.frame.open(
|
||||
filepath, mode="wb", compression_level=compression_level
|
||||
) as log_file:
|
||||
log_bytes = bytes(
|
||||
json.dumps(
|
||||
game_object.previous_episode_dense_log, ensure_ascii=False
|
||||
).encode("utf-8")
|
||||
)
|
||||
log_file.write(log_bytes)
|
||||
|
||||
|
||||
def load_episode_log(filepath):
|
||||
"""Load the dense log saved at provided filepath"""
|
||||
with lz4.frame.open(filepath, mode="rb") as log_file:
|
||||
log_bytes = log_file.read()
|
||||
return json.loads(log_bytes)
|
||||
|
||||
|
||||
def verify_activation_code():
|
||||
"""
|
||||
Validate the user's activation code.
|
||||
If the activation code is valid, also save it in a text file for future reference.
|
||||
If the activation code is invalid, simply exit the program
|
||||
"""
|
||||
path_to_activation_code_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def validate_activation_code(code, msg=b"covid19 code activation"):
|
||||
filepath = os.path.abspath(
|
||||
os.path.join(
|
||||
path_to_activation_code_dir,
|
||||
"scenarios/covid19/key_to_check_activation_code_against",
|
||||
)
|
||||
)
|
||||
with open(filepath, "r") as fp:
|
||||
key_pair = RSA.import_key(fp.read())
|
||||
|
||||
hashed_msg = int.from_bytes(sha512(msg).digest(), byteorder="big")
|
||||
signature = pow(hashed_msg, key_pair.d, key_pair.n)
|
||||
try:
|
||||
exp_from_code = int(code, 16)
|
||||
hashed_msg_from_signature = pow(signature, exp_from_code, key_pair.n)
|
||||
|
||||
return hashed_msg == hashed_msg_from_signature
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
activation_code_filename = "activation_code.txt"
|
||||
|
||||
filepath = os.path.join(path_to_activation_code_dir, activation_code_filename)
|
||||
if activation_code_filename in os.listdir(path_to_activation_code_dir):
|
||||
print("Using the activation code already present in '{}'".format(filepath))
|
||||
with open(filepath, "r") as fp:
|
||||
activation_code = fp.read()
|
||||
fp.close()
|
||||
if validate_activation_code(activation_code):
|
||||
return # already activated
|
||||
print(
|
||||
"The activation code saved in '{}' is incorrect! "
|
||||
"Please correct the activation code and try again.".format(filepath)
|
||||
)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(
|
||||
"In order to run this simulation, you will need an activation code.\n"
|
||||
"Please fill out the form at "
|
||||
"https://forms.gle/dJ2gKDBqLDko1g7m7 and we will send you an "
|
||||
"activation code to the provided email address.\n"
|
||||
)
|
||||
num_attempts = 5
|
||||
attempt_num = 0
|
||||
while attempt_num < num_attempts:
|
||||
activation_code = input(
|
||||
f"Whenever you are ready, "
|
||||
"please enter the activation code: "
|
||||
f"(attempt {attempt_num + 1} / {num_attempts})"
|
||||
)
|
||||
attempt_num += 1
|
||||
if validate_activation_code(activation_code):
|
||||
print(
|
||||
"Saving the activation code in '{}' for future "
|
||||
"use.".format(filepath)
|
||||
)
|
||||
with open(
|
||||
os.path.join(path_to_activation_code_dir, activation_code_filename),
|
||||
"w",
|
||||
) as fp:
|
||||
fp.write(activation_code)
|
||||
fp.close()
|
||||
return
|
||||
print("Incorrect activation code. Please try again.")
|
||||
print(
|
||||
"You have had {} attempts to provide the activate code. Unfortunately, "
|
||||
"none of the activation code(s) you provided could be validated. "
|
||||
"Exiting...".format(num_attempts)
|
||||
)
|
||||
sys.exit(0)
|
||||
120
ai_economist/real_business_cycle/README.md
Normal file
120
ai_economist/real_business_cycle/README.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Real Business Cycle (RBC)
|
||||
This directory implements a **Real-Business-Cycle** (RBC) simulation with many heterogeneous, interacting strategic agents of various types, such as **consumers, firms, and the government**. For details, please refer to this paper "Finding General Equilibria in Many-Agent Economic Simulations using Deep Reinforcement Learning (ArXiv link forthcoming)". We also provide training code that uses deep multi-agent reinforcement learning to determine optimal economic policies and dynamics in these many agent environments. Below are instructions required to launch the training runs.
|
||||
|
||||
**Note: The experiments require a GPU to run!**
|
||||
|
||||
## Dependencies
|
||||
|
||||
- torch>=1.9.0
|
||||
- pycuda==2021.1
|
||||
- matplotlib==3.2.1
|
||||
|
||||
## Running Local Jobs
|
||||
To run a hyperparameter sweep of jobs on a local machine, use (see file for command line arguments and hyperparameter sweep dictionaries)
|
||||
|
||||
```
|
||||
python train_multi_exps.py
|
||||
```
|
||||
|
||||
## Configuration Dictionaries
|
||||
|
||||
Configuration dictionaries are currently specified in Python code, and then written as `hparams.yaml` in the job directory. For examples, see the file `constants.py`. The dictionaries contain "agents", "world", and "train" dictionaries which contain various hyperparameters.
|
||||
|
||||
## Hyperparameter Sweeps
|
||||
|
||||
The files `train_multi_exps.py` allow hyperparameter sweeps. These are specified in `*_param_sweeps` dictionaries in the file. For each hyperparameter, specify a list of one or more choices. The Cartesian product of all choices will be used.
|
||||
|
||||
## Approximate Best Response Training
|
||||
|
||||
To run a single approximate best-response (BR) training job on checkpoint policies, run `python train_bestresponse.py ROLLOUT_DIR NUM_EPISODES_TO_TRAIN --ep-strs ep1 ep2 --agent-type all`. The `--ep-strs` argument specifies which episodes to run on (for example, policies from episode 0, 10000, and 200000). These must be episodes for which policies were saved. It is possible to specify a single agent type.
|
||||
|
||||
|
||||
## What Will Be Saved?
|
||||
|
||||
A large amount of data will be saved -- one can set hyperparamter `train.save_dense_every` in the configuration dictionary (`hparams.yaml`/`constants.py`) to reduce this.
|
||||
|
||||
At the top level, an experiment directory stores the results of many runs in a hyperparameter sweep. Example structure:
|
||||
|
||||
```
|
||||
experiment/experimentname/
|
||||
rollout-999999-99999/
|
||||
brconsumer/
|
||||
...
|
||||
brfirm/
|
||||
episode_XXXX_consumer.npz
|
||||
episode_XXXX_government.npz
|
||||
episode_XXXX_firm.npz
|
||||
saved_models/
|
||||
consumer_policy_XXX.pt
|
||||
firm_policy_XXX.pt
|
||||
government_policy_XXX.pt.
|
||||
brgovernment/
|
||||
...
|
||||
hparams.yaml
|
||||
action_arrays.pickle
|
||||
episode_XXXX_consumer.npz
|
||||
episode_XXXX_government.npz
|
||||
episode_XXXX_firm.npz
|
||||
saved_models/
|
||||
consumer_policy_XXX.pt
|
||||
firm_policy_XXX.pt
|
||||
government_policy_XXX.pt.
|
||||
|
||||
rollout-777777-77777/
|
||||
...
|
||||
```
|
||||
|
||||
Files:
|
||||
|
||||
`rollout-XXXXXX-XXX`: subdirectory containing all output for a single run.
|
||||
|
||||
`hparams.yaml`: configuration dictionary with hyperparameters
|
||||
|
||||
`action_arrays.pickle`: contains saved action arrays (allowing mapping action indices to the actual action, e.g. index 1 is price 1000.0, etc.)
|
||||
|
||||
`episode_XXXX_AGENTTYPE.npz`: Contains dense rollouts stored as the output of a numpy.savez call. When loaded, can be treated like a dictionary of numpy arrays. Has keys: `['states', 'actions', 'rewards', 'action_array', 'aux_array']` (view keys by using `.files`). `states`, `actions`, `rewards`, and `aux_array` all refer to saved copies of CUDA arrays (described below). `action_array` is a small array mapping action indices to the actual action.
|
||||
|
||||
`saved_models/AGENTTYPE_policy_XXX.pt`: a saved PyTorch state dict of the policy network, after episode XXX.
|
||||
|
||||
## Structure Of Arrays
|
||||
|
||||
`states` for any given agent type is an array storing observed states. It has shape `batch_size, ep_length, num_agents, agent_total_state_dim`.
|
||||
|
||||
`actions` is an array consisting of the action _indices_ (integers). For firms and government, it is of shape `batch_size, ep_length, num_agents`. For consumers, it is of shape `batch_size, ep_length, num_agents, num_action_heads`.
|
||||
|
||||
`rewards` stores total rewards, and is of shape `batch_size, ep_length, num_agents`.
|
||||
|
||||
The `aux_array` stores additional information and may differ per agent type. The consumer `aux_array` stores _actual_ consumption of each firm's good (as opposed to attempted consumption). The firm `aux_array` stores the amount bought by the export market.
|
||||
|
||||
## State Array Layout:
|
||||
|
||||
States observed by each agent consist of a global state, plus additional state dimensions per agent.
|
||||
|
||||
Global state: total dimension 4 * num_firms + 2 + 1
|
||||
- prices: 1 per firm
|
||||
- wages: 1 per firm
|
||||
- inventories: 1 per firm
|
||||
- overdemanded flag: 1 per firm
|
||||
- time
|
||||
|
||||
Consumer additional state variables: total dimension global state + 2
|
||||
- budget
|
||||
- theta
|
||||
|
||||
Firm additional state variables: total dimension global state + 3 + num_firms
|
||||
- budget
|
||||
- capital
|
||||
- production alpha
|
||||
- one-hot representation identifying which firm
|
||||
|
||||
## What Gets Loaded And Written By BR Code?
|
||||
|
||||
The best response code loads in the `hparams.yaml` file, and the policies at a given time step (i.e. `saved_models/...policy_XXX.pt`). It then trains one of the policies while keeping the others fixed. Results are written to directories `brfirm`, `brconsumer`, `brgovernment` and contain dense rollouts and saved policy checkpoints, but from the best response training.
|
||||
|
||||
## Which Hyperparameters Are Managed And Where?
|
||||
|
||||
Initial values of state variables (budgets, initial wages, levels of capital, and so on) are set by the code in the method `__init_cuda_data_structs`. Some of these can be controlled from the hyperparameter dict; others are currently hardcoded.
|
||||
|
||||
Other hyperparameters are specified in the configuration dictionary.
|
||||
|
||||
Finally, the technology parameter (A) of the production function is currently hardcoded in the function call in `rbc/cuda/firm_rbc.cu`.
|
||||
242
ai_economist/real_business_cycle/experiment_utils.py
Normal file
242
ai_economist/real_business_cycle/experiment_utils.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import struct
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
# defaults
|
||||
_NUM_FIRMS = 10
|
||||
|
||||
|
||||
def _bigint_from_bytes(num_bytes):
|
||||
"""
|
||||
See https://github.com/openai/gym/blob/master/gym/utils/seeding.py.
|
||||
"""
|
||||
sizeof_int = 4
|
||||
padding = sizeof_int - len(num_bytes) % sizeof_int
|
||||
num_bytes += b"\0" * padding
|
||||
int_count = int(len(num_bytes) / sizeof_int)
|
||||
unpacked = struct.unpack("{}I".format(int_count), num_bytes)
|
||||
accum = 0
|
||||
for i, val in enumerate(unpacked):
|
||||
accum += 2 ** (sizeof_int * 8 * i) * val
|
||||
return accum
|
||||
|
||||
|
||||
def seed_from_base_seed(base_seed):
|
||||
"""
|
||||
Hash base seed to reduce correlation.
|
||||
"""
|
||||
max_bytes = 4
|
||||
hash_func = hashlib.sha512(str(base_seed).encode("utf8")).digest()
|
||||
|
||||
return _bigint_from_bytes(hash_func[:max_bytes])
|
||||
|
||||
|
||||
def hash_from_dict(d):
|
||||
d_copy = deepcopy(d)
|
||||
del (d_copy["train"])["base_seed"]
|
||||
d_string = json.dumps(d_copy, sort_keys=True)
|
||||
return int(hashlib.sha256(d_string.encode("utf8")).hexdigest()[:8], 16)
|
||||
|
||||
|
||||
def cfg_dict_from_yaml(
|
||||
hparams_path,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
group_name=None,
|
||||
):
|
||||
with open(hparams_path) as f:
|
||||
d = yaml.safe_load(f)
|
||||
|
||||
if group_name is not None:
|
||||
d["metadata"]["group_name"] = group_name
|
||||
d["metadata"]["hparamhash"] = hash_from_dict(d)
|
||||
d["agents"][
|
||||
"consumer_consumption_actions_array"
|
||||
] = consumption_choices # Note: hardcoded
|
||||
d["agents"]["consumer_work_actions_array"] = work_choices # Note: hardcoded
|
||||
d["agents"]["firm_actions_array"] = price_and_wage # Note: hardcoded
|
||||
d["agents"]["government_actions_array"] = tax_choices
|
||||
d["train"]["save_dir"] = str(hparams_path.absolute().parent)
|
||||
d["train"]["seed"] = seed_from_base_seed(d["train"]["base_seed"])
|
||||
return d
|
||||
|
||||
|
||||
def run_experiment_batch_parallel(
|
||||
experiment_dir,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
group_name=None,
|
||||
consumers_only=False,
|
||||
no_firms=False,
|
||||
default_firm_action=None,
|
||||
default_government_action=None,
|
||||
):
|
||||
hparams_path = Path(experiment_dir) / Path("hparams.yaml")
|
||||
hparams_dict = cfg_dict_from_yaml(
|
||||
hparams_path,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
group_name=group_name,
|
||||
)
|
||||
print(f"hparams_dict {hparams_dict}")
|
||||
# import this here so rest of file still imports without cuda installed
|
||||
from rbc.cuda_manager import ConsumerFirmRunManagerBatchParallel
|
||||
|
||||
if consumers_only:
|
||||
m = ConsumerFirmRunManagerBatchParallel(
|
||||
hparams_dict,
|
||||
freeze_firms=default_firm_action,
|
||||
freeze_govt=default_government_action,
|
||||
)
|
||||
elif no_firms:
|
||||
m = ConsumerFirmRunManagerBatchParallel(
|
||||
hparams_dict,
|
||||
freeze_firms=default_firm_action,
|
||||
)
|
||||
else:
|
||||
m = ConsumerFirmRunManagerBatchParallel(hparams_dict)
|
||||
m.train()
|
||||
|
||||
|
||||
def compare_global_states_within_type(states, global_state_size):
|
||||
# every agent within a batch should have the same global state
|
||||
first_agent_global = states[:, :, :1, :global_state_size]
|
||||
all_agents_global = states[:, :, :, :global_state_size]
|
||||
return np.isclose(all_agents_global, first_agent_global).all()
|
||||
|
||||
|
||||
def compare_global_states_across_types(
|
||||
consumer_states, firm_states, government_states, global_state_size
|
||||
):
|
||||
first_agent_global = consumer_states[:, :, :1, :global_state_size]
|
||||
return (
|
||||
np.isclose(firm_states[:, :, :, :global_state_size], first_agent_global).all(),
|
||||
np.isclose(
|
||||
government_states[:, :, :, :global_state_size], first_agent_global
|
||||
).all(),
|
||||
np.isclose(
|
||||
consumer_states[:, :, :, :global_state_size], first_agent_global
|
||||
).all(),
|
||||
)
|
||||
|
||||
|
||||
def check_no_negative_stocks(state, stock_offset, stock_size):
|
||||
stocks = state[:, :, :, stock_offset : (stock_offset + stock_size)]
|
||||
return (stocks >= -1.0e-3).all()
|
||||
|
||||
|
||||
train_param_sweeps = {
|
||||
"lr": [0.005, 0.001],
|
||||
"entropy": [0.01],
|
||||
"base_seed": [2596],
|
||||
"batch_size": [64],
|
||||
"clip_grad_norm": [1.0, 2.0, 5.0],
|
||||
}
|
||||
|
||||
# Other param sweeps
|
||||
agent_param_sweeps = {
|
||||
# "consumer_noponzi_eta": [0.1,0.05]
|
||||
}
|
||||
|
||||
world_param_sweeps = {
|
||||
# "interest_rate": [0.02, 0.0]
|
||||
}
|
||||
|
||||
|
||||
def add_all(d, keys_list, target_val):
|
||||
for k in keys_list:
|
||||
d[k] = target_val
|
||||
|
||||
|
||||
def sweep_cfg_generator(
|
||||
base_cfg,
|
||||
tr_param_sweeps=None,
|
||||
ag_param_sweeps=None,
|
||||
wld_param_sweeps=None,
|
||||
seed_from_timestamp=False,
|
||||
group_name=None,
|
||||
):
|
||||
# train_param_sweeps
|
||||
if tr_param_sweeps is None:
|
||||
tr_param_sweeps = {}
|
||||
# agent_param_sweeps
|
||||
if ag_param_sweeps is None:
|
||||
ag_param_sweeps = {}
|
||||
# world_param_sweeps
|
||||
if wld_param_sweeps is None:
|
||||
wld_param_sweeps = {}
|
||||
|
||||
assert isinstance(tr_param_sweeps, dict)
|
||||
assert isinstance(ag_param_sweeps, dict)
|
||||
assert isinstance(wld_param_sweeps, dict)
|
||||
|
||||
key_dict = {} # tells which key goes to which dict, e.g. "lr" -> "train", etc.
|
||||
if len(tr_param_sweeps) > 0:
|
||||
train_k, train_v = zip(*tr_param_sweeps.items())
|
||||
else:
|
||||
train_k, train_v = (), ()
|
||||
add_all(key_dict, train_k, "train")
|
||||
if len(ag_param_sweeps) > 0:
|
||||
agent_k, agent_v = zip(*ag_param_sweeps.items())
|
||||
else:
|
||||
agent_k, agent_v = (), ()
|
||||
add_all(key_dict, agent_k, "agents")
|
||||
if len(wld_param_sweeps) > 0:
|
||||
world_k, world_v = zip(*wld_param_sweeps.items())
|
||||
else:
|
||||
world_k, world_v = (), ()
|
||||
add_all(key_dict, world_k, "world")
|
||||
|
||||
k = train_k + agent_k + world_k
|
||||
v = train_v + agent_v + world_v
|
||||
|
||||
# have a "reverse lookup" dictionary for each key name
|
||||
for combination in itertools.product(*v):
|
||||
values_to_substitute = dict(zip(k, combination))
|
||||
out = deepcopy(base_cfg)
|
||||
for key, value in values_to_substitute.items():
|
||||
out[key_dict[key]][key] = value
|
||||
if seed_from_timestamp:
|
||||
int_timestamp = int(
|
||||
time.time() * 1000
|
||||
) # time.time() returns float, multiply 1000 for higher resolution
|
||||
out["train"]["base_seed"] += int_timestamp
|
||||
if group_name is not None:
|
||||
out["metadata"]["group"] = group_name
|
||||
yield out
|
||||
|
||||
|
||||
def create_job_dir(experiment_dir, job_name_base, cfg=None, action_arrays=None):
|
||||
unique_id = time.time()
|
||||
dirname = f"{job_name_base}-{unique_id}".replace(".", "-")
|
||||
dir_path = Path(experiment_dir) / Path(dirname)
|
||||
os.makedirs(str(dir_path), exist_ok=True)
|
||||
cfg["metadata"]["dirname"] = dirname
|
||||
cfg["metadata"]["group"] = str(Path(experiment_dir).name)
|
||||
with open(dir_path / Path("hparams.yaml"), "w") as f:
|
||||
f.write(yaml.dump(cfg))
|
||||
|
||||
if action_arrays is not None:
|
||||
with open(dir_path / Path("action_arrays.pickle"), "wb") as f:
|
||||
pickle.dump(action_arrays, f)
|
||||
5
ai_economist/real_business_cycle/rbc/__init__.py
Normal file
5
ai_economist/real_business_cycle/rbc/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2020, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
638
ai_economist/real_business_cycle/rbc/constants.py
Normal file
638
ai_economist/real_business_cycle/rbc/constants.py
Normal file
@@ -0,0 +1,638 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
_NP_DTYPE = np.float32
|
||||
|
||||
|
||||
def all_agents_export_experiment_template(
|
||||
NUMFIRMS, NUMCONSUMERS, NUMGOVERNMENTS, episodes_const=30000
|
||||
):
|
||||
consumption_choices = [
|
||||
np.array([0.0 + 1.0 * c for c in range(11)], dtype=_NP_DTYPE)
|
||||
]
|
||||
work_choices = [
|
||||
np.array([0.0 + 20 * 13 * h for h in range(5)], dtype=_NP_DTYPE)
|
||||
] # specify dtype -- be consistent?
|
||||
|
||||
consumption_choices = np.array(
|
||||
list(itertools.product(*consumption_choices)), dtype=_NP_DTYPE
|
||||
)
|
||||
work_choices = np.array(list(itertools.product(*work_choices)), dtype=_NP_DTYPE)
|
||||
|
||||
price_choices = np.array([0.0 + 500.0 * c for c in range(6)], dtype=_NP_DTYPE)
|
||||
wage_choices = np.array([0.0, 11.0, 22.0, 33.0, 44.0], dtype=_NP_DTYPE)
|
||||
capital_choices = np.array([0.1], dtype=_NP_DTYPE)
|
||||
price_and_wage = np.array(
|
||||
list(itertools.product(price_choices, wage_choices, capital_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
|
||||
# government action discretization
|
||||
income_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
corporate_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
tax_choices = np.array(
|
||||
list(itertools.product(income_taxation_choices, corporate_taxation_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
global_state_dim = (
|
||||
NUMFIRMS # prices
|
||||
+ NUMFIRMS # wages
|
||||
+ NUMFIRMS # stocks
|
||||
+ NUMFIRMS # was good overdemanded
|
||||
+ 2 * NUMGOVERNMENTS # tax rates
|
||||
+ 1
|
||||
) # time
|
||||
|
||||
global_state_digit_dims = list(
|
||||
range(2 * NUMFIRMS, 3 * NUMFIRMS)
|
||||
) # stocks are the only global state var that can get huge
|
||||
consumer_state_dim = (
|
||||
global_state_dim + 1 + 1
|
||||
) # budget # theta, the disutility of work
|
||||
|
||||
firm_state_dim = (
|
||||
global_state_dim
|
||||
+ 1 # budget
|
||||
+ 1 # capital
|
||||
+ 1 # production alpha
|
||||
+ NUMFIRMS # onehot specifying which firm
|
||||
)
|
||||
|
||||
episodes_to_anneal_firm = 100000
|
||||
episodes_to_anneal_government = 100000
|
||||
government_phase1_start = 100000
|
||||
government_state_dim = global_state_dim
|
||||
DEFAULT_CFG_DICT = {
|
||||
# actions_array key will be added below
|
||||
"agents": {
|
||||
"num_consumers": NUMCONSUMERS,
|
||||
"num_firms": NUMFIRMS,
|
||||
"num_governments": NUMGOVERNMENTS,
|
||||
"global_state_dim": global_state_dim,
|
||||
"consumer_state_dim": consumer_state_dim,
|
||||
# action vectors are how much consume from each firm,
|
||||
# how much to work, and which firm to choose
|
||||
"consumer_action_dim": NUMFIRMS + 1 + 1,
|
||||
"consumer_num_consume_actions": consumption_choices.shape[0],
|
||||
"consumer_num_work_actions": work_choices.shape[0],
|
||||
"consumer_num_whichfirm_actions": NUMFIRMS,
|
||||
"firm_state_dim": firm_state_dim, # what are observations?
|
||||
# actions are price and wage for own firm, and capital choices
|
||||
"firm_action_dim": 3,
|
||||
"firm_num_actions": price_and_wage.shape[0],
|
||||
"government_state_dim": government_state_dim,
|
||||
"government_action_dim": 2,
|
||||
"government_num_actions": tax_choices.shape[0],
|
||||
"max_possible_consumption": float(consumption_choices.max()),
|
||||
"max_possible_hours_worked": float(work_choices.max()),
|
||||
"max_possible_wage": float(wage_choices.max()),
|
||||
"max_possible_price": float(price_choices.max()),
|
||||
# these are dims which, due to being on a large scale,
|
||||
# have to be expanded to a digit representation
|
||||
"consumer_digit_dims": global_state_digit_dims
|
||||
+ [global_state_dim], # global state + consumer budget
|
||||
# global state + firm budget (do we need capital?)
|
||||
"firm_digit_dims": global_state_digit_dims + [global_state_dim],
|
||||
# govt only has global state
|
||||
"government_digit_dims": global_state_digit_dims,
|
||||
"firm_reward_scale": 10000,
|
||||
"government_reward_scale": 100000,
|
||||
"consumer_reward_scale": 50.0,
|
||||
"firm_anneal_wages": {
|
||||
"anneal_on": True,
|
||||
"start": 22.0,
|
||||
"increase_const": float(wage_choices.max() - 22.0)
|
||||
/ (episodes_to_anneal_firm),
|
||||
"decrease_const": (22.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"firm_anneal_prices": {
|
||||
"anneal_on": True,
|
||||
"start": 1000.0,
|
||||
"increase_const": float(price_choices.max() - 1000.00)
|
||||
/ episodes_to_anneal_firm,
|
||||
"decrease_const": (1000.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"government_anneal_taxes": {
|
||||
"anneal_on": True,
|
||||
"start": 0.0,
|
||||
"increase_const": 1.0 / episodes_to_anneal_government,
|
||||
},
|
||||
"firm_begin_anneal_action": 0,
|
||||
"government_begin_anneal_action": government_phase1_start,
|
||||
"consumer_anneal_theta": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
},
|
||||
"consumer_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"firm_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"govt_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"consumer_noponzi_eta": 0.0,
|
||||
"consumer_penalty_scale": 1.0,
|
||||
"firm_noponzi_eta": 0.0,
|
||||
"firm_training_start": episodes_to_anneal_firm,
|
||||
"government_training_start": government_phase1_start
|
||||
+ episodes_to_anneal_government,
|
||||
"consumer_training_start": 0,
|
||||
"government_counts_firm_reward": 0,
|
||||
"should_boost_firm_reward": False,
|
||||
"firm_reward_for_government_factor": 0.0025,
|
||||
},
|
||||
"world": {
|
||||
"maxtime": 10,
|
||||
"initial_firm_endowment": 22.0 * 1000 * NUMCONSUMERS,
|
||||
"initial_consumer_endowment": 2000,
|
||||
"initial_stocks": 0.0,
|
||||
"initial_prices": 1000.0,
|
||||
"initial_wages": 22.0,
|
||||
"interest_rate": 0.1,
|
||||
"consumer_theta": 0.01,
|
||||
"crra_param": 0.1,
|
||||
"production_alpha": "fixed_array", # only works for exactly 10 firms, kluge
|
||||
"initial_capital": "twolevel",
|
||||
"paretoscaletheta": 4.0,
|
||||
"importer_price": 500.0,
|
||||
"importer_quantity": 100.0,
|
||||
"use_importer": 1,
|
||||
},
|
||||
"train": {
|
||||
"batch_size": 8,
|
||||
"base_seed": 1234,
|
||||
"save_dense_every": 2000,
|
||||
"save_model_every": 10000,
|
||||
"num_episodes": 500000,
|
||||
"infinite_episodes": False,
|
||||
"lr": 0.01,
|
||||
"gamma": 0.9999,
|
||||
"entropy": 0.0,
|
||||
"value_loss_weight": 1.0,
|
||||
"digit_representation_size": 10,
|
||||
"lagr_num_steps": 1,
|
||||
"boost_firm_reward_factor": 1.0,
|
||||
},
|
||||
}
|
||||
return (
|
||||
DEFAULT_CFG_DICT,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_agents_short_export_experiment_template(
|
||||
NUMFIRMS, NUMCONSUMERS, NUMGOVERNMENTS, episodes_const=10000
|
||||
):
|
||||
consumption_choices = [
|
||||
np.array([0.0 + 1.0 * c for c in range(11)], dtype=_NP_DTYPE)
|
||||
]
|
||||
work_choices = [
|
||||
np.array([0.0 + 20 * 13 * h for h in range(5)], dtype=_NP_DTYPE)
|
||||
] # specify dtype -- be consistent?
|
||||
|
||||
consumption_choices = np.array(
|
||||
list(itertools.product(*consumption_choices)), dtype=_NP_DTYPE
|
||||
)
|
||||
work_choices = np.array(list(itertools.product(*work_choices)), dtype=_NP_DTYPE)
|
||||
|
||||
price_choices = np.array([0.0 + 500.0 * c for c in range(6)], dtype=_NP_DTYPE)
|
||||
wage_choices = np.array([0.0, 11.0, 22.0, 33.0, 44.0], dtype=_NP_DTYPE)
|
||||
capital_choices = np.array([0.1], dtype=_NP_DTYPE)
|
||||
price_and_wage = np.array(
|
||||
list(itertools.product(price_choices, wage_choices, capital_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
|
||||
# government action discretization
|
||||
income_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
corporate_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
tax_choices = np.array(
|
||||
list(itertools.product(income_taxation_choices, corporate_taxation_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
global_state_dim = (
|
||||
NUMFIRMS # prices
|
||||
+ NUMFIRMS # wages
|
||||
+ NUMFIRMS # stocks
|
||||
+ NUMFIRMS # was good overdemanded
|
||||
+ 2 * NUMGOVERNMENTS # tax rates
|
||||
+ 1
|
||||
) # time
|
||||
|
||||
global_state_digit_dims = list(
|
||||
range(2 * NUMFIRMS, 3 * NUMFIRMS)
|
||||
) # stocks are the only global state var that can get huge
|
||||
consumer_state_dim = (
|
||||
global_state_dim + 1 + 1
|
||||
) # budget # theta, the disutility of work
|
||||
|
||||
firm_state_dim = (
|
||||
global_state_dim
|
||||
+ 1 # budget
|
||||
+ 1 # capital
|
||||
+ 1 # production alpha
|
||||
+ NUMFIRMS # onehot specifying which firm
|
||||
)
|
||||
|
||||
episodes_to_anneal_firm = 30000
|
||||
episodes_to_anneal_government = 30000
|
||||
government_phase1_start = 30000
|
||||
government_state_dim = global_state_dim
|
||||
DEFAULT_CFG_DICT = {
|
||||
# actions_array key will be added below
|
||||
"agents": {
|
||||
"num_consumers": NUMCONSUMERS,
|
||||
"num_firms": NUMFIRMS,
|
||||
"num_governments": NUMGOVERNMENTS,
|
||||
"global_state_dim": global_state_dim,
|
||||
"consumer_state_dim": consumer_state_dim,
|
||||
# action vectors are how much consume from each firm,
|
||||
# how much to work, and which firm to choose
|
||||
"consumer_action_dim": NUMFIRMS + 1 + 1,
|
||||
"consumer_num_consume_actions": consumption_choices.shape[0],
|
||||
"consumer_num_work_actions": work_choices.shape[0],
|
||||
"consumer_num_whichfirm_actions": NUMFIRMS,
|
||||
"firm_state_dim": firm_state_dim, # what are observations?
|
||||
# actions are price and wage for own firm, and capital choices
|
||||
"firm_action_dim": 3,
|
||||
"firm_num_actions": price_and_wage.shape[0],
|
||||
"government_state_dim": government_state_dim,
|
||||
"government_action_dim": 2,
|
||||
"government_num_actions": tax_choices.shape[0],
|
||||
"max_possible_consumption": float(consumption_choices.max()),
|
||||
"max_possible_hours_worked": float(work_choices.max()),
|
||||
"max_possible_wage": float(wage_choices.max()),
|
||||
"max_possible_price": float(price_choices.max()),
|
||||
# these are dims which, due to being on a large scale,
|
||||
# have to be expanded to a digit representation
|
||||
"consumer_digit_dims": global_state_digit_dims
|
||||
+ [global_state_dim], # global state + consumer budget
|
||||
"firm_digit_dims": global_state_digit_dims
|
||||
+ [global_state_dim], # global state + firm budget (do we need capital?)
|
||||
# govt only has global state
|
||||
"government_digit_dims": global_state_digit_dims,
|
||||
"firm_reward_scale": 10000,
|
||||
"government_reward_scale": 100000,
|
||||
"consumer_reward_scale": 50.0,
|
||||
"firm_anneal_wages": {
|
||||
"anneal_on": True,
|
||||
"start": 22.0,
|
||||
"increase_const": float(wage_choices.max() - 22.0)
|
||||
/ (episodes_to_anneal_firm),
|
||||
"decrease_const": (22.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"firm_anneal_prices": {
|
||||
"anneal_on": True,
|
||||
"start": 1000.0,
|
||||
"increase_const": float(price_choices.max() - 1000.00)
|
||||
/ episodes_to_anneal_firm,
|
||||
"decrease_const": (1000.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"government_anneal_taxes": {
|
||||
"anneal_on": True,
|
||||
"start": 0.0,
|
||||
"increase_const": 1.0 / episodes_to_anneal_government,
|
||||
},
|
||||
"firm_begin_anneal_action": 0,
|
||||
"government_begin_anneal_action": government_phase1_start,
|
||||
"consumer_anneal_theta": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
},
|
||||
"consumer_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"firm_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"govt_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"consumer_noponzi_eta": 0.0,
|
||||
"consumer_penalty_scale": 1.0,
|
||||
"firm_noponzi_eta": 0.0,
|
||||
"firm_training_start": episodes_to_anneal_firm,
|
||||
"government_training_start": government_phase1_start
|
||||
+ episodes_to_anneal_government,
|
||||
"consumer_training_start": 0,
|
||||
"government_counts_firm_reward": 0,
|
||||
"should_boost_firm_reward": False,
|
||||
"firm_reward_for_government_factor": 0.0025,
|
||||
},
|
||||
"world": {
|
||||
"maxtime": 10,
|
||||
"initial_firm_endowment": 22.0 * 1000 * NUMCONSUMERS,
|
||||
"initial_consumer_endowment": 2000,
|
||||
"initial_stocks": 0.0,
|
||||
"initial_prices": 1000.0,
|
||||
"initial_wages": 22.0,
|
||||
"interest_rate": 0.1,
|
||||
"consumer_theta": 0.01,
|
||||
"crra_param": 0.1,
|
||||
"production_alpha": "fixed_array", # only works for exactly 10 firms, kluge
|
||||
"initial_capital": "twolevel",
|
||||
"paretoscaletheta": 4.0,
|
||||
"importer_price": 500.0,
|
||||
"importer_quantity": 100.0,
|
||||
"use_importer": 1,
|
||||
},
|
||||
"train": {
|
||||
"batch_size": 8,
|
||||
"base_seed": 1234,
|
||||
"save_dense_every": 2000,
|
||||
"save_model_every": 10000,
|
||||
"num_episodes": 200000,
|
||||
"infinite_episodes": False,
|
||||
"lr": 0.01,
|
||||
"gamma": 0.9999,
|
||||
"entropy": 0.0,
|
||||
"value_loss_weight": 1.0,
|
||||
"digit_representation_size": 10,
|
||||
"lagr_num_steps": 1,
|
||||
"boost_firm_reward_factor": 1.0,
|
||||
},
|
||||
}
|
||||
return (
|
||||
DEFAULT_CFG_DICT,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def very_short_test_template(
|
||||
NUMFIRMS, NUMCONSUMERS, NUMGOVERNMENTS, episodes_const=30000
|
||||
):
|
||||
consumption_choices = [
|
||||
np.array([0.0 + 1.0 * c for c in range(11)], dtype=_NP_DTYPE)
|
||||
]
|
||||
work_choices = [
|
||||
np.array([0.0 + 20 * 13 * h for h in range(5)], dtype=_NP_DTYPE)
|
||||
] # specify dtype -- be consistent?
|
||||
|
||||
consumption_choices = np.array(
|
||||
list(itertools.product(*consumption_choices)), dtype=_NP_DTYPE
|
||||
)
|
||||
work_choices = np.array(list(itertools.product(*work_choices)), dtype=_NP_DTYPE)
|
||||
|
||||
price_choices = np.array([0.0 + 500.0 * c for c in range(6)], dtype=_NP_DTYPE)
|
||||
wage_choices = np.array([0.0, 11.0, 22.0, 33.0, 44.0], dtype=_NP_DTYPE)
|
||||
capital_choices = np.array([0.1], dtype=_NP_DTYPE)
|
||||
price_and_wage = np.array(
|
||||
list(itertools.product(price_choices, wage_choices, capital_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
|
||||
# government action discretization
|
||||
income_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
corporate_taxation_choices = np.array(
|
||||
[0.0 + 0.2 * c for c in range(6)], dtype=_NP_DTYPE
|
||||
)
|
||||
tax_choices = np.array(
|
||||
list(itertools.product(income_taxation_choices, corporate_taxation_choices)),
|
||||
dtype=_NP_DTYPE,
|
||||
)
|
||||
global_state_dim = (
|
||||
NUMFIRMS # prices
|
||||
+ NUMFIRMS # wages
|
||||
+ NUMFIRMS # stocks
|
||||
+ NUMFIRMS # was good overdemanded
|
||||
+ 2 * NUMGOVERNMENTS # tax rates
|
||||
+ 1
|
||||
) # time
|
||||
|
||||
global_state_digit_dims = list(
|
||||
range(2 * NUMFIRMS, 3 * NUMFIRMS)
|
||||
) # stocks are the only global state var that can get huge
|
||||
consumer_state_dim = (
|
||||
global_state_dim + 1 + 1
|
||||
) # budget # theta, the disutility of work
|
||||
|
||||
firm_state_dim = (
|
||||
global_state_dim
|
||||
+ 1 # budget
|
||||
+ 1 # capital
|
||||
+ 1 # production alpha
|
||||
+ NUMFIRMS # onehot specifying which firm
|
||||
)
|
||||
|
||||
episodes_to_anneal_firm = 10
|
||||
episodes_to_anneal_government = 10
|
||||
government_phase1_start = 10
|
||||
government_state_dim = global_state_dim
|
||||
DEFAULT_CFG_DICT = {
|
||||
# actions_array key will be added below
|
||||
"agents": {
|
||||
"num_consumers": NUMCONSUMERS,
|
||||
"num_firms": NUMFIRMS,
|
||||
"num_governments": NUMGOVERNMENTS,
|
||||
"global_state_dim": global_state_dim,
|
||||
"consumer_state_dim": consumer_state_dim,
|
||||
# action vectors are how much consume from each firm,
|
||||
# how much to work, and which firm to choose
|
||||
"consumer_action_dim": NUMFIRMS + 1 + 1,
|
||||
"consumer_num_consume_actions": consumption_choices.shape[0],
|
||||
"consumer_num_work_actions": work_choices.shape[0],
|
||||
"consumer_num_whichfirm_actions": NUMFIRMS,
|
||||
"firm_state_dim": firm_state_dim, # what are observations?
|
||||
# actions are price and wage for own firm, and capital choices
|
||||
"firm_action_dim": 3,
|
||||
"firm_num_actions": price_and_wage.shape[0],
|
||||
"government_state_dim": government_state_dim,
|
||||
"government_action_dim": 2,
|
||||
"government_num_actions": tax_choices.shape[0],
|
||||
"max_possible_consumption": float(consumption_choices.max()),
|
||||
"max_possible_hours_worked": float(work_choices.max()),
|
||||
"max_possible_wage": float(wage_choices.max()),
|
||||
"max_possible_price": float(price_choices.max()),
|
||||
# these are dims which, due to being on a large scale,
|
||||
# have to be expanded to a digit representation
|
||||
"consumer_digit_dims": global_state_digit_dims
|
||||
+ [global_state_dim], # global state + consumer budget
|
||||
"firm_digit_dims": global_state_digit_dims
|
||||
+ [global_state_dim], # global state + firm budget (do we need capital?)
|
||||
# govt only has global state
|
||||
"government_digit_dims": global_state_digit_dims,
|
||||
"firm_reward_scale": 10000,
|
||||
"government_reward_scale": 100000,
|
||||
"consumer_reward_scale": 50.0,
|
||||
"firm_anneal_wages": {
|
||||
"anneal_on": True,
|
||||
"start": 22.0,
|
||||
"increase_const": float(wage_choices.max() - 22.0)
|
||||
/ (episodes_to_anneal_firm),
|
||||
"decrease_const": (22.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"firm_anneal_prices": {
|
||||
"anneal_on": True,
|
||||
"start": 1000.0,
|
||||
"increase_const": float(price_choices.max() - 1000.00)
|
||||
/ episodes_to_anneal_firm,
|
||||
"decrease_const": (1000.0) / episodes_to_anneal_firm,
|
||||
},
|
||||
"government_anneal_taxes": {
|
||||
"anneal_on": True,
|
||||
"start": 0.0,
|
||||
"increase_const": 1.0 / episodes_to_anneal_government,
|
||||
},
|
||||
"firm_begin_anneal_action": 0,
|
||||
"government_begin_anneal_action": government_phase1_start,
|
||||
"consumer_anneal_theta": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
},
|
||||
"consumer_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"firm_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"govt_anneal_entropy": {
|
||||
"anneal_on": True,
|
||||
"exp_decay_length_in_steps": episodes_const,
|
||||
"coef_floor": 0.1,
|
||||
},
|
||||
"consumer_noponzi_eta": 0.0,
|
||||
"consumer_penalty_scale": 1.0,
|
||||
"firm_noponzi_eta": 0.0,
|
||||
"firm_training_start": episodes_to_anneal_firm,
|
||||
"government_training_start": government_phase1_start
|
||||
+ episodes_to_anneal_government,
|
||||
"consumer_training_start": 0,
|
||||
"government_counts_firm_reward": 0,
|
||||
"should_boost_firm_reward": False,
|
||||
"firm_reward_for_government_factor": 0.0025,
|
||||
"train_firms_every": 2,
|
||||
"train_consumers_every": 1,
|
||||
"train_government_every": 5,
|
||||
},
|
||||
"world": {
|
||||
"maxtime": 10,
|
||||
"initial_firm_endowment": 22.0 * 1000 * NUMCONSUMERS,
|
||||
"initial_consumer_endowment": 2000,
|
||||
"initial_stocks": 0.0,
|
||||
"initial_prices": 1000.0,
|
||||
"initial_wages": 22.0,
|
||||
"interest_rate": 0.1,
|
||||
"consumer_theta": 0.01,
|
||||
"crra_param": 0.1,
|
||||
"production_alpha": "fixed_array", # only works for exactly 10 firms, kluge
|
||||
"initial_capital": "twolevel",
|
||||
"paretoscaletheta": 4.0,
|
||||
"importer_price": 500.0,
|
||||
"importer_quantity": 100.0,
|
||||
"use_importer": 1,
|
||||
},
|
||||
"train": {
|
||||
"batch_size": 8,
|
||||
"base_seed": 1234,
|
||||
"save_dense_every": 2000,
|
||||
"save_model_every": 10000,
|
||||
"num_episodes": 100,
|
||||
"infinite_episodes": False,
|
||||
"lr": 0.01,
|
||||
"gamma": 0.9999,
|
||||
"entropy": 0.0,
|
||||
"value_loss_weight": 1.0,
|
||||
"digit_representation_size": 10,
|
||||
"lagr_num_steps": 1,
|
||||
"boost_firm_reward_factor": 1.0,
|
||||
},
|
||||
}
|
||||
return (
|
||||
DEFAULT_CFG_DICT,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def global_state_scaling_factors(cfg_dict):
|
||||
max_wage = cfg_dict["agents"]["max_possible_wage"]
|
||||
max_price = cfg_dict["agents"]["max_possible_price"]
|
||||
num_firms = cfg_dict["agents"]["num_firms"]
|
||||
num_governments = cfg_dict["agents"]["num_governments"]
|
||||
maxtime = cfg_dict["world"]["maxtime"]
|
||||
|
||||
digit_size = cfg_dict["train"]["digit_representation_size"]
|
||||
|
||||
return torch.tensor(
|
||||
# prices, wages, stocks, overdemanded
|
||||
([max_price] * num_firms)
|
||||
+ ([max_wage] * num_firms)
|
||||
+ ([1.0] * num_firms * digit_size) # stocks are expanded to digit form
|
||||
+ ([1.0] * num_firms)
|
||||
+ ([1.0] * (2 * num_governments))
|
||||
+ [maxtime]
|
||||
)
|
||||
|
||||
|
||||
def consumer_state_scaling_factors(cfg_dict):
|
||||
global_state_scales = global_state_scaling_factors(cfg_dict)
|
||||
digit_size = cfg_dict["train"]["digit_representation_size"]
|
||||
consumer_scales = torch.tensor(
|
||||
([1.0] * digit_size) + [cfg_dict["world"]["consumer_theta"]]
|
||||
)
|
||||
return torch.cat((global_state_scales, consumer_scales)).cuda()
|
||||
|
||||
|
||||
def firm_state_scaling_factors(cfg_dict):
|
||||
num_firms = cfg_dict["agents"]["num_firms"]
|
||||
global_state_scales = global_state_scaling_factors(cfg_dict)
|
||||
digit_size = cfg_dict["train"]["digit_representation_size"]
|
||||
# budget, capital, alpha, one-hot
|
||||
firm_scales = torch.tensor(
|
||||
([1.0] * digit_size) + [10000.0, 1.0] + ([1.0] * num_firms)
|
||||
)
|
||||
return torch.cat((global_state_scales, firm_scales)).cuda()
|
||||
|
||||
|
||||
def govt_state_scaling_factors(cfg_dict):
|
||||
return global_state_scaling_factors(cfg_dict).cuda()
|
||||
912
ai_economist/real_business_cycle/rbc/cuda/firm_rbc.cu
Normal file
912
ai_economist/real_business_cycle/rbc/cuda/firm_rbc.cu
Normal file
@@ -0,0 +1,912 @@
|
||||
// Copyright (c) 2021, salesforce.com, inc.
|
||||
// All rights reserved.
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
// For full license text, see the LICENSE file in the repo root
|
||||
// or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
|
||||
// Real Business Chain implementation in CUDA C
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <math.h>
|
||||
|
||||
typedef enum {
|
||||
kConsumerType,
|
||||
kFirmType,
|
||||
kGovernmentType,
|
||||
} AgentType;
|
||||
|
||||
const size_t kBatchSize = M_BATCHSIZE;
|
||||
const size_t kNumConsumers = M_NUMCONSUMERS;
|
||||
const bool kCountFirmReward = M_COUNTFIRMREWARD;
|
||||
const size_t kNumFirms = M_NUMFIRMS;
|
||||
const size_t kNumGovts = M_NUMGOVERNMENTS;
|
||||
const float kMaxTime = M_MAXTIME;
|
||||
const float kCrraParam = M_CRRA_PARAM;
|
||||
const float kInterestRate = M_INTERESTRATE;
|
||||
const size_t kNumAgents = kNumConsumers + kNumFirms + kNumGovts;
|
||||
const bool kIncentivizeFirmActivity = M_SHOULDBOOSTFIRMREWARD;
|
||||
const float kFirmBoostRewardFactor = M_BOOSTFIRMREWARDFACTOR;
|
||||
const bool kUseImporter = M_USEIMPORTER;
|
||||
const float kImporterPrice = M_IMPORTERPRICE;
|
||||
const float kImporterQuantity = M_IMPORTERQUANTITY;
|
||||
const float kLaborFloor = M_LABORFLOOR;
|
||||
|
||||
// Global state =
|
||||
const size_t kNumPrices = kNumFirms; // - prices,
|
||||
const size_t kNumWages = kNumFirms; // - wages,
|
||||
const size_t kNumInventories = kNumFirms; // - stocks,
|
||||
const size_t kNumOverdemandFlags = kNumFirms; // - good overdemanded flag,
|
||||
const size_t kNumCorporateTaxes = kNumGovts; // - corporate tax rate
|
||||
const size_t kNumIncomeTaxes = kNumGovts; // - income tax rate
|
||||
const size_t kNumTimeDimensions = 1; // - time step
|
||||
const size_t kGlobalStateSize = kNumPrices + kNumWages + kNumInventories + kNumOverdemandFlags + kNumCorporateTaxes + kNumIncomeTaxes + kNumTimeDimensions;
|
||||
|
||||
const size_t kIdxPricesOffset = 0;
|
||||
const size_t kIdxWagesOffset = kNumPrices;
|
||||
const size_t kIdxStockOffset = kIdxWagesOffset + kNumInventories;
|
||||
const size_t kIdxOverdemandOffset = kIdxStockOffset + kNumOverdemandFlags;
|
||||
const size_t kIdxIncomeTaxOffset = kGlobalStateSize - 3;
|
||||
const size_t kIdxCorporateTaxOffset = kGlobalStateSize - 2;
|
||||
const size_t kIdxTimeOffset = kGlobalStateSize - 1;
|
||||
|
||||
// Consumer actions: consume, work, choose which firm to work for
|
||||
const size_t kActionSizeConsumer = kNumFirms + 1 + 1;
|
||||
|
||||
// add budget and theta
|
||||
const size_t kStateSizeConsumer = kGlobalStateSize + 1 + 1;
|
||||
const size_t kIdxConsumerBudgetOffset = 0;
|
||||
|
||||
// offset from agent-specific state part of array
|
||||
const size_t kIdxConsumerThetaOffset = 1;
|
||||
|
||||
// UNUSED for consumer. Actions are floats.
|
||||
// __constant__ float cs_index_to_action[num_actions_consumer *
|
||||
// kActionSizeConsumer]; const size_t kActionSizeConsumer = kNumFirms +
|
||||
// kNumFirms; // consume + work
|
||||
/*const size_t num_actions_consumer =
|
||||
NUMACTIONSkConsumerType; // depends on discretization*/
|
||||
|
||||
// Firm actions: set wage, set price, invest in capital
|
||||
const size_t kActionSizeFirm = 3;
|
||||
|
||||
// Number of actions depends on discretization of continuous action space.
|
||||
const size_t kNumActionsFirm = M_NUMACTIONSFIRM;
|
||||
|
||||
// budget, capital, production alpha, and one-hot firm ID
|
||||
const size_t kStateSizeFirm = kGlobalStateSize + 1 + 1 + 1 + kNumFirms;
|
||||
|
||||
// offset from agent-specific state part of array
|
||||
const size_t kIdxFirmBudgetOffset = 0;
|
||||
const size_t kIdxFirmCapitalOffset = 1;
|
||||
const size_t kIdxFirmAlphaOffset = 2;
|
||||
const size_t kIdxFirmOnehotOffset = 3;
|
||||
|
||||
// Constant memory available from ALL threads.
|
||||
// See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#constant
|
||||
__constant__ float kFirmIndexToAction[kNumActionsFirm * kActionSizeFirm];
|
||||
|
||||
// Corporate + income tax rates
|
||||
const size_t kGovtActionSize = 2;
|
||||
const size_t kNumActionsGovernment = M_NUMACTIONSGOVERNMENT;
|
||||
const size_t kGovtStateSize = kGlobalStateSize;
|
||||
__constant__ float
|
||||
kGovernmentIndexToAction[kNumActionsGovernment * kGovtActionSize];
|
||||
|
||||
// One RNG state for each thread. Each thread is assigned to an agent in an env.
|
||||
__device__ curandState_t
|
||||
*rng_state_arr[kBatchSize * kNumAgents]; // not sure best way to do this
|
||||
|
||||
// Offsets into action vectors
|
||||
const size_t kIdxConsumerDemandedOffset = 0;
|
||||
const size_t kIdxConsumerWorkedOffset = kNumFirms;
|
||||
const size_t kIdxConsumerWhichFirmOffset = kNumFirms + 1;
|
||||
|
||||
// currently 1 govt
|
||||
const size_t kIdxThisThreadGovtId = 0;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// ------------------
|
||||
// CUDA C Utilities
|
||||
// ------------------
|
||||
__device__ void CopyFloatArraySlice(float *start_point, int num_elems,
|
||||
float *destination) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
destination[i] = start_point[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void CopyIntArraySlice(int *start_point, int num_elems,
|
||||
float *destination) {
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
destination[i] = start_point[i];
|
||||
}
|
||||
}
|
||||
|
||||
// unfortunately, you can't do templates with extern "C" linkage required for
|
||||
// CUDA, so we have to define different functions for each case.
|
||||
__device__ int *GetPointerFromMultiIndexFor3DIntTensor(int *array,
|
||||
const dim3 &sizes,
|
||||
const dim3 &index) {
|
||||
unsigned int flat_index =
|
||||
index.z + index.y * (sizes.z) + index.x * (sizes.z * sizes.y);
|
||||
return &(array[flat_index]);
|
||||
}
|
||||
|
||||
__device__ float *GetPointerFromMultiIndexFor3DFloatTensor(float *array,
|
||||
const dim3 &sizes,
|
||||
const dim3 &index) {
|
||||
unsigned int flat_index =
|
||||
index.z + index.y * (sizes.z) + index.x * (sizes.z * sizes.y);
|
||||
return &(array[flat_index]);
|
||||
}
|
||||
|
||||
__device__ float *
|
||||
GetPointerFromMultiIndexFor4DTensor(float *array, const size_t *sizes,
|
||||
const size_t *multi_index) {
|
||||
// don't use this for arrays that arne't exactly size 4!!!
|
||||
unsigned int flat_index = multi_index[3] + multi_index[2] * sizes[3] +
|
||||
multi_index[1] * sizes[3] * sizes[2] +
|
||||
multi_index[0] * sizes[3] * sizes[2] * sizes[1];
|
||||
return &(array[flat_index]);
|
||||
}
|
||||
|
||||
__global__ void CudaInitKernel(int seed) {
|
||||
// we want to reset random seeds for all firms and consumers
|
||||
int tidx = threadIdx.x;
|
||||
const int kThisThreadGlobalArrayIdx = blockIdx.x * kNumAgents + threadIdx.x;
|
||||
|
||||
if (tidx < kNumAgents) {
|
||||
curandState_t *s = new curandState_t;
|
||||
if (s != 0) {
|
||||
curand_init(seed, kThisThreadGlobalArrayIdx, 0, s);
|
||||
}
|
||||
rng_state_arr[kThisThreadGlobalArrayIdx] = s;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CudaFreeRand() {
|
||||
int tidx = threadIdx.x;
|
||||
const int kThisThreadGlobalArrayIdx = blockIdx.x * kNumAgents + threadIdx.x;
|
||||
|
||||
if (tidx < kNumAgents) {
|
||||
curandState_t *s = rng_state_arr[kThisThreadGlobalArrayIdx];
|
||||
delete s;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ int SearchIndex(float *distr, float p, int l, int r) {
|
||||
int mid;
|
||||
int left = l;
|
||||
int right = r;
|
||||
|
||||
while (left <= right) {
|
||||
mid = left + (right - left) / 2;
|
||||
if (distr[mid] == p) {
|
||||
return mid;
|
||||
} else if (distr[mid] < p) {
|
||||
left = mid + 1;
|
||||
} else {
|
||||
right = mid - 1;
|
||||
}
|
||||
}
|
||||
return left > r ? r : left;
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Simulation Utilities
|
||||
// --------------------
|
||||
__device__ AgentType GetAgentType(const int agent_id) {
|
||||
if (agent_id < kNumConsumers) {
|
||||
return kConsumerType;
|
||||
} else if (agent_id < (kNumConsumers + kNumFirms)) {
|
||||
return kFirmType;
|
||||
} else {
|
||||
return kGovernmentType;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ float GetCRRAUtil(float consumption, float crra_param) {
|
||||
return (powf(consumption + 1, 1.0 - crra_param) - 1.0) / (1.0 - crra_param);
|
||||
}
|
||||
|
||||
__global__ void CudaResetEnv(float *cs_state_arr, float *fm_state_arr,
|
||||
float *govt_state_arr, float *cs_state_ckpt_arr,
|
||||
float *fm_state_ckpt_arr,
|
||||
float *govt_state_ckpt_arr,
|
||||
float theta_anneal_factor) {
|
||||
/*
|
||||
Resets the environment by writing the initial state (checkpoint) into the
|
||||
state array for this agent (thread).
|
||||
*/
|
||||
|
||||
const int kBlockId = blockIdx.x;
|
||||
const int kWithinBlockAgentId = threadIdx.x;
|
||||
|
||||
if (kWithinBlockAgentId >= kNumAgents) {
|
||||
return;
|
||||
}
|
||||
|
||||
AgentType ThisThreadAgentType = GetAgentType(kWithinBlockAgentId);
|
||||
|
||||
float *state_arr;
|
||||
float *ckpt_arr;
|
||||
dim3 my_state_shape, my_state_idx;
|
||||
size_t my_state_size;
|
||||
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
// This thread/agent is a consumer.
|
||||
my_state_size = kStateSizeConsumer;
|
||||
my_state_shape = {kBatchSize, kNumConsumers, kStateSizeConsumer};
|
||||
my_state_idx = {kBlockId, kWithinBlockAgentId, 0};
|
||||
state_arr = cs_state_arr;
|
||||
ckpt_arr = cs_state_ckpt_arr;
|
||||
} else if (ThisThreadAgentType == kFirmType) {
|
||||
// This thread/agent is a firm.
|
||||
my_state_size = kStateSizeFirm;
|
||||
my_state_shape = {kBatchSize, kNumFirms, kStateSizeFirm};
|
||||
my_state_idx = {kBlockId,
|
||||
(unsigned int)(kWithinBlockAgentId - kNumConsumers), 0};
|
||||
state_arr = fm_state_arr;
|
||||
ckpt_arr = fm_state_ckpt_arr;
|
||||
} else {
|
||||
// This thread/agent is government.
|
||||
my_state_size = kGovtStateSize;
|
||||
my_state_shape = {kBatchSize, kNumGovts, kGovtStateSize};
|
||||
my_state_idx = {
|
||||
kBlockId,
|
||||
(unsigned int)(kWithinBlockAgentId - kNumConsumers - kNumFirms), 0};
|
||||
state_arr = govt_state_arr;
|
||||
ckpt_arr = govt_state_ckpt_arr;
|
||||
}
|
||||
|
||||
float *my_state_arr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
state_arr, my_state_shape, my_state_idx);
|
||||
|
||||
float *my_ckpt_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
ckpt_arr, my_state_shape, my_state_idx);
|
||||
|
||||
CopyFloatArraySlice(my_ckpt_ptr, my_state_size, my_state_arr);
|
||||
|
||||
// anneal theta
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
my_state_arr[kGlobalStateSize + kIdxConsumerThetaOffset] *=
|
||||
theta_anneal_factor;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void GetAction(float *action_arr, float *index_to_action_arr,
|
||||
int index, int agent_idx, int agent_action_size) {
|
||||
|
||||
// it needs to be possible to call this for either agents or firms
|
||||
// Note: each thread is an agent.
|
||||
for (int i = 0; i < agent_action_size; i++) {
|
||||
action_arr[agent_idx * agent_action_size + i] =
|
||||
index_to_action_arr[index * agent_action_size + i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CudaSampleFirmAndGovernmentActions(
|
||||
float *fm_distr, int *fm_action_indices_arr, float *fm_actions_arr,
|
||||
float *govt_distr, int *govt_action_indices_arr, float *govt_actions_arr) {
|
||||
// Samples actions for firms and governments. Consumer actions are sampled in
|
||||
// Pytorch...
|
||||
const int kWithinBlockAgentId = threadIdx.x;
|
||||
|
||||
// Unused threads should not do anything.
|
||||
if (threadIdx.x >= kNumAgents) {
|
||||
return;
|
||||
}
|
||||
|
||||
AgentType ThisThreadAgentType = GetAgentType(kWithinBlockAgentId);
|
||||
|
||||
// Index into rand states array
|
||||
int kThisThreadGlobalArrayIdx = blockIdx.x * kNumAgents + threadIdx.x;
|
||||
curandState_t rng_state = *rng_state_arr[kThisThreadGlobalArrayIdx];
|
||||
*rng_state_arr[kThisThreadGlobalArrayIdx] = rng_state;
|
||||
|
||||
// float cs_cum_dist[num_actions_consumer];
|
||||
float fm_cum_dist[kNumActionsFirm];
|
||||
float govt_cum_dist[kNumActionsGovernment];
|
||||
|
||||
float *my_cumul_dist;
|
||||
float *my_dist;
|
||||
int *my_indices;
|
||||
float *my_actions;
|
||||
float *index_to_action;
|
||||
int this_thread_global_array_idx;
|
||||
size_t my_num_actions;
|
||||
int my_action_size;
|
||||
|
||||
// Consumers have multiple action heads, hence sampling is more complicated.
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
return;
|
||||
} else if (ThisThreadAgentType == kFirmType) {
|
||||
// on firm thread
|
||||
my_cumul_dist = fm_cum_dist;
|
||||
my_dist = fm_distr;
|
||||
my_indices = fm_action_indices_arr;
|
||||
my_actions = fm_actions_arr;
|
||||
my_num_actions = kNumActionsFirm;
|
||||
index_to_action = kFirmIndexToAction;
|
||||
my_action_size = kActionSizeFirm;
|
||||
this_thread_global_array_idx =
|
||||
(blockIdx.x * kNumFirms) + (threadIdx.x - kNumConsumers);
|
||||
} else {
|
||||
my_cumul_dist = govt_cum_dist;
|
||||
my_dist = govt_distr;
|
||||
my_indices = govt_action_indices_arr;
|
||||
my_actions = govt_actions_arr;
|
||||
my_num_actions = kNumActionsGovernment;
|
||||
index_to_action = kGovernmentIndexToAction;
|
||||
my_action_size = kGovtActionSize;
|
||||
this_thread_global_array_idx =
|
||||
(blockIdx.x * kNumGovts) + (threadIdx.x - kNumConsumers - kNumFirms);
|
||||
}
|
||||
|
||||
// Compute CDF
|
||||
my_cumul_dist[0] = my_dist[this_thread_global_array_idx * my_num_actions];
|
||||
for (int i = 1; i < my_num_actions; i++) {
|
||||
my_cumul_dist[i] =
|
||||
my_dist[this_thread_global_array_idx * my_num_actions + i] +
|
||||
my_cumul_dist[i - 1];
|
||||
}
|
||||
|
||||
// Given sampled action which is a float in [0, 1], find the corresponding
|
||||
// discrete action.
|
||||
float sampled_float = curand_uniform(&rng_state);
|
||||
const int index =
|
||||
SearchIndex(my_cumul_dist, sampled_float, 0, (int)(my_num_actions - 1));
|
||||
my_indices[this_thread_global_array_idx] = index;
|
||||
GetAction(my_actions, index_to_action, index, this_thread_global_array_idx,
|
||||
my_action_size);
|
||||
}
|
||||
|
||||
__device__ float GetFirmProduction(float technology, float capital, float hours,
|
||||
float alpha) {
|
||||
if (hours < kLaborFloor) {
|
||||
hours = 0.0;
|
||||
}
|
||||
return technology * powf(capital, 1.0 - alpha) * powf(hours, alpha);
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Simulation Logic
|
||||
// --------------------
|
||||
__global__ void
|
||||
CudaStep(float *cs_state_arr, float *cs_actions_arr, float *cs_rewards_arr,
|
||||
float *cs_state_arr_batch, float *cs_rewards_arr_batch,
|
||||
|
||||
float *fm_state_arr, int *fm_action_indices_arr, float *fm_actions_arr,
|
||||
float *fm_rewards_arr, float *fm_state_arr_batch,
|
||||
int *fm_actions_arr_batch, float *fm_rewards_arr_batch,
|
||||
|
||||
float *govt_state_arr, int *govt_action_indices_arr,
|
||||
float *govt_actions_arr, float *govt_rewards_arr,
|
||||
float *govt_state_arr_batch, int *govt_actions_arr_batch,
|
||||
float *govt_rewards_arr_batch,
|
||||
float *consumer_aux_batch,
|
||||
float *firm_aux_batch,
|
||||
int iter) {
|
||||
// This function should be called with 1 block per copy of the environment.
|
||||
// Within a block, each agent should have a thread.
|
||||
const int kWithinBlockAgentId = threadIdx.x;
|
||||
|
||||
// return if we're on an extra thread not corresponding to an agent
|
||||
if (kWithinBlockAgentId >= kNumAgents) {
|
||||
return;
|
||||
}
|
||||
|
||||
// -------------------------------------
|
||||
// Start of variables and pointers defs.
|
||||
// -------------------------------------
|
||||
|
||||
// __shared__ variables are block-local: can be seen by each thread ** in the
|
||||
// block **
|
||||
__shared__ float gross_demand_arr[kNumFirms];
|
||||
__shared__ int num_consumer_demand_arr[kNumFirms];
|
||||
__shared__ float hours_worked_arr[kNumFirms];
|
||||
__shared__ float total_actually_consumer_arr[kNumFirms];
|
||||
__shared__ float bought_by_importer_arr[kNumFirms];
|
||||
__shared__ float next_global_state_arr[kGlobalStateSize];
|
||||
__shared__ float tax_revenue_arr[kNumGovts];
|
||||
__shared__ float total_utility_arr[kNumGovts];
|
||||
__shared__ bool need_to_ration_this_good_arr[kNumFirms]; // whether or not to
|
||||
// ration good i
|
||||
|
||||
float net_demand_arr[kNumFirms]; // amount demanded after budget constraints
|
||||
// by a consumer (ignore for non-consumers)
|
||||
|
||||
int num_iter = (int)kMaxTime;
|
||||
AgentType ThisThreadAgentType = GetAgentType(kWithinBlockAgentId);
|
||||
float this_agent_reward = 0.0;
|
||||
|
||||
// pointer to start of state vector
|
||||
// state vector consists of global state, then
|
||||
// agent-specific state global part is of same size for
|
||||
// all agents, but needs to be sliced out of different
|
||||
// arrays depending on agent type
|
||||
float *my_global_state_ptr;
|
||||
|
||||
float *my_action_arr;
|
||||
|
||||
// pointer to start of state vector in batch history
|
||||
float *batch_state_ptr;
|
||||
|
||||
// sizes and indices for strided array access
|
||||
dim3 my_state_shape, my_state_idx, action_shape;
|
||||
|
||||
// shape for batched array of scalars (action ind and reward)
|
||||
dim3 batch_single_shape, single_idx;
|
||||
|
||||
// pointer to action index
|
||||
int *batch_action_value_ptr;
|
||||
|
||||
// pointer to batch reward
|
||||
float *batch_reward_value_ptr;
|
||||
|
||||
// pointers to action index for current arrays
|
||||
int *my_action_value_ptr;
|
||||
|
||||
// pointers to reward index for current arrays
|
||||
float *my_reward_value_ptr;
|
||||
|
||||
float *my_aux_batch_ptr;
|
||||
|
||||
// -----------------------------------
|
||||
// End of variables and pointers defs.
|
||||
// -----------------------------------
|
||||
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
// get current state
|
||||
my_state_shape = {kBatchSize, kNumConsumers, kStateSizeConsumer};
|
||||
my_state_idx = {blockIdx.x, threadIdx.x, 0};
|
||||
my_global_state_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
cs_state_arr, my_state_shape, my_state_idx); // index)
|
||||
|
||||
// get current action
|
||||
action_shape = {kBatchSize, kNumConsumers, kActionSizeConsumer};
|
||||
my_action_arr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
cs_actions_arr, action_shape, my_state_idx);
|
||||
|
||||
// index into the episode history and save prev state into it
|
||||
size_t my_batch_state_shape[] = {kBatchSize, num_iter, kNumConsumers,
|
||||
kStateSizeConsumer};
|
||||
size_t my_batch_state_idx[] = {blockIdx.x, iter, threadIdx.x, 0};
|
||||
batch_state_ptr = GetPointerFromMultiIndexFor4DTensor(
|
||||
cs_state_arr_batch, my_batch_state_shape, my_batch_state_idx);
|
||||
CopyFloatArraySlice(my_global_state_ptr, kStateSizeConsumer,
|
||||
batch_state_ptr);
|
||||
|
||||
size_t my_aux_batch_shape[] = {kBatchSize, num_iter, kNumConsumers, kNumFirms};
|
||||
my_aux_batch_ptr = GetPointerFromMultiIndexFor4DTensor(
|
||||
consumer_aux_batch, my_aux_batch_shape, my_batch_state_idx
|
||||
);
|
||||
|
||||
|
||||
// Extract pointers to rewards, batch and current
|
||||
batch_single_shape = {kBatchSize, (unsigned int)num_iter, kNumConsumers};
|
||||
single_idx = {blockIdx.x, (unsigned int)iter, threadIdx.x};
|
||||
|
||||
batch_reward_value_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
cs_rewards_arr_batch, batch_single_shape, single_idx);
|
||||
|
||||
my_reward_value_ptr =
|
||||
&(cs_rewards_arr[blockIdx.x * kNumConsumers + threadIdx.x]);
|
||||
}
|
||||
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
// get current state
|
||||
size_t this_thread_firm_id = (threadIdx.x - kNumConsumers);
|
||||
my_state_shape = {kBatchSize, kNumFirms, kStateSizeFirm};
|
||||
my_state_idx = {blockIdx.x, (unsigned int)this_thread_firm_id, 0};
|
||||
my_global_state_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
fm_state_arr, my_state_shape, my_state_idx);
|
||||
|
||||
// get current action
|
||||
action_shape = {kBatchSize, kNumFirms, kActionSizeFirm};
|
||||
my_action_arr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
fm_actions_arr, action_shape, my_state_idx);
|
||||
|
||||
// index into the episode history and save prev state into it
|
||||
size_t my_batch_state_shape[] = {kBatchSize, num_iter, kNumFirms,
|
||||
kStateSizeFirm};
|
||||
size_t my_batch_state_idx[] = {blockIdx.x, iter, this_thread_firm_id, 0};
|
||||
batch_state_ptr = GetPointerFromMultiIndexFor4DTensor(
|
||||
fm_state_arr_batch, my_batch_state_shape, my_batch_state_idx);
|
||||
CopyFloatArraySlice(my_global_state_ptr, kStateSizeFirm, batch_state_ptr);
|
||||
|
||||
dim3 my_aux_batch_shape = {kBatchSize, num_iter, kNumFirms};
|
||||
dim3 aux_batch_idx = {blockIdx.x, iter, this_thread_firm_id};
|
||||
my_aux_batch_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
firm_aux_batch, my_aux_batch_shape, aux_batch_idx
|
||||
);
|
||||
// extract pointers to action indices and rewards, batch and current
|
||||
batch_single_shape = {kBatchSize, (unsigned int)num_iter, kNumFirms};
|
||||
single_idx = {blockIdx.x, (unsigned int)iter,
|
||||
(unsigned int)this_thread_firm_id};
|
||||
batch_action_value_ptr = GetPointerFromMultiIndexFor3DIntTensor(
|
||||
fm_actions_arr_batch, batch_single_shape, single_idx);
|
||||
batch_reward_value_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
fm_rewards_arr_batch, batch_single_shape, single_idx);
|
||||
|
||||
const int kThisThreadFirmIdx = blockIdx.x * kNumFirms + this_thread_firm_id;
|
||||
my_action_value_ptr = &(fm_action_indices_arr[kThisThreadFirmIdx]);
|
||||
my_reward_value_ptr = &(fm_rewards_arr[kThisThreadFirmIdx]);
|
||||
}
|
||||
|
||||
if (ThisThreadAgentType == kGovernmentType) {
|
||||
|
||||
int this_thread_govt_id = (threadIdx.x - kNumConsumers - kNumFirms);
|
||||
|
||||
my_state_shape = {kBatchSize, kNumGovts, kGovtStateSize};
|
||||
my_state_idx = {blockIdx.x, (unsigned int)this_thread_govt_id, 0};
|
||||
my_global_state_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
govt_state_arr, my_state_shape, my_state_idx); // index)
|
||||
|
||||
// get current action
|
||||
action_shape = {kBatchSize, kNumGovts, kGovtActionSize};
|
||||
my_action_arr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
govt_actions_arr, action_shape, my_state_idx);
|
||||
|
||||
// index into the episode history and save prev state into it
|
||||
size_t my_batch_state_shape[] = {kBatchSize, num_iter, kNumGovts,
|
||||
kGovtStateSize};
|
||||
size_t my_batch_state_idx[] = {blockIdx.x, iter, this_thread_govt_id, 0};
|
||||
batch_state_ptr = GetPointerFromMultiIndexFor4DTensor(
|
||||
govt_state_arr_batch, my_batch_state_shape, my_batch_state_idx);
|
||||
CopyFloatArraySlice(my_global_state_ptr, kGovtStateSize, batch_state_ptr);
|
||||
|
||||
// extract pointers to action indices and rewards, batch and current
|
||||
batch_single_shape = {kBatchSize, (unsigned int)num_iter, kNumGovts};
|
||||
single_idx = {blockIdx.x, (unsigned int)iter,
|
||||
(unsigned int)this_thread_govt_id};
|
||||
batch_action_value_ptr = GetPointerFromMultiIndexFor3DIntTensor(
|
||||
govt_actions_arr_batch, batch_single_shape, single_idx);
|
||||
batch_reward_value_ptr = GetPointerFromMultiIndexFor3DFloatTensor(
|
||||
govt_rewards_arr_batch, batch_single_shape, single_idx);
|
||||
|
||||
const int kThisThreadGovtIdx = blockIdx.x * kNumGovts + this_thread_govt_id;
|
||||
my_action_value_ptr = &(govt_action_indices_arr[kThisThreadGovtIdx]);
|
||||
my_reward_value_ptr = &(govt_rewards_arr[kThisThreadGovtIdx]);
|
||||
}
|
||||
|
||||
// ----------------------------------------------
|
||||
// State pointers and variables
|
||||
// Create pointers to agent-specific state that will be updated by the
|
||||
// simulation logic.
|
||||
// ----------------------------------------------
|
||||
float *my_state_arr = &(my_global_state_ptr[kGlobalStateSize]);
|
||||
float *prices_arr = &(my_global_state_ptr[kIdxPricesOffset]);
|
||||
float *wages_arr = &(my_global_state_ptr[kIdxWagesOffset]);
|
||||
float *available_stock_arr = &(my_global_state_ptr[kIdxStockOffset]);
|
||||
float time = my_global_state_ptr[kIdxTimeOffset];
|
||||
float income_tax_rate = my_global_state_ptr[kIdxIncomeTaxOffset];
|
||||
float corporate_tax_rate = my_global_state_ptr[kIdxCorporateTaxOffset];
|
||||
|
||||
// -------------------------------
|
||||
// Safely initialize shared memory
|
||||
// -------------------------------
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
int this_thread_firm_id = threadIdx.x - kNumConsumers;
|
||||
gross_demand_arr[this_thread_firm_id] = 0.0;
|
||||
num_consumer_demand_arr[this_thread_firm_id] = 0;
|
||||
hours_worked_arr[this_thread_firm_id] = 0.0;
|
||||
total_actually_consumer_arr[this_thread_firm_id] = 0.0;
|
||||
need_to_ration_this_good_arr[this_thread_firm_id] = false;
|
||||
}
|
||||
|
||||
if (ThisThreadAgentType == kGovernmentType) {
|
||||
tax_revenue_arr[0] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// -------------------------------------
|
||||
// End - Safely initialize shared memory
|
||||
// -------------------------------------
|
||||
|
||||
// -------------------------------------
|
||||
// Process actions
|
||||
// -------------------------------------
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
// amount demanded is just the first part of the action vector
|
||||
float *this_agent_gross_demand_arr =
|
||||
&(my_action_arr[kIdxConsumerDemandedOffset]);
|
||||
const float this_agent_hours_worked =
|
||||
my_action_arr[kIdxConsumerWorkedOffset];
|
||||
const int worked_for_this_firm_id =
|
||||
(int)my_action_arr[kIdxConsumerWhichFirmOffset];
|
||||
|
||||
// here, need to scale demands to meet the budget. put them in a local array
|
||||
// *budgetDemanded logic should be: compute total expenditure given prices.
|
||||
// if less than budget, copy existing demands
|
||||
float __cost_of_demand = 0.0;
|
||||
for (int i = 0; i < kNumFirms; i++) {
|
||||
__cost_of_demand += this_agent_gross_demand_arr[i] * prices_arr[i];
|
||||
}
|
||||
|
||||
// Scale demand to ensure that total demand at most equals total supply
|
||||
// we want: my_state_arr being 0 always sends __scale_factor to 0
|
||||
float __scale_factor = 1.0;
|
||||
|
||||
if ((__cost_of_demand > 0.0) && (__cost_of_demand > my_state_arr[0])) {
|
||||
__scale_factor = my_state_arr[0] / __cost_of_demand;
|
||||
}
|
||||
|
||||
// otherwise scale all demands down to meet budget
|
||||
// copy them into a demanded array
|
||||
for (int i = 0; i < kNumFirms; i++) {
|
||||
net_demand_arr[i] = __scale_factor * this_agent_gross_demand_arr[i];
|
||||
}
|
||||
|
||||
// adding up demand across threads **in the block**
|
||||
// somehow store amount demanded per firm in an array demanded (copy from
|
||||
// action) also store amount worked per firm in array worked
|
||||
|
||||
// Every thread executes atomicAdd_block in a memory-safe way.
|
||||
for (int i = 0; i < kNumFirms; i++) {
|
||||
// sum across threads in block
|
||||
atomicAdd_block(&(gross_demand_arr[i]), net_demand_arr[i]);
|
||||
|
||||
// increment count of consumers who want good i
|
||||
if (net_demand_arr[i] > 0) {
|
||||
atomicAdd_block(&(num_consumer_demand_arr[i]), 1);
|
||||
}
|
||||
}
|
||||
|
||||
// increment total hours worked for firm i
|
||||
atomicAdd_block(&(hours_worked_arr[worked_for_this_firm_id]),
|
||||
this_agent_hours_worked);
|
||||
}
|
||||
|
||||
// wait for everyone to finish tallying up their adding
|
||||
__syncthreads();
|
||||
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
// check each firm if rationing needed
|
||||
int this_thread_firm_id = threadIdx.x - kNumConsumers;
|
||||
need_to_ration_this_good_arr[this_thread_firm_id] =
|
||||
((gross_demand_arr[this_thread_firm_id] > 0.0) && (gross_demand_arr[this_thread_firm_id] >
|
||||
available_stock_arr[this_thread_firm_id]));
|
||||
}
|
||||
|
||||
// wait for single thread to finish checking demands
|
||||
__syncthreads();
|
||||
|
||||
// ----------------------------------------
|
||||
// Consumers: Rationing demand + Utility
|
||||
// ----------------------------------------
|
||||
// Logic:
|
||||
// case 1: no overdemand
|
||||
// case 2: overdemand, but some want less than 1/N -- fill everyone up to
|
||||
// max(theirs, 1/N) case 3: overdemand, everyone wants more -- fill everyone
|
||||
// up to max(theirs, 1/N)
|
||||
float net_consumed_arr[kNumFirms]; // per consumer thread
|
||||
// always add negligible positive money to avoid budgets becoming small
|
||||
// negative numbers otherwise, when computing proportions, one may end up with
|
||||
// negative stocks.
|
||||
float cs_budget_delta = 0.01;
|
||||
float fm_budget_delta = 0.01;
|
||||
float capital_delta = 0.0;
|
||||
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
// find out how much consumed
|
||||
for (int i = 0; i < kNumFirms; i++) {
|
||||
float __ration_factor = 1.0;
|
||||
|
||||
if (need_to_ration_this_good_arr[i]) {
|
||||
// overdemanded
|
||||
__ration_factor = available_stock_arr[i] / gross_demand_arr[i];
|
||||
}
|
||||
|
||||
net_consumed_arr[i] = __ration_factor * net_demand_arr[i];
|
||||
|
||||
atomicAdd_block(&(total_actually_consumer_arr[i]), net_consumed_arr[i]);
|
||||
}
|
||||
|
||||
// store amount actually consumed for this consumer
|
||||
CopyFloatArraySlice(net_consumed_arr, kNumFirms, my_aux_batch_ptr);
|
||||
|
||||
// ----------------------------------------
|
||||
// Compute consumer utility
|
||||
// ----------------------------------------
|
||||
float hours_worked = my_action_arr[kIdxConsumerWorkedOffset];
|
||||
int worked_for_this_firm_id =
|
||||
(int)my_action_arr[kIdxConsumerWhichFirmOffset];
|
||||
|
||||
// budget is first elem of consumer state, theta second
|
||||
float __theta = my_state_arr[1];
|
||||
|
||||
float __this_consumer_util = 0.0;
|
||||
float __total_hours_worked = 0.0;
|
||||
float __gross_income = 0.0;
|
||||
|
||||
// Compute expenses
|
||||
// Each consumer can consume from each firm, so loop over them.
|
||||
for (int i = 0; i < kNumFirms; i++) {
|
||||
__this_consumer_util += GetCRRAUtil(net_consumed_arr[i], kCrraParam);
|
||||
cs_budget_delta -= prices_arr[i] * net_consumed_arr[i];
|
||||
}
|
||||
|
||||
// Compute income
|
||||
__total_hours_worked += hours_worked;
|
||||
__gross_income += wages_arr[worked_for_this_firm_id] * hours_worked;
|
||||
float __income_tax_paid = income_tax_rate * __gross_income;
|
||||
cs_budget_delta += (__gross_income - __income_tax_paid);
|
||||
|
||||
// Update tax revenue (government)
|
||||
atomicAdd_block(&(tax_revenue_arr[kIdxThisThreadGovtId]),
|
||||
__income_tax_paid);
|
||||
|
||||
// Compute reward
|
||||
this_agent_reward +=
|
||||
__this_consumer_util - (__theta / 2.0) * (__total_hours_worked);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// ----------------------------------------
|
||||
// Firms Exports: Add external consumption.
|
||||
// ----------------------------------------
|
||||
if (ThisThreadAgentType == kFirmType ) {
|
||||
const int this_thread_firm_id = threadIdx.x - kNumConsumers;
|
||||
if (kUseImporter) {
|
||||
// sell remaining goods, if any, to importer, if price is high enough.
|
||||
float __this_firm_price = prices_arr[this_thread_firm_id];
|
||||
float __stock_after_consumers = available_stock_arr[this_thread_firm_id] - total_actually_consumer_arr[this_thread_firm_id];
|
||||
|
||||
if (__this_firm_price >= kImporterPrice) {
|
||||
bought_by_importer_arr[this_thread_firm_id] = fmaxf(fminf(__stock_after_consumers, kImporterQuantity), 0.0); // floor to zero to avoid small negative floats
|
||||
}
|
||||
else {
|
||||
bought_by_importer_arr[this_thread_firm_id] = 0.0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
bought_by_importer_arr[this_thread_firm_id] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------
|
||||
// Firms: Rationing demand + Utility
|
||||
// ----------------------------------------
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
const int this_thread_firm_id = threadIdx.x - kNumConsumers;
|
||||
|
||||
float __this_firm_revenue =
|
||||
(total_actually_consumer_arr[this_thread_firm_id] + bought_by_importer_arr[this_thread_firm_id]) *
|
||||
prices_arr[this_thread_firm_id];
|
||||
float __wages_paid =
|
||||
hours_worked_arr[this_thread_firm_id] * wages_arr[this_thread_firm_id];
|
||||
|
||||
// Firms can invest in new capital. This increases their production factor
|
||||
// (see GetFirmProduction).
|
||||
// here, after consumers consume, if price is >= than importer price, importer consumes up to their maximum of the goods, at the importer price
|
||||
|
||||
float __gross_income = __this_firm_revenue - __wages_paid;
|
||||
capital_delta = fmaxf(my_action_arr[2] * __gross_income, 0.0);
|
||||
float __gross_profit = __gross_income - capital_delta;
|
||||
float __corp_tax_paid = corporate_tax_rate * fmaxf(__gross_profit, 0.0);
|
||||
fm_budget_delta = (__gross_profit - __corp_tax_paid);
|
||||
if (kIncentivizeFirmActivity) {
|
||||
if ((fm_budget_delta + my_state_arr[0]) > 0.0) { // if positive budget
|
||||
this_agent_reward += (kFirmBoostRewardFactor * __this_firm_revenue);
|
||||
}
|
||||
}
|
||||
this_agent_reward += (__gross_profit - __corp_tax_paid);
|
||||
|
||||
atomicAdd_block(&(tax_revenue_arr[0]), __corp_tax_paid);
|
||||
|
||||
float __production = GetFirmProduction(0.01, my_state_arr[kIdxFirmCapitalOffset],
|
||||
hours_worked_arr[this_thread_firm_id], my_state_arr[kIdxFirmAlphaOffset]);
|
||||
|
||||
// -------------------
|
||||
// Update global state
|
||||
// -------------------
|
||||
// update prices in global state
|
||||
next_global_state_arr[kIdxPricesOffset + this_thread_firm_id] =
|
||||
my_action_arr[0];
|
||||
// update wages in global state
|
||||
next_global_state_arr[kIdxWagesOffset + this_thread_firm_id] =
|
||||
my_action_arr[1];
|
||||
// update stocks in global state
|
||||
next_global_state_arr[kIdxStockOffset + this_thread_firm_id] =
|
||||
available_stock_arr[this_thread_firm_id] -
|
||||
total_actually_consumer_arr[this_thread_firm_id] - bought_by_importer_arr[this_thread_firm_id] + __production;
|
||||
|
||||
*my_aux_batch_ptr = bought_by_importer_arr[this_thread_firm_id];
|
||||
|
||||
// update overdemanded in global state
|
||||
next_global_state_arr[kIdxOverdemandOffset + this_thread_firm_id] =
|
||||
need_to_ration_this_good_arr[this_thread_firm_id] ? 1.0 : 0.0;
|
||||
}
|
||||
|
||||
// -----------------
|
||||
// Move time forward
|
||||
// -----------------
|
||||
// Let first firm tick time
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
const int this_thread_firm_id = threadIdx.x - kNumConsumers;
|
||||
if (this_thread_firm_id == 0) {
|
||||
next_global_state_arr[kIdxTimeOffset] =
|
||||
my_global_state_ptr[kIdxTimeOffset] + 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// ----------------------------------------
|
||||
// Subsidies
|
||||
// ----------------------------------------
|
||||
// need to redistribute tax revenues
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
float __redistribution = tax_revenue_arr[0] / kNumConsumers;
|
||||
cs_budget_delta += __redistribution;
|
||||
}
|
||||
|
||||
// ----------------------------------------
|
||||
// Social welfare
|
||||
// ----------------------------------------
|
||||
// After this point consumers and firms know their final reward, so can inform
|
||||
// the government thread via shared memory
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// ----------------------------------------
|
||||
// Government sets taxes for the next round
|
||||
// ----------------------------------------
|
||||
if (ThisThreadAgentType == kGovernmentType) {
|
||||
next_global_state_arr[kIdxIncomeTaxOffset] = my_action_arr[0];
|
||||
next_global_state_arr[kIdxCorporateTaxOffset] = my_action_arr[1];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// -----------------------------------------------
|
||||
// Copy next_global_state_arr into my global state
|
||||
// -----------------------------------------------
|
||||
// All agents need to see the updated global state
|
||||
CopyFloatArraySlice(next_global_state_arr, kGlobalStateSize,
|
||||
my_global_state_ptr);
|
||||
|
||||
// -----------------------------------------------
|
||||
// Update budgets
|
||||
// -----------------------------------------------
|
||||
// Update budget (same for all agents)
|
||||
if (ThisThreadAgentType == kConsumerType) {
|
||||
my_state_arr[0] += cs_budget_delta;
|
||||
}
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
my_state_arr[0] += fm_budget_delta;
|
||||
}
|
||||
|
||||
// Add interest rate on savings
|
||||
if ((ThisThreadAgentType == kConsumerType) ||
|
||||
(ThisThreadAgentType == kFirmType)) {
|
||||
if (my_state_arr[0] > 0.0) {
|
||||
my_state_arr[0] += my_state_arr[0] * kInterestRate;
|
||||
}
|
||||
}
|
||||
|
||||
// Add new capital
|
||||
if (ThisThreadAgentType == kFirmType) {
|
||||
my_state_arr[kIdxFirmCapitalOffset] += capital_delta;
|
||||
}
|
||||
|
||||
// Add new capital
|
||||
if ((ThisThreadAgentType == kFirmType) ||
|
||||
(ThisThreadAgentType == kGovernmentType)) {
|
||||
*batch_action_value_ptr = *my_action_value_ptr;
|
||||
}
|
||||
|
||||
// Update rewards in global state
|
||||
*my_reward_value_ptr = this_agent_reward;
|
||||
*batch_reward_value_ptr = this_agent_reward;
|
||||
}
|
||||
|
||||
// ************************
|
||||
// End of extern "C" block.
|
||||
// ************************
|
||||
}
|
||||
1930
ai_economist/real_business_cycle/rbc/cuda_manager.py
Normal file
1930
ai_economist/real_business_cycle/rbc/cuda_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
114
ai_economist/real_business_cycle/rbc/networks.py
Normal file
114
ai_economist/real_business_cycle/rbc/networks.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class IndependentPolicyNet(nn.Module):
|
||||
"""
|
||||
Represents a policy network with separate heads for different types of actions.
|
||||
Thus, the resulting policy will take the form
|
||||
$pi(a | s) = pi_1(a_1 | s) pi_2(a_2 | s)...$
|
||||
"""
|
||||
|
||||
def __init__(self, state_size, action_size_list, norm_consts=None):
|
||||
super().__init__()
|
||||
|
||||
self.state_size = state_size
|
||||
self.action_size_list = action_size_list
|
||||
if norm_consts is not None:
|
||||
self.norm_center, self.norm_scale = norm_consts
|
||||
else:
|
||||
self.norm_center = torch.zeros(self.state_size).cuda()
|
||||
self.norm_scale = torch.ones(self.state_size).cuda()
|
||||
self.fc1 = nn.Linear(state_size, 128)
|
||||
self.fc2 = nn.Linear(128, 128)
|
||||
# policy network head
|
||||
self.action_heads = nn.ModuleList(
|
||||
[nn.Linear(128, action_size) for action_size in action_size_list]
|
||||
)
|
||||
# value network head
|
||||
self.fc4 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[-1] == self.state_size # Check if the last dimension matches
|
||||
|
||||
# Normalize the model input
|
||||
new_shape = tuple(1 for _ in x.shape[:-1]) + (x.shape[-1],)
|
||||
view_center = self.norm_center.view(new_shape)
|
||||
view_scale = self.norm_scale.view(new_shape)
|
||||
x = (x - view_center) / view_scale
|
||||
|
||||
# Feed forward
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
probs = [F.softmax(action_head(x), dim=-1) for action_head in self.action_heads]
|
||||
vals = self.fc4(x)
|
||||
return probs, vals
|
||||
|
||||
|
||||
class PolicyNet(nn.Module):
|
||||
"""
|
||||
The policy network class to output acton probabilities and the value function.
|
||||
"""
|
||||
|
||||
def __init__(self, state_size, action_size, norm_consts=None):
|
||||
super().__init__()
|
||||
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
if norm_consts is not None:
|
||||
self.norm_center, self.norm_scale = norm_consts
|
||||
else:
|
||||
self.norm_center = torch.zeros(self.state_size).cuda()
|
||||
self.norm_scale = torch.ones(self.state_size).cuda()
|
||||
self.fc1 = nn.Linear(state_size, 128)
|
||||
self.fc2 = nn.Linear(128, 128)
|
||||
# policy network head
|
||||
self.fc3 = nn.Linear(128, action_size)
|
||||
# value network head
|
||||
self.fc4 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x, actions_mask=None):
|
||||
# here, the action mask should be large negative constants for actions
|
||||
# that shouldn't be allowed.
|
||||
new_shape = tuple(1 for _ in x.shape[:-1]) + (x.shape[-1],)
|
||||
view_center = self.norm_center.view(new_shape)
|
||||
view_scale = self.norm_scale.view(new_shape)
|
||||
x = (x - view_center) / view_scale
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
if actions_mask is not None:
|
||||
probs = F.softmax(self.fc3(x) + actions_mask, dim=-1)
|
||||
else:
|
||||
probs = F.softmax(self.fc3(x), dim=-1)
|
||||
vals = self.fc4(x)
|
||||
return probs, vals
|
||||
|
||||
|
||||
class DeterministicPolicy:
|
||||
"""
|
||||
A policy class that outputs deterministic actions.
|
||||
"""
|
||||
|
||||
def __init__(self, state_size, action_size, action_choice):
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.action_choice = action_choice
|
||||
self.actions_out = torch.zeros(action_size, device="cuda")
|
||||
self.actions_out[self.action_choice] = 1.0
|
||||
|
||||
def __call__(self, x, actions_mask=None):
|
||||
return self.forward(x)
|
||||
|
||||
def forward(self, x):
|
||||
# output enough copies of the delta function
|
||||
# distribution of the right size given x
|
||||
x_batch_shapes = x.shape[:-1]
|
||||
repeat_vals = x_batch_shapes + (1,)
|
||||
return self.actions_out.repeat(*repeat_vals), None
|
||||
110
ai_economist/real_business_cycle/rbc/util.py
Normal file
110
ai_economist/real_business_cycle/rbc/util.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def dict_merge(dct, merge_dct):
|
||||
"""Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
|
||||
updating only top-level keys, dict_merge recurses down into dicts nested
|
||||
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
|
||||
``dct``.
|
||||
:param dct: dict onto which the merge is executed
|
||||
:param merge_dct: dct merged into dct
|
||||
:return: None
|
||||
"""
|
||||
for k, v in merge_dct.items():
|
||||
# dct does not have k yet. Add it with value v.
|
||||
if (k not in dct) and (not isinstance(dct, list)):
|
||||
dct[k] = v
|
||||
else:
|
||||
# dct[k] and merge_dict[k] are both dictionaries. Recurse.
|
||||
if isinstance(dct[k], (dict, list)) and isinstance(v, dict):
|
||||
dict_merge(dct[k], merge_dct[k])
|
||||
else:
|
||||
# dct[k] and merge_dict[k] are both tuples or lists.
|
||||
if isinstance(dct[k], (tuple, list)) and isinstance(v, (tuple, list)):
|
||||
# They don't match. Overwrite with v.
|
||||
if len(dct[k]) != len(v):
|
||||
dct[k] = v
|
||||
else:
|
||||
for i, (d_val, v_val) in enumerate(zip(dct[k], v)):
|
||||
if isinstance(d_val, dict) and isinstance(v_val, dict):
|
||||
dict_merge(d_val, v_val)
|
||||
else:
|
||||
dct[k][i] = v_val
|
||||
else:
|
||||
dct[k] = v
|
||||
|
||||
|
||||
def min_max_consumer_budget_delta(hparams_dict):
|
||||
# largest single round changes
|
||||
max_wage = hparams_dict["agents"]["max_possible_wage"]
|
||||
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
|
||||
max_price = hparams_dict["agents"]["max_possible_price"]
|
||||
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
|
||||
num_firms = hparams_dict["agents"]["num_firms"]
|
||||
|
||||
min_budget = (
|
||||
-max_singlefirm_consumption * max_price * num_firms
|
||||
) # negative budget from consuming only
|
||||
max_budget = max_hours * max_wage * num_firms # positive budget from only working
|
||||
return min_budget, max_budget
|
||||
|
||||
|
||||
def min_max_stock_delta(hparams_dict):
|
||||
# for now, assuming 1.0 capital
|
||||
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
|
||||
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
|
||||
alpha = hparams_dict["world"]["production_alpha"]
|
||||
if isinstance(alpha, str):
|
||||
alpha = 0.8
|
||||
num_consumers = hparams_dict["agents"]["num_consumers"]
|
||||
max_delta = (max_hours * num_consumers) ** alpha
|
||||
min_delta = -max_singlefirm_consumption * num_consumers
|
||||
return min_delta, max_delta
|
||||
|
||||
|
||||
def min_max_firm_budget(hparams_dict):
|
||||
max_wage = hparams_dict["agents"]["max_possible_wage"]
|
||||
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
|
||||
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
|
||||
num_consumers = hparams_dict["agents"]["num_consumers"]
|
||||
max_price = hparams_dict["agents"]["max_possible_price"]
|
||||
max_delta = max_singlefirm_consumption * max_price * num_consumers
|
||||
min_delta = -max_hours * max_wage * num_consumers
|
||||
return min_delta, max_delta
|
||||
|
||||
|
||||
def expand_to_digit_form(x, dims_to_expand, max_digits):
|
||||
# first split x up
|
||||
requires_grad = (
|
||||
x.requires_grad
|
||||
) # don't want to backprop through these ops, but do want
|
||||
# gradients if x had them
|
||||
with torch.no_grad():
|
||||
tensor_pieces = []
|
||||
expanded_digit_shape = x.shape[:-1] + (max_digits,)
|
||||
for i in range(x.shape[-1]):
|
||||
if i not in dims_to_expand:
|
||||
tensor_pieces.append(x[..., i : i + 1])
|
||||
else:
|
||||
digit_entries = torch.zeros(expanded_digit_shape, device=x.device)
|
||||
for j in range(max_digits):
|
||||
digit_entries[..., j] = (x[..., i] % (10 ** (j + 1))) / (
|
||||
10 ** (j + 1)
|
||||
)
|
||||
tensor_pieces.append(digit_entries)
|
||||
|
||||
output = torch.cat(tensor_pieces, dim=-1)
|
||||
output.requires_grad_(requires_grad)
|
||||
return output
|
||||
|
||||
|
||||
def size_after_digit_expansion(existing_size, dims_to_expand, max_digits):
|
||||
num_expanded = len(dims_to_expand)
|
||||
# num non expanded digits, + all the expanded ones
|
||||
return (existing_size - num_expanded) + (max_digits * num_expanded)
|
||||
108
ai_economist/real_business_cycle/train_bestresponse.py
Normal file
108
ai_economist/real_business_cycle/train_bestresponse.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from experiment_utils import cfg_dict_from_yaml
|
||||
from rbc.cuda_manager import ConsumerFirmRunManagerBatchParallel
|
||||
|
||||
|
||||
def check_if_ep_str_policy_exists(rollout_path, ep_str):
|
||||
return (
|
||||
rollout_path / Path("saved_models") / Path(f"consumer_policy_{ep_str}.pt")
|
||||
).is_file()
|
||||
|
||||
|
||||
def run_rollout(rollout_path, arguments):
|
||||
"""
|
||||
# take in rollout directory
|
||||
# load latest policies and the action functions and the hparams dict
|
||||
# make a cudamanager obj and run the job
|
||||
# this will require initializing everything as before, resetting,
|
||||
# and running naive policy gradient training at some fixed learning rate
|
||||
"""
|
||||
with open(rollout_path / Path("action_arrays.pickle"), "rb") as f:
|
||||
action_arrays = pickle.load(f)
|
||||
|
||||
consumption_choices, work_choices, price_and_wage, tax_choices = (
|
||||
action_arrays["consumption_choices"],
|
||||
action_arrays["work_choices"],
|
||||
action_arrays["price_and_wage"],
|
||||
action_arrays["tax_choices"],
|
||||
)
|
||||
|
||||
cfg_dict = cfg_dict_from_yaml(
|
||||
rollout_path / Path("hparams.yaml"),
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
)
|
||||
|
||||
print(cfg_dict)
|
||||
|
||||
if arguments.agent_type == "all":
|
||||
agent_types = ["consumer", "firm", "government"]
|
||||
else:
|
||||
agent_types = [arguments.agent_type]
|
||||
|
||||
for agent_type in agent_types:
|
||||
ep_rewards = defaultdict(list)
|
||||
for _ in range(arguments.repeat_runs):
|
||||
for ep_str in arguments.ep_strs:
|
||||
if not check_if_ep_str_policy_exists(rollout_path, ep_str):
|
||||
print(f"warning: {rollout_path} {ep_str} policy not found")
|
||||
ep_rewards[ep_str].append([0.0])
|
||||
continue
|
||||
m = ConsumerFirmRunManagerBatchParallel(cfg_dict)
|
||||
rewards_start = m.bestresponse_train(
|
||||
agent_type,
|
||||
arguments.num_episodes,
|
||||
rollout_path,
|
||||
ep_str=ep_str,
|
||||
checkpoint=arguments.checkpoint_model,
|
||||
)
|
||||
ep_rewards[ep_str].append(rewards_start)
|
||||
|
||||
with open(rollout_path / Path(f"br_{agent_type}_output.txt"), "w") as f:
|
||||
for ep_str in arguments.ep_strs:
|
||||
reward_arr = np.array(ep_rewards[ep_str])
|
||||
print(
|
||||
f"mean reward (std) on rollout {ep_str}: "
|
||||
f"before BR training {reward_arr[:,0].mean()} "
|
||||
f"({reward_arr[:,0].std()}), "
|
||||
f"after BR training {reward_arr[:,-1].mean()} "
|
||||
f"({reward_arr[:,-1].std()}), "
|
||||
f"mean improvement {(reward_arr[:,-1]-reward_arr[:,0]).mean()} "
|
||||
f"({(reward_arr[:,-1]-reward_arr[:,0]).std()}",
|
||||
file=f,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("rolloutdir", type=str)
|
||||
parser.add_argument("num_episodes", type=int)
|
||||
parser.add_argument("--experiment-dir", action="store_true")
|
||||
parser.add_argument("--ep-strs", nargs="+", default=["0", "latest"])
|
||||
parser.add_argument("--agent-type", type=str, default="all")
|
||||
parser.add_argument("--repeat-runs", type=int, default=1)
|
||||
parser.add_argument("--checkpoint-model", type=int, default=100)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.experiment_dir:
|
||||
exp_dir = Path(args.rolloutdir)
|
||||
for rolloutpath in exp_dir.iterdir():
|
||||
if rolloutpath.is_dir():
|
||||
run_rollout(rolloutpath, args)
|
||||
else:
|
||||
rolloutpath = Path(args.rolloutdir)
|
||||
run_rollout(rolloutpath, args)
|
||||
119
ai_economist/real_business_cycle/train_multi_exps.py
Normal file
119
ai_economist/real_business_cycle/train_multi_exps.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from experiment_utils import (
|
||||
create_job_dir,
|
||||
run_experiment_batch_parallel,
|
||||
sweep_cfg_generator,
|
||||
)
|
||||
from rbc.constants import all_agents_short_export_experiment_template
|
||||
|
||||
train_param_sweeps = {
|
||||
"lr": [0.001],
|
||||
"entropy": [0.5],
|
||||
"batch_size": [128],
|
||||
"clip_grad_norm": [2.0],
|
||||
"base_seed": [2345],
|
||||
"should_boost_firm_reward": [False],
|
||||
"use_ppo": [True],
|
||||
"ppo_num_updates": [2, 4],
|
||||
"ppo_clip_param": [0.1],
|
||||
}
|
||||
|
||||
|
||||
agent_param_sweeps = {
|
||||
"consumer_lr_multiple": [1.0],
|
||||
"consumer_reward_scale": [5.0],
|
||||
"government_reward_scale": [5.0 * 100.0 * 2.0],
|
||||
"firm_reward_scale": [30000],
|
||||
"government_counts_firm_reward": [1],
|
||||
"government_lr_multiple": [0.05],
|
||||
}
|
||||
|
||||
|
||||
world_param_sweeps = {
|
||||
"initial_wages": [0.0],
|
||||
"interest_rate": [0.0],
|
||||
"importer_price": [500.0],
|
||||
"importer_quantity": [100.0],
|
||||
"use_importer": [1],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--experiment-dir", type=str, default="experiment/experiment")
|
||||
parser.add_argument("--group-name", type=str, default="default_group")
|
||||
parser.add_argument("--job-name-base", type=str, default="rollout")
|
||||
parser.add_argument("--num-consumers", type=int, default=100)
|
||||
parser.add_argument("--num-firms", type=int, default=10)
|
||||
parser.add_argument("--num-governments", type=int, default=1)
|
||||
parser.add_argument("--run-only", action="store_true")
|
||||
parser.add_argument("--seed-from-timestamp", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
(
|
||||
default_cfg_dict,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
default_firm_action,
|
||||
default_government_action,
|
||||
) = all_agents_short_export_experiment_template(
|
||||
args.num_firms, args.num_consumers, args.num_governments
|
||||
)
|
||||
|
||||
if args.run_only:
|
||||
print("Not sweeping over hyperparameter combos...")
|
||||
else:
|
||||
for new_cfg in sweep_cfg_generator(
|
||||
default_cfg_dict,
|
||||
tr_param_sweeps=train_param_sweeps,
|
||||
ag_param_sweeps=agent_param_sweeps,
|
||||
wld_param_sweeps=world_param_sweeps,
|
||||
seed_from_timestamp=args.seed_from_timestamp,
|
||||
group_name=args.group_name,
|
||||
):
|
||||
create_job_dir(
|
||||
args.experiment_dir,
|
||||
args.job_name_base,
|
||||
cfg=new_cfg,
|
||||
action_arrays={
|
||||
"consumption_choices": consumption_choices,
|
||||
"work_choices": work_choices,
|
||||
"price_and_wage": price_and_wage,
|
||||
"tax_choices": tax_choices,
|
||||
},
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
print("Dry-run -> not actually training...")
|
||||
else:
|
||||
print("Training multiple experiments locally...")
|
||||
|
||||
# for dirs in experiment dir, run job
|
||||
experiment_dirs = [
|
||||
f.path for f in os.scandir(args.experiment_dir) if f.is_dir()
|
||||
]
|
||||
for experiment in experiment_dirs:
|
||||
run_experiment_batch_parallel(
|
||||
experiment,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
group_name=args.group_name,
|
||||
consumers_only=False,
|
||||
no_firms=False,
|
||||
default_firm_action=default_firm_action,
|
||||
default_government_action=default_government_action,
|
||||
)
|
||||
51
ai_economist/real_business_cycle/train_single_exp.py
Normal file
51
ai_economist/real_business_cycle/train_single_exp.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
import argparse
|
||||
|
||||
from experiment_utils import run_experiment_batch_parallel
|
||||
from rbc.constants import all_agents_export_experiment_template
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--experiment-dir", type=str, default="experiment/experiment")
|
||||
parser.add_argument("--group-name", type=str, default="default_group")
|
||||
parser.add_argument("--job-name-base", type=str, default="rollout")
|
||||
parser.add_argument("--num-consumers", type=int, default=100)
|
||||
parser.add_argument("--num-firms", type=int, default=10)
|
||||
parser.add_argument("--num-governments", type=int, default=1)
|
||||
parser.add_argument("--run-only", action="store_true")
|
||||
parser.add_argument("--seed-from-timestamp", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
(
|
||||
default_cfg_dict,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
default_firm_action,
|
||||
default_government_action,
|
||||
) = all_agents_export_experiment_template(
|
||||
args.num_firms, args.num_consumers, args.num_governments
|
||||
)
|
||||
|
||||
if not args.dry_run:
|
||||
# for dirs in experiment dir, run job
|
||||
experiment = args.experiment_dir
|
||||
run_experiment_batch_parallel(
|
||||
experiment,
|
||||
consumption_choices,
|
||||
work_choices,
|
||||
price_and_wage,
|
||||
tax_choices,
|
||||
group_name=args.group_name,
|
||||
consumers_only=False,
|
||||
no_firms=False,
|
||||
default_firm_action=default_firm_action,
|
||||
default_government_action=default_government_action,
|
||||
)
|
||||
5
ai_economist/training/__init__.py
Normal file
5
ai_economist/training/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
# YAML configuration for the tag continuous environment
|
||||
name: "covid_and_economy_environment"
|
||||
# Environment settings
|
||||
env:
|
||||
collate_agent_step_and_reset_data: True
|
||||
components:
|
||||
- ControlUSStateOpenCloseStatus:
|
||||
action_cooldown_period: 28
|
||||
- FederalGovernmentSubsidy:
|
||||
num_subsidy_levels: 20
|
||||
subsidy_interval: 90
|
||||
max_annual_subsidy_per_person: 20000
|
||||
- VaccinationCampaign:
|
||||
daily_vaccines_per_million_people: 3000
|
||||
delivery_interval: 1
|
||||
vaccine_delivery_start_date: "2021-01-12"
|
||||
economic_reward_crra_eta: 2
|
||||
episode_length: 540
|
||||
flatten_masks: True
|
||||
flatten_observations: False
|
||||
health_priority_scaling_agents: 0.3
|
||||
health_priority_scaling_planner: 0.45
|
||||
infection_too_sick_to_work_rate: 0.1
|
||||
multi_action_mode_agents: False
|
||||
multi_action_mode_planner: False
|
||||
n_agents: 51
|
||||
path_to_data_and_fitted_params: ""
|
||||
pop_between_age_18_65: 0.6
|
||||
risk_free_interest_rate: 0.03
|
||||
world_size: [1, 1]
|
||||
start_date: "2020-03-22"
|
||||
use_real_world_data: False
|
||||
use_real_world_policies: False
|
||||
# Trainer settings
|
||||
trainer:
|
||||
num_envs: 60 # number of environment replicas
|
||||
num_episodes: 1000 # number of episodes to run the training for
|
||||
train_batch_size: 5400 # total batch size used for training per iteration (across all the environments)
|
||||
# Policy network settings
|
||||
policy: # list all the policies below
|
||||
a:
|
||||
to_train: True # flag indicating whether the model needs to be trained
|
||||
algorithm: "PPO" # algorithm used to train the policy
|
||||
vf_loss_coeff: 1 # loss coefficient schedule for the value function loss
|
||||
entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss
|
||||
gamma: 0.98 # discount factor
|
||||
lr: 0.0001 # learning rate
|
||||
model:
|
||||
type: "fully_connected"
|
||||
fc_dims: [256, 256]
|
||||
model_ckpt_filepath: ""
|
||||
p:
|
||||
to_train: True
|
||||
algorithm: "PPO"
|
||||
vf_loss_coeff: 1
|
||||
entropy_coeff: # annealing entropy over time
|
||||
- [0, 0.5]
|
||||
- [50000000, 0.05]
|
||||
gamma: 0.98
|
||||
lr: 0.0001
|
||||
model:
|
||||
type: "fully_connected"
|
||||
fc_dims: [256, 256]
|
||||
model_ckpt_filepath: ""
|
||||
# Checkpoint saving setting
|
||||
saving:
|
||||
metrics_log_freq: 100 # How often (in iterations) to print the metrics
|
||||
model_params_save_freq: 500 # How often (in iterations) to save the model parameters
|
||||
basedir: "/tmp" # base folder used for saving
|
||||
name: "covid19_and_economy" # experiment name
|
||||
tag: "experiments" # experiment tag
|
||||
134
ai_economist/training/training_script.py
Normal file
134
ai_economist/training/training_script.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) 2021, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root
|
||||
# or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
"""
|
||||
Example training script for the grid world and continuous versions of Tag.
|
||||
Note: This training script only runs on a GPU machine.
|
||||
You will also need to install WarpDrive (https://github.com/salesforce/warp-drive)
|
||||
using `pip install rl-warp-drive`, and Pytorch(https://pytorch.org/)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import GPUtil
|
||||
|
||||
try:
|
||||
num_gpus_available = len(GPUtil.getAvailable())
|
||||
assert num_gpus_available > 0, "This training script needs a GPU to run!"
|
||||
print(f"Inside training_script.py: {num_gpus_available} GPUs are available.")
|
||||
import torch
|
||||
import yaml
|
||||
from warp_drive.training.trainer import Trainer
|
||||
from warp_drive.utils.env_registrar import EnvironmentRegistrar
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"This training script requires the 'WarpDrive' package, please run "
|
||||
"'pip install rl-warp-drive' first."
|
||||
) from None
|
||||
except ValueError:
|
||||
raise ValueError("This training script needs a GPU to run!") from None
|
||||
|
||||
from ai_economist.foundation.env_wrapper import FoundationEnvWrapper
|
||||
from ai_economist.foundation.scenarios.covid19.covid19_env import (
|
||||
CovidAndEconomyEnvironment,
|
||||
)
|
||||
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
pytorch_cuda_init_success = torch.cuda.FloatTensor(8)
|
||||
_COVID_AND_ECONOMY_ENVIRONMENT = "covid_and_economy_environment"
|
||||
|
||||
# Usage:
|
||||
# >> python ai_economist/training/example_training_script.py
|
||||
# --env covid_and_economy_environment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--env", "-e", type=str, help="Environment to train.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Read the run configurations specific to each environment.
|
||||
# Note: The run config yamls are located at warp_drive/training/run_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
assert args.env in [_COVID_AND_ECONOMY_ENVIRONMENT], (
|
||||
f"Currently, the only environment supported "
|
||||
f"is {_COVID_AND_ECONOMY_ENVIRONMENT}"
|
||||
)
|
||||
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"run_configs",
|
||||
f"{args.env}.yaml",
|
||||
)
|
||||
with open(config_path, "r", encoding="utf8") as f:
|
||||
run_config = yaml.safe_load(f)
|
||||
|
||||
num_envs = run_config["trainer"]["num_envs"]
|
||||
|
||||
# Create a wrapped environment object via the EnvWrapper
|
||||
# Ensure that use_cuda is set to True (in order to run on the GPU)
|
||||
# ----------------------------------------------------------------
|
||||
if run_config["name"] == _COVID_AND_ECONOMY_ENVIRONMENT:
|
||||
env_registrar = EnvironmentRegistrar()
|
||||
this_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_registrar.add_cuda_env_src_path(
|
||||
CovidAndEconomyEnvironment.name,
|
||||
os.path.join(
|
||||
this_file_dir, "../foundation/scenarios/covid19/covid19_build.cu"
|
||||
),
|
||||
)
|
||||
env_wrapper = FoundationEnvWrapper(
|
||||
CovidAndEconomyEnvironment(**run_config["env"]),
|
||||
num_envs=num_envs,
|
||||
use_cuda=True,
|
||||
env_registrar=env_registrar,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# The policy_tag_to_agent_id_map dictionary maps
|
||||
# policy model names to agent ids.
|
||||
# ----------------------------------------------------
|
||||
policy_tag_to_agent_id_map = {
|
||||
"a": [str(agent_id) for agent_id in range(env_wrapper.env.n_agents)],
|
||||
"p": ["p"],
|
||||
}
|
||||
|
||||
# Flag indicating whether separate obs, actions and rewards placeholders
|
||||
# have to be created for each policy.
|
||||
# Set "create_separate_placeholders_for_each_policy" to True here
|
||||
# since the agents and planner have different observation
|
||||
# and action spaces.
|
||||
separate_placeholder_per_policy = True
|
||||
|
||||
# Flag indicating the observation dimension corresponding to
|
||||
# 'num_agents'.
|
||||
# Note: WarpDrive assumes that all the observation are shaped
|
||||
# (num_agents, *feature_dim), i.e., the observation dimension
|
||||
# corresponding to 'num_agents' is the first one. Instead, if the
|
||||
# observation dimension corresponding to num_agents is the last one,
|
||||
# we will need to permute the axes to align with WarpDrive's assumption
|
||||
obs_dim_corresponding_to_num_agents = "last"
|
||||
|
||||
# Trainer object
|
||||
# --------------
|
||||
trainer = Trainer(
|
||||
env_wrapper=env_wrapper,
|
||||
config=run_config,
|
||||
policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
|
||||
create_separate_placeholders_for_each_policy=separate_placeholder_per_policy,
|
||||
obs_dim_corresponding_to_num_agents=obs_dim_corresponding_to_num_agents,
|
||||
)
|
||||
|
||||
# Perform training
|
||||
# ----------------
|
||||
trainer.train()
|
||||
trainer.graceful_close()
|
||||
5
envs/__init__.py
Normal file
5
envs/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from . import (
|
||||
simple_market,
|
||||
econ_wrapper,
|
||||
base_econ_wrapper
|
||||
)
|
||||
169
envs/base_econ_wrapper.py
Normal file
169
envs/base_econ_wrapper.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from ai_economist.foundation.base import base_env
|
||||
from threading import Event, Lock, Thread
|
||||
from queue import Queue
|
||||
class BaseEconVecEnv():
|
||||
"""Base class for connecting reciever wrapper to a multi threaded econ simulation and training session"""
|
||||
|
||||
base_notification=Event() #Notification for Base
|
||||
reset_notification=Event() #Notification for recievers
|
||||
step_notification=Event() #Notification for recievers
|
||||
|
||||
action_edit_lock=Lock()
|
||||
actor_actions={}
|
||||
|
||||
stop_edit_lock=Lock()
|
||||
stop=False
|
||||
|
||||
vote_lock=Lock()
|
||||
n_voters=0
|
||||
n_votes_reset=0
|
||||
|
||||
|
||||
|
||||
# States of Env
|
||||
env_data_lock=Lock()
|
||||
obs=None
|
||||
rew=None
|
||||
done=None
|
||||
info=None
|
||||
n_data_retrieved=0
|
||||
|
||||
def __init__(self, econ: base_env.BaseEnvironment):
|
||||
self.env=econ
|
||||
|
||||
def register_vote(self):
|
||||
"""Register reciever on base. Base now knows"""
|
||||
|
||||
def run(self):
|
||||
"""Start the base wrapper"""
|
||||
thr=Thread(target=self._run,daemon=True)
|
||||
thr.run()
|
||||
return thr
|
||||
|
||||
def _run(self):
|
||||
#Reset for run
|
||||
self.base_notification.clear()
|
||||
self.reset_notification.clear()
|
||||
self.step_notification.clear()
|
||||
|
||||
self.stop_edit_lock.release()
|
||||
self.stop=False
|
||||
self.action_edit_lock.release()
|
||||
self.actor_actions={}
|
||||
self.vote_lock.release()
|
||||
self.reset_notification.clear()
|
||||
self.n_votes_reset=0
|
||||
self.n_votes_step=0
|
||||
|
||||
self.env_data_lock.release()
|
||||
self.obs=None
|
||||
self.rew=None
|
||||
self.done=None
|
||||
self.info=None
|
||||
#Reseting Env
|
||||
self._reset()
|
||||
|
||||
while True:
|
||||
# Wait for notification
|
||||
self.base_notification.wait()
|
||||
self.base_notification.clear() # Cleard
|
||||
#Check for stop signal
|
||||
self.stop_edit_lock.acquire()
|
||||
if self.stop:
|
||||
return
|
||||
self.stop_edit_lock.release()
|
||||
|
||||
#check for reset
|
||||
self.vote_lock.acquire() # we might edit votes
|
||||
if self.n_voters==self.n_votes_reset:
|
||||
## perform reset
|
||||
self.n_votes_reset=0
|
||||
self._reset()
|
||||
self.vote_lock.release()
|
||||
|
||||
#check for actions
|
||||
self.action_edit_lock.acquire()
|
||||
if self.env.n_agents==len(self.actor_actions.keys) & self.step_notification.is_set()==False:
|
||||
# we have all the actions -> STEP
|
||||
self._step()
|
||||
self.action_edit_lock.release() # release actions
|
||||
# we are done
|
||||
|
||||
def stop_env(self):
|
||||
"""Stops the wrapper"""
|
||||
self.stop_edit_lock.acquire()
|
||||
self.stop=True
|
||||
self.stop_edit_lock.release()
|
||||
self.base_notification.set()
|
||||
|
||||
|
||||
def _reset(self):
|
||||
# Aquire Lock
|
||||
self.env_data_lock.acquire()
|
||||
self.n_votes_reset=0
|
||||
self.obs=self.env.reset() #Reset env
|
||||
self.rew=None
|
||||
self.done=None
|
||||
self.info=None
|
||||
self.env_data_lock.release() #Release lock
|
||||
# Notify for reset
|
||||
self.reset_notification.set()
|
||||
|
||||
def _step(self):
|
||||
"""Steping interaly"""
|
||||
|
||||
self.env_data_lock.acquire()
|
||||
self.reset_notification.clear() # reset after first step
|
||||
self.obs,self.rew,self.done,self.info=self.env.step(self.actor_actions) # write data
|
||||
self.n_data_retrieved=0
|
||||
self.env_data_lock.release()
|
||||
self.action_edit_lock.acquire() # prevent steps until everybody had the chanse to look at it
|
||||
self.step_notification.set() # notify recievers
|
||||
|
||||
def _prepare_step(self):
|
||||
#prepare base for next step
|
||||
self.action_edit_lock.acquire() # we are editing action data
|
||||
if self.step_notification.is_set():
|
||||
self.step_notification.clear()
|
||||
self.actor_actions={}
|
||||
self.action_edit_lock.release()
|
||||
|
||||
def reciever_request_step(self, actions):
|
||||
"""Submits actions to base processing queue. Actions as dict pairing of idx and action id"""
|
||||
self._prepare_step() # New actions are bening submitted. Prepare base for new step
|
||||
self.action_edit_lock.acquire() # Start to submit action dict
|
||||
for k,v in actions:
|
||||
if self.actor_actions[k]!=None:
|
||||
raise Exception("Actor action has already been submitted. {}".format(k))
|
||||
self.actor_actions[k]=v
|
||||
self.base_notification.set() #Alert base for action changes
|
||||
self.action_edit_lock.release()
|
||||
|
||||
def reciever_block_step(self):
|
||||
"""Returns with newest data after step request has been called. Blocks until all actors have submitted an action"""
|
||||
self.step_notification.wait() # new data available
|
||||
self.env_data_lock.acquire() # get data
|
||||
obs=self.obs
|
||||
rew=self.rew
|
||||
done=self.done
|
||||
info=self.info
|
||||
self.n_data_retrieved+=1
|
||||
if self.n_data_retrieved>=self.n_voters:
|
||||
self.action_edit_lock.release() # release the step so that new actions can be submitted
|
||||
|
||||
self.env_data_lock.release()
|
||||
return obs,rew,done,info
|
||||
|
||||
def reciever_request_reset(self):
|
||||
"""Adds to vote count to reset. If limit is reached reset will occure"""
|
||||
self.vote_lock.acquire()
|
||||
self.n_votes_reset+=1
|
||||
self.vote_lock.release()
|
||||
|
||||
def reciever_block_reset(self):
|
||||
"""Called after request will block until reset occures. Returns observations."""
|
||||
self.reset_notification.wait()
|
||||
self.env_data_lock.acquire()
|
||||
obs=self.obs
|
||||
self.env_data_lock.release()
|
||||
return obs
|
||||
67
envs/reciever_econ_wrapper.py
Normal file
67
envs/reciever_econ_wrapper.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
from ai_economist.foundation.base import base_env
|
||||
|
||||
import gym
|
||||
import gym.spaces
|
||||
import numpy as np
|
||||
from base_econ_wrapper import BaseEconVecEnv
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
||||
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
from ai_economist import foundation
|
||||
|
||||
class RecieverEconVecEnv(gym.Env):
|
||||
"""Reciever part of BaseEconVecEnv. Filters by agent class and presents gym api to RL algos. Enables multi threading learning for different agent types."""
|
||||
def __init__(self, base_econ: BaseEconVecEnv, agent_classname: str):
|
||||
self.base_econ=base_econ
|
||||
base_econ.register_vote()
|
||||
self.econ=base_econ.env
|
||||
self.agent_name=agent_classname
|
||||
self.agnet_idx=list(self.econ.world._agent_class_idx_map[agent_classname])
|
||||
self.idx_to_index={}
|
||||
#create idx to index map
|
||||
for i in range(len(self.agnet_idx)):
|
||||
self.idx_to_index[self.agnet_idx[i]]=i
|
||||
first_idx=self.agnet_idx[0]
|
||||
|
||||
|
||||
def step_async(self, actions: dict) -> None:
|
||||
"""Submittes actions to Env. actions is a dict with idx -> action pair"""
|
||||
data=self._dict_index_to_idx(actions)
|
||||
self.base_econ.reciever_request_step(data)
|
||||
|
||||
def _dict_idx_to_index(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.idx_to_index[k]]=v
|
||||
return data_out
|
||||
|
||||
def _dict_index_to_idx(self, data):
|
||||
data_out={}
|
||||
for k,v in data.items():
|
||||
data_out[self.agnet_idx[k]]=v
|
||||
return data_out
|
||||
|
||||
def step_wait(self):
|
||||
#convert to econ actions
|
||||
obs,rew,done,info=self.base_econ.reciever_block_step()
|
||||
c_obs=self._dict_idx_to_index(obs)
|
||||
c_rew=self._dict_idx_to_index(rew)
|
||||
c_done=self._dict_idx_to_index(done)
|
||||
c_info=self._dict_idx_to_index(info)
|
||||
return c_obs,c_rew,c_done,c_info
|
||||
|
||||
|
||||
def reset(self):
|
||||
# env=foundation.make_env_instance(**self.config)
|
||||
# self.env = env
|
||||
self.base_econ.reciever_request_reset()
|
||||
obs =self.base_econ.reciever_block_reset()
|
||||
c_obs=self._dict_idx_to_index(obs)
|
||||
return c_obs
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
ai-economist
|
||||
|
||||
gym
|
||||
ray[rllib]
|
||||
Reference in New Issue
Block a user