boat_torch.utils
Submodules
boat_torch.utils.op_utils
- class boat_torch.utils.op_utils.DynamicalSystemRules[source]
Bases:
objectA class to store and manage gradient operator rules.
- class boat_torch.utils.op_utils.HyperGradientRules[source]
Bases:
objectA class to store and manage gradient operator rules.
- class boat_torch.utils.op_utils.ResultStore[source]
Bases:
objectA simple class to store and manage intermediate results of hyper-gradient computation.
- boat_torch.utils.op_utils.average_grad(model, batch_size)[source]
Average gradients over a batch.
- Parameters:
model (torch.nn.Module) – Model whose gradients are averaged.
batch_size (int) – The batch size used for averaging.
- boat_torch.utils.op_utils.cat_list_to_tensor(list_tx)[source]
Concatenate a list of tensors into a single tensor.
- Parameters:
list_tx (List[Tensor]) – List of tensors to concatenate.
- Returns:
The concatenated tensor.
- Return type:
Tensor
- boat_torch.utils.op_utils.cg_step(Ax, b, max_iter=100, epsilon=1e-05)[source]
Perform the conjugate gradient optimization step.
- Parameters:
Ax (Callable) – Function to compute the matrix-vector product.
b (List[Tensor]) – Right-hand side of the linear system.
max_iter (int, optional) – Maximum number of iterations, by default 100.
epsilon (float, optional) – Tolerance for early stopping, by default 1.0e-5.
- Returns:
Solution vector for the linear system.
- Return type:
List[Tensor]
- boat_torch.utils.op_utils.conjugate_gradient(params, hparams, upper_loss, lower_loss, K, fp_map, tol=1e-10)[source]
Compute hypergradients using the conjugate gradient method.
- Parameters:
params (List[Tensor]) – List of lower-level parameters.
hparams (List[Tensor]) – List of upper-level hyperparameters.
upper_loss (Tensor) – The upper-level loss.
lower_loss (Tensor) – The lower-level loss.
K (int) – Maximum number of iterations for the conjugate gradient method.
fp_map (Callable) – Fixed-point mapping function.
tol (float, optional) – Tolerance for early stopping, by default 1e-10.
- Returns:
Hypergradients for the upper-level hyperparameters.
- Return type:
List[Tensor]
- boat_torch.utils.op_utils.copy_parameter_from_list(y, z)[source]
Copy parameters from a list to a model.
- Parameters:
y (torch.nn.Module) – Target model to which parameters are copied.
z (List[torch.Tensor]) – List of source parameters.
- Returns:
The updated model with copied parameters.
- Return type:
torch.nn.Module
- boat_torch.utils.op_utils.get_outer_gradients(outer_loss, params, hparams, retain_graph=True)[source]
Compute the gradients of the outer-level loss with respect to parameters and hyperparameters.
- Parameters:
outer_loss (Tensor) – The outer-level loss.
params (List[Tensor]) – List of tensors representing parameters.
hparams (List[Tensor]) – List of tensors representing hyperparameters.
retain_graph (bool, optional) – Whether to retain the computation graph, by default True.
- Returns:
Gradients with respect to parameters and hyperparameters.
- Return type:
Tuple[List[Tensor], List[Tensor]]
- boat_torch.utils.op_utils.grad_unused_zero(output, inputs, grad_outputs=None, retain_graph=False, create_graph=False)[source]
Compute gradients for the given inputs, substituting zeros for unused gradients.
- Parameters:
output (torch.Tensor) – The output tensor for which gradients are computed.
inputs (List[torch.Tensor]) – List of input tensors with respect to which gradients are computed.
grad_outputs (torch.Tensor, optional) – Gradient outputs to compute the gradients of the inputs, by default None.
retain_graph (bool, optional) – If True, the computation graph is retained after the gradient computation, by default False.
create_graph (bool, optional) – If True, constructs the graph for higher-order gradient computations, by default False.
- Returns:
Gradients for the inputs, with unused gradients replaced by zeros.
- Return type:
Tuple[torch.Tensor]
- boat_torch.utils.op_utils.l2_reg(parameters)[source]
Compute the L2 regularization term for a list of parameters.
- Parameters:
parameters (List[torch.Tensor]) – List of tensors for which the L2 regularization term is computed.
- Returns:
The L2 regularization loss value.
- Return type:
torch.Tensor
- boat_torch.utils.op_utils.list_tensor_matmul(list1, list2)[source]
Perform element-wise multiplication and summation for two lists of tensors.
- Parameters:
list1 (List[torch.Tensor]) – First list of tensors.
list2 (List[torch.Tensor]) – Second list of tensors.
- Returns:
Result of the element-wise multiplication and summation.
- Return type:
torch.Tensor
- boat_torch.utils.op_utils.list_tensor_norm(list_tensor, p=2)[source]
Compute the p-norm of a list of tensors.
- Parameters:
list_tensor (List[torch.Tensor]) – List of tensors for which the norm is computed.
p (int, optional) – Order of the norm, by default 2.
- Returns:
The computed p-norm of the list of tensors.
- Return type:
torch.Tensor
- boat_torch.utils.op_utils.neumann(params, hparams, upper_loss, lower_loss, k, fp_map, tol=1e-10)[source]
Compute hypergradients using Neumann series approximation.
- Parameters:
params (List[Tensor]) – List of lower-level parameters.
hparams (List[Tensor]) – List of upper-level hyperparameters.
upper_loss (Tensor) – The upper-level loss.
lower_loss (Tensor) – The lower-level loss.
k (int) – Number of iterations for Neumann approximation.
fp_map (Callable) – Fixed-point mapping function.
tol (float, optional) – Tolerance for early stopping, by default 1e-10.
- Returns:
Hypergradients for the upper-level hyperparameters.
- Return type:
List[Tensor]
- boat_torch.utils.op_utils.require_model_grad(model=None)[source]
Enable gradient computation for all parameters in the given model.
- Parameters:
model (torch.nn.Module, optional) – The model whose parameters require gradient computation.
- Raises:
AssertionError – If the model is None.
- boat_torch.utils.op_utils.stop_grads(grads)[source]
Stop gradient computation for the given gradients.
- Parameters:
grads (List[torch.Tensor]) – Gradients to be detached from the computation graph.
- Returns:
Gradients detached from the computation graph.
- Return type:
List[torch.Tensor]
- boat_torch.utils.op_utils.stop_model_grad(model=None)[source]
Disable gradient computation for all parameters in the given model.
- Parameters:
model (torch.nn.Module, optional) – The model whose parameters no longer require gradient computation.
- Raises:
AssertionError – If the model is None.