Source code for boat_torch.utils.op_utils

import torch
from torch import Tensor
from typing import List, Callable, Dict


[docs] class ResultStore: """ A simple class to store and manage intermediate results of hyper-gradient computation. """ def __init__(self): self.results = []
[docs] def add(self, name: str, result: Dict): """ Add a result to the store. :param name: The name of the result (e.g., 'gradient_operator_results_0'). :type name: str :param result: The result dictionary to store. :type result: Dict """ self.results.append({name: result})
[docs] def clear(self): """Clear all stored results.""" self.results = []
[docs] def get_results(self) -> List[Dict]: """Retrieve all stored results.""" return self.results
[docs] class DynamicalSystemRules: """ A class to store and manage gradient operator rules. """ # Default static gradient operator order _gradient_order = [ ["GDA", "DI"], ["NGD", "DM"], ]
[docs] @staticmethod def get_gradient_order() -> List[List[str]]: """ Get the current gradient operator order. Returns ------- List[List[str]] The current gradient operator order. """ return DynamicalSystemRules._gradient_order
[docs] @staticmethod def set_gradient_order(new_order: List[List[str]]): """ Set a new gradient operator order. Parameters ---------- new_order : List[List[str]] The new gradient operator order to set. Raises ------ ValueError If the new order is invalid. """ if not isinstance(new_order, list) or not all( isinstance(group, list) for group in new_order ): raise ValueError("Gradient order must be a list of lists.") DynamicalSystemRules._gradient_order = new_order
[docs] class HyperGradientRules: """ A class to store and manage gradient operator rules. """ # Default static gradient operator order _gradient_order = [ ["PTT", "FOA", "RGT"], ["IAD", "RAD", "FD", "IGA"], ["CG", "NS"], ]
[docs] @staticmethod def get_gradient_order() -> List[List[str]]: """ Get the current gradient operator order. Returns ------- List[List[str]] The current gradient operator order. """ return HyperGradientRules._gradient_order
[docs] @staticmethod def set_gradient_order(new_order: List[List[str]]): """ Set a new gradient operator order. Parameters ---------- new_order : List[List[str]] The new gradient operator order to set. Raises ------ ValueError If the new order is invalid. """ if not isinstance(new_order, list) or not all( isinstance(group, list) for group in new_order ): raise ValueError("Gradient order must be a list of lists.") HyperGradientRules._gradient_order = new_order
[docs] def l2_reg(parameters): """ Compute the L2 regularization term for a list of parameters. Parameters ---------- parameters : List[torch.Tensor] List of tensors for which the L2 regularization term is computed. Returns ------- torch.Tensor The L2 regularization loss value. """ loss = 0 for w in parameters: loss += torch.norm(w, 2) ** 2 return loss
[docs] def grad_unused_zero( output, inputs, grad_outputs=None, retain_graph=False, create_graph=False ): """ Compute gradients for the given inputs, substituting zeros for unused gradients. Parameters ---------- output : torch.Tensor The output tensor for which gradients are computed. inputs : List[torch.Tensor] List of input tensors with respect to which gradients are computed. grad_outputs : torch.Tensor, optional Gradient outputs to compute the gradients of the inputs, by default None. retain_graph : bool, optional If True, the computation graph is retained after the gradient computation, by default False. create_graph : bool, optional If True, constructs the graph for higher-order gradient computations, by default False. Returns ------- Tuple[torch.Tensor] Gradients for the inputs, with unused gradients replaced by zeros. """ grads = torch.autograd.grad( output, inputs, grad_outputs=grad_outputs, allow_unused=True, retain_graph=retain_graph, create_graph=create_graph, ) def grad_or_zeros(grad, var): return ( torch.zeros_like(var) if grad is None or (torch.isnan(grad).any()) else grad ) return tuple(grad_or_zeros(g, v) for g, v in zip(grads, inputs))
[docs] def list_tensor_matmul(list1, list2): """ Perform element-wise multiplication and summation for two lists of tensors. Parameters ---------- list1 : List[torch.Tensor] First list of tensors. list2 : List[torch.Tensor] Second list of tensors. Returns ------- torch.Tensor Result of the element-wise multiplication and summation. """ out = 0 for t1, t2 in zip(list1, list2): out = out + torch.sum(t1 * t2) return out
[docs] def list_tensor_norm(list_tensor, p=2): """ Compute the p-norm of a list of tensors. Parameters ---------- list_tensor : List[torch.Tensor] List of tensors for which the norm is computed. p : int, optional Order of the norm, by default 2. Returns ------- torch.Tensor The computed p-norm of the list of tensors. """ norm = 0 for t in list_tensor: norm = norm + torch.norm(t, p) return norm
[docs] def require_model_grad(model=None): """ Enable gradient computation for all parameters in the given model. Parameters ---------- model : torch.nn.Module, optional The model whose parameters require gradient computation. Raises ------ AssertionError If the model is None. """ assert model is not None, "The module is not defined!" for param in model.parameters(): param.requires_grad_(True)
[docs] def update_grads(grads, model): """ Update gradients of a model with the given gradients. Parameters ---------- grads : List[torch.Tensor] Gradients to be applied to the model's parameters. model : torch.nn.Module Model whose gradients are updated. """ for p, x in zip(grads, model.parameters()): if x.grad is None: x.grad = p else: x.grad += p
[docs] def update_tensor_grads(hparams, grads): """ Update gradients of hyperparameters with the given gradients. Parameters ---------- hparams : List[torch.Tensor] Hyperparameters whose gradients are updated. grads : List[torch.Tensor] Gradients to be applied to the hyperparameters. """ for l, g in zip(hparams, grads): if l.grad is None: l.grad = torch.zeros_like(l) #print('first_time') # else: if g is not None: #print("accumulate") l.grad += g
[docs] def stop_grads(grads): """ Stop gradient computation for the given gradients. Parameters ---------- grads : List[torch.Tensor] Gradients to be detached from the computation graph. Returns ------- List[torch.Tensor] Gradients detached from the computation graph. """ return [ (grad.detach().requires_grad_(False) if grad is not None else grad) for grad in grads ]
[docs] def average_grad(model, batch_size): """ Average gradients over a batch. Parameters ---------- model : torch.nn.Module Model whose gradients are averaged. batch_size : int The batch size used for averaging. """ for param in model.parameters(): param.grad.data = param.grad.data / batch_size
[docs] def stop_model_grad(model=None): """ Disable gradient computation for all parameters in the given model. Parameters ---------- model : torch.nn.Module, optional The model whose parameters no longer require gradient computation. Raises ------ AssertionError If the model is None. """ assert model is not None, "The module is not defined!" for param in model.parameters(): param.requires_grad_(False)
[docs] def copy_parameter_from_list(y, z): """ Copy parameters from a list to a model. Parameters ---------- y : torch.nn.Module Target model to which parameters are copied. z : List[torch.Tensor] List of source parameters. Returns ------- torch.nn.Module The updated model with copied parameters. """ for p, q in zip(y.parameters(), z): p.data = q.clone().detach().requires_grad_() return y
[docs] def get_outer_gradients(outer_loss, params, hparams, retain_graph=True): """ Compute the gradients of the outer-level loss with respect to parameters and hyperparameters. Parameters ---------- outer_loss : Tensor The outer-level loss. params : List[Tensor] List of tensors representing parameters. hparams : List[Tensor] List of tensors representing hyperparameters. retain_graph : bool, optional Whether to retain the computation graph, by default True. Returns ------- Tuple[List[Tensor], List[Tensor]] Gradients with respect to parameters and hyperparameters. """ grad_outer_w = grad_unused_zero(outer_loss, params, retain_graph=retain_graph) grad_outer_hparams = grad_unused_zero( outer_loss, hparams, retain_graph=retain_graph ) return grad_outer_w, grad_outer_hparams
[docs] def cat_list_to_tensor(list_tx): """ Concatenate a list of tensors into a single tensor. Parameters ---------- list_tx : List[Tensor] List of tensors to concatenate. Returns ------- Tensor The concatenated tensor. """ return torch.cat([xx.view([-1]) for xx in list_tx])
[docs] def neumann( params: List[Tensor], hparams: List[Tensor], upper_loss, lower_loss, k: int, fp_map: Callable[[List[Tensor], List[Tensor]], List[Tensor]], tol=1e-10, ) -> List[Tensor]: """ Compute hypergradients using Neumann series approximation. Parameters ---------- params : List[Tensor] List of lower-level parameters. hparams : List[Tensor] List of upper-level hyperparameters. upper_loss : Tensor The upper-level loss. lower_loss : Tensor The lower-level loss. k : int Number of iterations for Neumann approximation. fp_map : Callable Fixed-point mapping function. tol : float, optional Tolerance for early stopping, by default 1e-10. Returns ------- List[Tensor] Hypergradients for the upper-level hyperparameters. """ grad_outer_w, grad_outer_hparams = get_outer_gradients(upper_loss, params, hparams) w_mapped = fp_map(params, lower_loss) vs, gs = grad_outer_w, grad_outer_w gs_vec = cat_list_to_tensor(gs) for i in range(k): gs_prev_vec = gs_vec vs = torch.autograd.grad(w_mapped, params, grad_outputs=vs, retain_graph=True) gs = [g + v for g, v in zip(gs, vs)] gs_vec = cat_list_to_tensor(gs) if float(torch.norm(gs_vec - gs_prev_vec)) < tol: break grads = torch.autograd.grad(w_mapped, hparams, grad_outputs=gs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] return grads
[docs] def conjugate_gradient( params: List[Tensor], hparams: List[Tensor], upper_loss, lower_loss, K: int, fp_map: Callable[[List[Tensor], List[Tensor]], List[Tensor]], tol=1e-10, ) -> List[Tensor]: """ Compute hypergradients using the conjugate gradient method. Parameters ---------- params : List[Tensor] List of lower-level parameters. hparams : List[Tensor] List of upper-level hyperparameters. upper_loss : Tensor The upper-level loss. lower_loss : Tensor The lower-level loss. K : int Maximum number of iterations for the conjugate gradient method. fp_map : Callable Fixed-point mapping function. tol : float, optional Tolerance for early stopping, by default 1e-10. Returns ------- List[Tensor] Hypergradients for the upper-level hyperparameters. """ grad_outer_w, grad_outer_hparams = get_outer_gradients(upper_loss, params, hparams) w_mapped = fp_map(params, lower_loss) def dfp_map_dw(xs): Jfp_mapTv = torch.autograd.grad( w_mapped, params, grad_outputs=xs, retain_graph=True ) return [v - j for v, j in zip(xs, Jfp_mapTv)] vs = cg_step(dfp_map_dw, grad_outer_w, max_iter=K, epsilon=tol) grads = torch.autograd.grad(w_mapped, hparams, grad_outputs=vs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] return grads
[docs] def cg_step(Ax, b, max_iter=100, epsilon=1.0e-5): """ Perform the conjugate gradient optimization step. Parameters ---------- Ax : Callable Function to compute the matrix-vector product. b : List[Tensor] Right-hand side of the linear system. max_iter : int, optional Maximum number of iterations, by default 100. epsilon : float, optional Tolerance for early stopping, by default 1.0e-5. Returns ------- List[Tensor] Solution vector for the linear system. """ x_last = [torch.zeros_like(bb) for bb in b] r_last = [torch.zeros_like(bb).copy_(bb) for bb in b] p_last = [torch.zeros_like(rr).copy_(rr) for rr in r_last] for ii in range(max_iter): Ap = Ax(p_last) Ap_vec = cat_list_to_tensor(Ap) p_last_vec = cat_list_to_tensor(p_last) r_last_vec = cat_list_to_tensor(r_last) rTr = torch.sum(r_last_vec * r_last_vec) pAp = torch.sum(p_last_vec * Ap_vec) alpha = rTr / pAp x = [xx + alpha * pp for xx, pp in zip(x_last, p_last)] r = [rr - alpha * pp for rr, pp in zip(r_last, Ap)] r_vec = cat_list_to_tensor(r) if float(torch.norm(r_vec)) < epsilon: break beta = torch.sum(r_vec * r_vec) / rTr p = [rr + beta * pp for rr, pp in zip(r, p_last)] x_last = x p_last = p r_last = r return x_last