Source code for boat_torch.fo_ol.alto

from boat_torch.utils.op_utils import (
    grad_unused_zero,
    update_tensor_grads,
)
import torch
from torch.nn import Module
from typing import Dict, Any, Callable, List
from boat_torch.operation_registry import register_class
from boat_torch.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class ALTO(DynamicalSystem): """ Implements a simple Alternating Optimization (ALT) procedure for bi-level optimization. """ def __init__( self, ll_objective: Callable, lower_loop: int, ul_model: Module, ul_objective: Callable, ll_model: Module, ll_var: List, ul_var: List, solver_config: Dict[str, Any], ): super(ALTO, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config ) self.ll_opt = solver_config["lower_level_opt"] self.ll_var = ll_var self.ul_var = ul_var
[docs] def optimize(self, ll_feed_dict: Dict, ul_feed_dict: Dict, current_iter: int): """ Executes the simple alternating optimization procedure. """ # Lower-level update self.ll_opt.zero_grad() lower_loss = self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model) grad_y_parameters = grad_unused_zero(lower_loss, list(self.ll_var)) update_tensor_grads(self.ll_var, grad_y_parameters) self.ll_opt.step() # Upper-level update upper_loss = self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model) grad_x_parameters = grad_unused_zero(upper_loss, self.ul_var) update_tensor_grads(self.ul_var, grad_x_parameters) return {"upper_loss": upper_loss.item()}