boat_ms.utils
Submodules
boat_ms.utils.op_utils
- boat_ms.utils.op_utils.average_grad(model, batch_size)[source]
Average the gradients of all model parameters by the batch size.
- Parameters:
model (mindspore.nn.Cell) – The model whose gradients need to be averaged.
batch_size (int) – The batch size to divide gradients by.
- boat_ms.utils.op_utils.copy_parameter_from_list(model, param_list)[source]
Copy parameters from a list to a model’s trainable parameters.
- Parameters:
model (mindspore.nn.Cell) – The model whose parameters need to be updated.
param_list (list of mindspore.Tensor) – The list of parameters to copy.
- boat_ms.utils.op_utils.grad_unused_zero(output, inputs, grad_outputs=None)[source]
Compute gradients for the given inputs, substituting zeros for unused gradients. (MindSpore version)
- Parameters:
output (mindspore.Tensor) – The scalar output tensor for which gradients are computed.
inputs (List[mindspore.Tensor]) – List of input tensors with respect to which gradients are computed.
grad_outputs (mindspore.Tensor, optional) – Sensitivity tensor (same shape as output). Default = ones_like(output).
- Returns:
Gradients for the inputs, with unused gradients replaced by zeros.
- Return type:
List[mindspore.Tensor]
- boat_ms.utils.op_utils.l2_reg(params)[source]
Compute the L2 regularization loss.
- Parameters:
params (list) – List of model parameters (trainable).
- Returns:
The computed L2 regularization loss.
- Return type:
mindspore.Tensor
- boat_ms.utils.op_utils.require_model_grad(model=None)[source]
Ensure all parameters of a MindSpore model require gradients.
- Parameters:
model (mindspore.nn.Cell, optional) – MindSpore model instance. Must not be None.
- Raises:
AssertionError – If the model is None.
- boat_ms.utils.op_utils.stop_grads(grads)[source]
Detach and stop gradient computation for a list of gradients.
- Parameters:
grads (list of mindspore.Tensor) – Gradients to process.
- Returns:
Detached gradients with requires_grad set to False.
- Return type:
list of mindspore.Tensor
- boat_ms.utils.op_utils.stop_model_grad(model=None)[source]
Stop gradient computation for all parameters in a model.
- Parameters:
model (mindspore.nn.Cell, optional) – The model to stop gradients for. Must not be None.
- Raises:
AssertionError – If the model is None.