Source code for boat_torch.boat_opt

import time
import copy
from typing import Dict, Any, Callable
from boat_torch.utils.op_utils import copy_parameter_from_list, average_grad

import torch
from torch import Tensor
import higher

from boat_torch.operation_registry import get_registered_operation
from boat_torch.gm_ol import makes_functional_dynamical_system
from boat_torch.na_ol import makes_functional_na_operation
import matplotlib.pyplot as plt
import os
import json

importlib = __import__("importlib")


def _load_loss_function(loss_config: Dict[str, Any]) -> Callable:
    """
    Dynamically load a loss function from the provided configuration.

    :param loss_config: Dictionary with keys:
        - "function": Path to the loss function (e.g., "module.path.to_function").
        - "params": Parameters to be passed to the loss function.
    :type loss_config: Dict[str, Any]

    :returns: Loaded loss function ready for use.
    :rtype: Callable
    """

    module_name, func_name = loss_config["function"].rsplit(".", 1)
    module = importlib.import_module(module_name)
    func = getattr(module, func_name)

    # Return a wrapper function that can accept both positional and keyword arguments
    return lambda *args, **kwargs: func(
        *args, **{**loss_config.get("params", {}), **kwargs}
    )


[docs] class Problem: """ Enhanced bi-level optimization problem class supporting flexible loss functions and operation configurations. """ def __init__(self, config: Dict[str, Any], loss_config: Dict[str, Any]): """ Initialize the Problem instance. :param config: Configuration dictionary for the optimization setup. - "fo_op": First Order Gradient based Method (optional), e.g., ["VSO"], ["VFO"], ["MESO"]. - "gm_op": List of gradient mapping operations (optional), e.g., ["NGD"], ["NGD", "GDA"], ["NGD", "GDA", "DI"]. - "na_op": Hyper-optimization method (optional), e.g., ["RAD"], ["RAD", "PTT"], ["IAD", "NS", "PTT"]. - "lower_level_loss": Configuration for the lower-level loss function based on the json file configuration. - "upper_level_loss": Configuration for the upper-level loss function based on the json file configuration. - "lower_level_model": The lower-level model to be optimized. - "upper_level_model": The upper-level model to be optimized. - "lower_level_var": Variables in the lower-level model. - "upper_level_var": Variables in the upper-level model. - "device": Device configuration (e.g., "cpu", "cuda"). :type config: Dict[str, Any] :param loss_config: Loss function configuration dictionary. - "lower_level_loss": Configuration for the lower-level loss function. - "upper_level_loss": Configuration for the upper-level loss function. - "GDA_loss": Configuration for GDA loss function (optional). :type loss_config: Dict[str, Any] :returns: None """ self._fo_op = config["fo_op"] self._gm_op = config["gm_op"] self._na_op = config["na_op"] self._ll_model = config["lower_level_model"] self._ul_model = config["upper_level_model"] self._ll_var = config["lower_level_var"] self._ul_var = config["upper_level_var"] self._lower_opt = config["lower_level_opt"] self._upper_opt = config["upper_level_opt"] self._ll_loss = _load_loss_function(loss_config["lower_level_loss"]) self._ul_loss = _load_loss_function(loss_config["upper_level_loss"]) self.boat_configs = config self._lower_loop = config.get("lower_iters", 10) self._log_results = [] self._device = torch.device(config["device"]) self._ll_solver = None self._ul_solver = None self._lower_init_opt = None self._fo_op_solver = None self._track_opt_traj = False if config["gm_op"] is not None: if "GDA" in config["gm_op"]: assert ( loss_config.get("gda_loss", None) is not None ), "Set the 'gda_loss' in loss_config properly." self.boat_configs["gda_loss"] = _load_loss_function( loss_config["gda_loss"] ) self.loss_log_path = config["loss_log_path"] self.loss_history = []
[docs] def build_ll_solver(self): """ Configure the lower-level solver. :returns: None """ if self.boat_configs["fo_op"] is None: assert (self.boat_configs["gm_op"] is not None) and ( self.boat_configs["na_op"] is not None ), "Set 'gm_op' and 'na_op' properly." self.check_status() sorted_ops = sorted([op.upper() for op in self._gm_op]) self._ll_solver = makes_functional_dynamical_system( custom_order=sorted_ops, ll_objective=self._ll_loss, ul_objective=self._ul_loss, ll_model=self._ll_model, ul_model=self._ul_model, lower_loop=self._lower_loop, solver_config=self.boat_configs, ) if "DI" in self.boat_configs["gm_op"]: opt_cls = type(self._upper_opt) di_lr = float(self.boat_configs["DI"]["lr"]) self._lower_init_opt = opt_cls([{'params': self._ll_var, 'lr': di_lr}]) else: self._fo_op_solver = get_registered_operation( "%s" % self.boat_configs["fo_op"] )( ll_objective=self._ll_loss, ul_objective=self._ul_loss, ll_model=self._ll_model, ul_model=self._ul_model, lower_loop=self._lower_loop, ll_var=self._ll_var, ul_var=self._ul_var, solver_config=self.boat_configs, ) return self
[docs] def build_ul_solver(self): """ Configure the lower-level solver. :returns: None """ if self.boat_configs["fo_op"] is None: assert ( self.boat_configs["na_op"] is not None ), "Choose FO_OL based methods from ['VSO','VFO','MESO', 'PGDO', 'ALTO'] or set 'gm_ol' and 'na_ol' properly. Currently, fo_op ={} is not None".format( self.boat_configs["fo_op"] ) sorted_ops = sorted([op.upper() for op in self._na_op]) if "DM" not in self._gm_op: self._ul_solver = makes_functional_na_operation( custom_order=sorted_ops, ul_objective=self._ul_loss, ll_objective=self._ll_loss, ll_model=self._ll_model, ul_model=self._ul_model, ll_var=self._ll_var, ul_var=self._ul_var, solver_config=self.boat_configs, ) else: assert ( self.boat_configs["na_op"] is None ), "Choose FO_OL based methods from ['VSO','VFO','MESO', 'PGDO', 'ALTO'] or set 'gm_ol' and 'na_ol' properly. Currently, na_op ={} is not None".format( self.boat_configs["na_op"] ) self._ul_solver = None return self
[docs] def run_iter( self, ll_feed_dict: Dict[str, Tensor], ul_feed_dict: Dict[str, Tensor], current_iter: int, ) -> tuple: """ Run a single iteration of the bi-level optimization process. :param ll_feed_dict: Dictionary containing the real-time data and parameters fed for the construction of the lower-level (LL) objective. Example:: { "image": train_images, "text": train_texts, "target": train_labels # Optional } :type ll_feed_dict: Dict[str, Tensor] :param ul_feed_dict: Dictionary containing the real-time data and parameters fed for the construction of the upper-level (UL) objective. Example:: { "image": val_images, "text": val_texts, "target": val_labels # Optional } :type ul_feed_dict: Dict[str, Tensor] :param current_iter: The current iteration number. :type current_iter: int :notes: - When `accumulate_grad` is set to True, you need to pack the data of each batch based on the format above. - In that case, pass `ll_feed_dict` and `ul_feed_dict` as lists of dictionaries, i.e., `[Dict[str, Tensor]]`. :returns: A tuple containing: - **loss** (*float*): The loss value for the current iteration. - **run_time** (*float*): The total time taken for the iteration. :rtype: tuple """ if self.boat_configs["fo_op"] is not None: start_time = time.perf_counter() if self.boat_configs["fo_ol_batch_input"]: for batch_ll_feed_dict, batch_ul_feed_dict in zip( ll_feed_dict, ul_feed_dict ): self._log_results.append( self._fo_op_solver.optimize(batch_ll_feed_dict, batch_ul_feed_dict, current_iter) ) #meta_learning else: self._log_results.append( self._fo_op_solver.optimize(ll_feed_dict, ul_feed_dict, current_iter) ) run_time = time.perf_counter() - start_time else: run_time = 0 if self.boat_configs["accumulate_grad"]: for batch_ll_feed_dict, batch_ul_feed_dict in zip( ll_feed_dict, ul_feed_dict ): with higher.innerloop_ctx( self._ll_model, self._lower_opt, copy_initial_weights=False, device=self._device, track_higher_grads=self._track_opt_traj, ) as (auxiliary_model, auxiliary_opt): forward_time = time.perf_counter() dynamic_results = self._ll_solver.optimize( ll_feed_dict=batch_ll_feed_dict, ul_feed_dict=batch_ul_feed_dict, auxiliary_model=auxiliary_model, auxiliary_opt=auxiliary_opt, current_iter=current_iter, ) self._log_results.append(dynamic_results) max_loss_iter = list(dynamic_results[-1].values())[-1] forward_time = time.perf_counter() - forward_time backward_time = time.perf_counter() if self._ul_solver is not None: self._log_results.append( self._ul_solver.compute_gradients( ll_feed_dict=batch_ll_feed_dict, ul_feed_dict=batch_ul_feed_dict, auxiliary_model=auxiliary_model, max_loss_iter=max_loss_iter, ) ) backward_time = time.perf_counter() - backward_time run_time += forward_time + backward_time else: with higher.innerloop_ctx( self._ll_model, self._lower_opt, copy_initial_weights=False, device=self._device, track_higher_grads=self._track_opt_traj, ) as (auxiliary_model, auxiliary_opt): forward_time = time.perf_counter() dynamic_results = self._ll_solver.optimize( ll_feed_dict=ll_feed_dict, ul_feed_dict=ul_feed_dict, auxiliary_model=auxiliary_model, auxiliary_opt=auxiliary_opt, current_iter=current_iter, ) max_loss_iter = list(dynamic_results[-1].values())[-1] forward_time = time.perf_counter() - forward_time backward_time = time.perf_counter() if self._ul_solver is not None: self._log_results.append( self._ul_solver.compute_gradients( ll_feed_dict=ll_feed_dict, ul_feed_dict=ul_feed_dict, auxiliary_model=auxiliary_model, max_loss_iter=max_loss_iter, ) ) backward_time = time.perf_counter() - backward_time if self.boat_configs["copy_last_param"]: copy_parameter_from_list( self._ll_model, list(auxiliary_model.parameters(time=-1)), ) run_time = forward_time + backward_time if "DI" in self.boat_configs["gm_op"]: self._lower_init_opt.step() self._lower_init_opt.zero_grad() if isinstance(ll_feed_dict, list): ll_fd = ll_feed_dict[0] ul_fd = ul_feed_dict[0] else: ll_fd = ll_feed_dict ul_fd = ul_feed_dict if not self.boat_configs["return_grad"]: self._upper_opt.step() self._upper_opt.zero_grad() else: ll_loss = self._ll_loss(ll_fd, self._ul_model, self._ll_model) ul_loss = self._ul_loss(ul_fd, self._ul_model, self._ll_model) self.save_losses(current_iter = current_iter, ll_loss = ll_loss, ul_loss = ul_loss) return [var.grad for var in list(self._ul_var)], run_time ll_loss = self._ll_loss(ll_fd, self._ul_model, self._ll_model) ul_loss = self._ul_loss(ul_fd, self._ul_model, self._ll_model) self.save_losses(current_iter = current_iter, ll_loss = ll_loss, ul_loss = ul_loss) return self._log_results, run_time
[docs] def set_track_trajectory(self, track_traj=True): self._track_opt_traj = track_traj
[docs] def check_status(self): if any(item in self._na_op for item in ["PTT", "IAD", "RAD"]): self.set_track_trajectory(True) if "DM" in self.boat_configs["gm_op"]: assert (self.boat_configs["na_op"] == ["RAD"]) or ( self.boat_configs["na_op"] == ["CG"] ), "When 'DM' is chosen, set the 'truncate_iter' properly." if "RGT" in self.boat_configs["na_op"]: assert ( self.boat_configs["RGT"]["truncate_iter"] > 0 ), "When 'RGT' is chosen, set the 'truncate_iter' properly ." if self.boat_configs["GDA"]["alpha_init"] > 0.0: assert ( 0.0 < self.boat_configs["GDA"]["alpha_decay"] <= 1.0 ), "Parameter 'alpha_decay' used in method BDA should be in the interval (0,1)." if "FD" in self._na_op: assert ( self.boat_configs["RGT"]["truncate_iter"] == 0 ), "One-stage method doesn't need trajectory truncation." def check_model_structure(base_model, meta_model): for param1, param2 in zip(base_model.parameters(), meta_model.parameters()): if ( (param1.shape != param2.shape) or (param1.dtype != param2.dtype) or (param1.device != param2.device) ): return False return True if "IAD" in self._na_op: assert check_model_structure(self._ll_model, self._ul_model), ( "With IAD or FOA operation, 'upper_level_model' and 'lower_level_model' have the same structure, " "and 'lower_level_var' and 'upper_level_var' are the same group of variables." ) assert (("DI" in self._gm_op) ^ ("IAD" in self._na_op)) or ( ("DI" not in self._gm_op) and ("IAD" not in self._na_op) ), "Only one of the 'PTT' and 'RGT' methods could be chosen." assert ( 0.0 <= self.boat_configs["GDA"]["alpha_init"] <= 1.0 ), "Parameter 'alpha' used in method BDA should be in the interval (0,1)." assert ( self.boat_configs["RGT"]["truncate_iter"] < self.boat_configs["lower_iters"] ), "The value of 'truncate_iter' shouldn't be greater than 'lower_loop'."
[docs] def plot_losses(self): iters = [x["iter"] for x in self.loss_history] ll_losses = [x["ll_loss"] for x in self.loss_history] ul_losses = [x["ul_loss"] for x in self.loss_history] fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].plot(iters, ll_losses, label="Lower-level Loss", color="blue") axes[0].set_xlabel("Iteration") axes[0].set_ylabel("Loss") axes[0].set_title("Lower-level Loss") axes[0].legend(loc="upper left") axes[0].grid(True) axes[1].plot(iters, ul_losses, label="Upper-level Loss", color="orange") axes[1].set_xlabel("Iteration") axes[1].set_ylabel("Loss") axes[1].set_title("Upper-level Loss") axes[1].legend(loc="upper left") axes[1].grid(True) plt.tight_layout() save_path = os.path.join(os.path.dirname(self.loss_log_path), "loss_curve.png") plt.savefig(save_path) plt.close()
[docs] def save_losses(self, current_iter, ll_loss, ul_loss): """ Save the losses to a JSON file and update the loss history. :param current_iter:iteration number :param ll_loss:lower loss :param ul_loss:upper loss :return: None """ self.loss_history.append({ "iter": current_iter, "ll_loss": float(ll_loss.item()), "ul_loss": float(ul_loss.item()) }) with open(self.loss_log_path, "w") as f: json.dump(self.loss_history, f)