Source code for boat_torch.fo_ol.meso

from boat_torch.utils.op_utils import (
    grad_unused_zero,
    update_tensor_grads,
    copy_parameter_from_list,
)
import numpy
import torch
from torch.nn import Module
import copy
from typing import Dict, Any, Callable, List
from boat_torch.operation_registry import register_class
from boat_torch.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class MESO(DynamicalSystem): """ Implements the optimization procedure of Moreau Envelope based Single-loop Method (MESO) [1]. Parameters ---------- ll_objective : Callable The lower-level objective of the BLO problem. ul_objective : Callable The upper-level objective of the BLO problem. ll_model : torch.nn.Module The lower-level model of the BLO problem. ul_model : torch.nn.Module The upper-level model of the BLO problem. ll_var : List[torch.Tensor] The list of lower-level variables of the BLO problem. ul_var : List[torch.Tensor] The list of upper-level variables of the BLO problem. lower_loop : int Number of iterations for lower-level optimization. solver_config : Dict[str, Any] A dictionary containing solver configurations. Expected keys include: - "lower_level_opt": The optimizer for the lower-level model. - "MESO" (Dict): A dictionary containing the following keys: - "eta": Learning rate for the MESO optimization procedure. - "gamma_1": Regularization parameter for the MESO algorithm. - "c0": Initial constant for the update steps. - "y_hat_lr": Learning rate for optimizing the surrogate variable `y_hat`. References ---------- [1] Liu R, Liu Z, Yao W, et al. "Moreau Envelope for Nonconvex Bi-Level Optimization: A Single-loop and Hessian-free Solution Strategy," ICML, 2024. """ def __init__( self, ll_objective: Callable, lower_loop: int, ul_model: Module, ul_objective: Callable, ll_model: Module, ll_var: List, ul_var: List, solver_config: Dict[str, Any], ): super(MESO, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config ) self.ll_opt = solver_config["lower_level_opt"] self.ll_var = ll_var self.ul_var = ul_var self.y_loop = lower_loop self.eta = solver_config["MESO"]["eta"] self.gamma_1 = solver_config["MESO"]["gamma_1"] self.c0 = solver_config["MESO"]["c0"] self.y_hat = copy.deepcopy(self.ll_model) self.y_hat_opt = torch.optim.SGD( self.y_hat.parameters(), lr=solver_config["MESO"]["y_hat_lr"], momentum=0.9 )
[docs] def optimize(self, ll_feed_dict: Dict, ul_feed_dict: Dict, current_iter: int): """ Executes the optimization procedure using the provided data and model configurations. Parameters ---------- ll_feed_dict : Dict Dictionary containing the lower-level data used for optimization. Typically includes training data or parameters for the lower-level objective. ul_feed_dict : Dict Dictionary containing the upper-level data used for optimization. Usually includes parameters or configurations for the upper-level objective. current_iter : int The current iteration count of the optimization process, used for tracking progress or adjusting optimization parameters. Returns ------- Dict A dictionary containing the upper-level objective and the status of hypergradient computation. """ if current_iter == 0: ck = 0.2 else: ck = numpy.power(current_iter + 1, 0.25) * self.c0 theta_loss = self.ll_objective(ll_feed_dict, self.ul_model, self.y_hat) grad_theta_parmaters = grad_unused_zero( theta_loss, list(self.y_hat.parameters()) ) errs = [] for a, b in zip(list(self.y_hat.parameters()), list(self.ll_var)): diff = a - b errs.append(diff) vs_param = [] for v0, gt, err in zip( list(self.y_hat.parameters()), grad_theta_parmaters, errs ): vs_param.append(v0 - self.eta * (gt + self.gamma_1 * err)) # upate \theta copy_parameter_from_list(self.y_hat, vs_param) reg = 0 for param1, param2 in zip(list(self.ll_var), vs_param): diff = param1 - param2 reg += torch.norm(diff, p=2) ** 2 lower_loss = ( (1 / ck) * self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model) + self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model) - 0.5 * self.gamma_1 * reg ) self.ll_opt.zero_grad() grad_y_parmaters = grad_unused_zero(lower_loss, list(self.ll_var)) update_tensor_grads(self.ll_var, grad_y_parmaters) self.ll_opt.step() upper_loss = ( (1 / ck) * self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model) + self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model) - self.ll_objective(ll_feed_dict, self.ul_model, self.y_hat) ) grad_x_parmaters = grad_unused_zero(upper_loss, self.ul_var) update_tensor_grads(self.ul_var, grad_x_parmaters) return {"upper_loss": upper_loss.item()}