from typing import Any, Callable, Dict, Iterable, List
import torch
from torch.nn import Module
from boat_torch.gm_ol.dynamical_system import DynamicalSystem
from boat_torch.operation_registry import register_class
from boat_torch.utils.op_utils import grad_unused_zero, update_tensor_grads
[docs]
@register_class
class GAFFO(DynamicalSystem):
"""
Implements the optimization procedure of Gap-Function-based First-Order
Method (GAFFO) [1].
This first-order implementation treats lower-level variables as the current
task state and writes the meta-gradient to matching upper-level variables.
Parameters
----------
ll_objective : Callable
The lower-level objective of the BLO problem.
ul_objective : Callable
The upper-level objective of the BLO problem.
ll_model : torch.nn.Module
The lower-level model of the BLO problem.
ul_model : torch.nn.Module
The upper-level model of the BLO problem.
ll_var : List[torch.Tensor]
The list of lower-level variables of the BLO problem.
ul_var : List[torch.Tensor]
The list of upper-level variables of the BLO problem.
lower_loop : int
Number of iterations for lower-level optimization.
solver_config : Dict[str, Any]
A dictionary containing solver configurations. Expected keys include:
- "lower_level_opt": The optimizer for the lower-level model.
- "GAFFO" (Dict): A dictionary containing the following optional keys:
- "lambda" or "gap_lambda": Regularization weight for the gap term.
- "sigma": Probe step size used to estimate the regularized gap.
- "lower_step_size": Step size for the lower-level update. If not
specified, the learning rate of "lower_level_opt" is used.
- "use_sign_lower_step": Whether to use the sign of lower-level
gradients for the lower-level update.
- "maximize": Whether to use the ascent-form GAFFO direction.
- "sync_lower_from_upper": Whether to synchronize lower-level
variables from upper-level variables before each update.
- "projection": Optional projection operator for constrained
lower-level variables.
References
----------
[1] Yao W, Yin H, Zeng S, and Zhang J. "Overcoming Lower-Level Constraints
in Bilevel Optimization: A Novel Approach with Regularized Gap Functions,"
arXiv:2406.01992, 2024.
"""
def __init__(
self,
ll_objective: Callable,
lower_loop: int,
ul_model: Module,
ul_objective: Callable,
ll_model: Module,
ll_var: List,
ul_var: List,
solver_config: Dict[str, Any],
):
super(GAFFO, self).__init__(
ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config
)
config = solver_config.get("GAFFO", {})
self.ll_var = list(ll_var)
self.ul_var = list(ul_var)
self.gap_lambda = float(config.get("lambda", config.get("gap_lambda", 1.0)))
self.sigma = float(config.get("sigma", 0.01))
lower_step_size = config.get("lower_step_size")
self.lower_step_size = float(
solver_config["lower_level_opt"].defaults.get("lr", 1.0)
if lower_step_size is None
else lower_step_size
)
self.use_sign_lower_step = bool(config.get("use_sign_lower_step", False))
self.maximize = bool(config.get("maximize", True))
self.sync_lower_from_upper = bool(config.get("sync_lower_from_upper", True))
self.projection = config.get("projection", None)
[docs]
def optimize(self, ll_feed_dict: Dict, ul_feed_dict: Dict, current_iter: int):
self._check_meta_shapes()
if self.sync_lower_from_upper:
self._copy_params(self.ul_var, self.ll_var)
self._project_params(self.ll_var)
lower_loss = self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model)
lower_grads = grad_unused_zero(lower_loss, self.ll_var)
self._lower_step(lower_grads)
self._project_params(self.ll_var)
lower_loss_next = self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model)
gap_base_grads = grad_unused_zero(lower_loss_next, self.ll_var)
current_lower = [param.detach().clone() for param in self.ll_var]
self._set_probe_params(current_lower, gap_base_grads)
lower_loss_probe = self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model)
gap_probe_grads = grad_unused_zero(lower_loss_probe, self.ll_var)
self._copy_params(current_lower, self.ll_var)
upper_loss = self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model)
risk_grads = grad_unused_zero(upper_loss, self.ll_var)
ascent_direction = [
risk_grad / max(self.gap_lambda, 1e-8) - (probe_grad - base_grad)
for risk_grad, probe_grad, base_grad in zip(
risk_grads, gap_probe_grads, gap_base_grads
)
]
meta_grads = [-grad if self.maximize else grad for grad in ascent_direction]
update_tensor_grads(self.ul_var, meta_grads)
return {"upper_loss": upper_loss.item()}
def _check_meta_shapes(self):
if len(self.ll_var) != len(self.ul_var): # pragma: no cover
raise ValueError(
"GAFFO expects lower_level_var and upper_level_var to have the "
"same structure."
)
for ll_param, ul_param in zip(self.ll_var, self.ul_var):
if ll_param.shape != ul_param.shape: # pragma: no cover
raise ValueError(
"GAFFO expects matching lower/upper variable shapes. "
f"Got lower {tuple(ll_param.shape)} and upper {tuple(ul_param.shape)}."
)
def _lower_step(self, grads: Iterable[torch.Tensor]):
with torch.no_grad():
for param, grad in zip(self.ll_var, grads):
step_grad = grad.sign() if self.use_sign_lower_step else grad
param.sub_(self.lower_step_size * step_grad)
def _set_probe_params(
self, base_params: List[torch.Tensor], base_grads: Iterable[torch.Tensor]
):
with torch.no_grad():
for param, base_param, grad in zip(self.ll_var, base_params, base_grads):
param.copy_(base_param + self.sigma * grad)
self._project_params(self.ll_var)
def _project_params(self, params: List[torch.Tensor]): # pragma: no cover
if self.projection is None:
return
projected = self._call_projection(params)
if projected is None:
return
with torch.no_grad():
for param, projected_param in zip(params, projected):
param.copy_(projected_param)
def _call_projection(self, params: List[torch.Tensor]): # pragma: no cover
try:
return self.projection(params)
except TypeError:
return self.projection(params=params)
@staticmethod
def _copy_params(source: Iterable[torch.Tensor], target: Iterable[torch.Tensor]):
with torch.no_grad():
for src, dst in zip(source, target):
dst.copy_(src)