Files
ai-econ/ai_economist/real_business_cycle/rbc/util.py

111 lines
4.5 KiB
Python

# Copyright (c) 2021, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root
# or https://opensource.org/licenses/BSD-3-Clause
import torch
def dict_merge(dct, merge_dct):
"""Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
:param dct: dict onto which the merge is executed
:param merge_dct: dct merged into dct
:return: None
"""
for k, v in merge_dct.items():
# dct does not have k yet. Add it with value v.
if (k not in dct) and (not isinstance(dct, list)):
dct[k] = v
else:
# dct[k] and merge_dict[k] are both dictionaries. Recurse.
if isinstance(dct[k], (dict, list)) and isinstance(v, dict):
dict_merge(dct[k], merge_dct[k])
else:
# dct[k] and merge_dict[k] are both tuples or lists.
if isinstance(dct[k], (tuple, list)) and isinstance(v, (tuple, list)):
# They don't match. Overwrite with v.
if len(dct[k]) != len(v):
dct[k] = v
else:
for i, (d_val, v_val) in enumerate(zip(dct[k], v)):
if isinstance(d_val, dict) and isinstance(v_val, dict):
dict_merge(d_val, v_val)
else:
dct[k][i] = v_val
else:
dct[k] = v
def min_max_consumer_budget_delta(hparams_dict):
# largest single round changes
max_wage = hparams_dict["agents"]["max_possible_wage"]
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
max_price = hparams_dict["agents"]["max_possible_price"]
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
num_firms = hparams_dict["agents"]["num_firms"]
min_budget = (
-max_singlefirm_consumption * max_price * num_firms
) # negative budget from consuming only
max_budget = max_hours * max_wage * num_firms # positive budget from only working
return min_budget, max_budget
def min_max_stock_delta(hparams_dict):
# for now, assuming 1.0 capital
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
alpha = hparams_dict["world"]["production_alpha"]
if isinstance(alpha, str):
alpha = 0.8
num_consumers = hparams_dict["agents"]["num_consumers"]
max_delta = (max_hours * num_consumers) ** alpha
min_delta = -max_singlefirm_consumption * num_consumers
return min_delta, max_delta
def min_max_firm_budget(hparams_dict):
max_wage = hparams_dict["agents"]["max_possible_wage"]
max_hours = hparams_dict["agents"]["max_possible_hours_worked"]
max_singlefirm_consumption = hparams_dict["agents"]["max_possible_consumption"]
num_consumers = hparams_dict["agents"]["num_consumers"]
max_price = hparams_dict["agents"]["max_possible_price"]
max_delta = max_singlefirm_consumption * max_price * num_consumers
min_delta = -max_hours * max_wage * num_consumers
return min_delta, max_delta
def expand_to_digit_form(x, dims_to_expand, max_digits):
# first split x up
requires_grad = (
x.requires_grad
) # don't want to backprop through these ops, but do want
# gradients if x had them
with torch.no_grad():
tensor_pieces = []
expanded_digit_shape = x.shape[:-1] + (max_digits,)
for i in range(x.shape[-1]):
if i not in dims_to_expand:
tensor_pieces.append(x[..., i : i + 1])
else:
digit_entries = torch.zeros(expanded_digit_shape, device=x.device)
for j in range(max_digits):
digit_entries[..., j] = (x[..., i] % (10 ** (j + 1))) / (
10 ** (j + 1)
)
tensor_pieces.append(digit_entries)
output = torch.cat(tensor_pieces, dim=-1)
output.requires_grad_(requires_grad)
return output
def size_after_digit_expansion(existing_size, dims_to_expand, max_digits):
num_expanded = len(dims_to_expand)
# num non expanded digits, + all the expanded ones
return (existing_size - num_expanded) + (max_digits * num_expanded)