BOAT-Torch Structure

Core Problem Class

class boat_torch.boat_opt.Problem(config, loss_config)[source]

Bases: object

Enhanced bi-level optimization problem class supporting flexible loss functions and operation configurations.

build_ll_solver()[source]

Configure the lower-level solver.

Returns:

None

build_ul_solver()[source]

Configure the lower-level solver.

Returns:

None

check_status()[source]
plot_losses()[source]
run_iter(ll_feed_dict, ul_feed_dict, current_iter)[source]

Run a single iteration of the bi-level optimization process.

Parameters:
  • ll_feed_dict (Dict[str, Tensor]) –

    Dictionary containing the real-time data and parameters fed for the construction of the lower-level (LL) objective.

    Example:

    {
        "image": train_images,
        "text": train_texts,
        "target": train_labels  # Optional
    }
    

  • ul_feed_dict (Dict[str, Tensor]) –

    Dictionary containing the real-time data and parameters fed for the construction of the upper-level (UL) objective.

    Example:

    {
        "image": val_images,
        "text": val_texts,
        "target": val_labels  # Optional
    }
    

  • current_iter (int) – The current iteration number.

Notes:
  • When accumulate_grad is set to True, you need to pack the data of each batch based on the format above.

  • In that case, pass ll_feed_dict and ul_feed_dict as lists of dictionaries, i.e., [Dict[str, Tensor]].

Returns:

A tuple containing: - loss (float): The loss value for the current iteration. - run_time (float): The total time taken for the iteration.

Return type:

tuple

save_losses(current_iter, ll_loss, ul_loss)[source]

Save the losses to a JSON file and update the loss history. :param current_iter:iteration number :param ll_loss:lower loss :param ul_loss:upper loss :return: None

set_track_trajectory(track_traj=True)[source]

Main Subpackages

Extension with Operation_Registry

boat_torch.operation_registry.get_registered_operation(name)[source]

Retrieve a registered operation class by name.

boat_torch.operation_registry.register_class(cls)[source]

Register a new operation class to the global registry.