# Meta-Learning This example demonstrates how to use the BOAT library to perform meta-learning tasks, focusing on bi-level optimization using sinusoid functions as the dataset. The explanation is broken down into steps with corresponding code snippets. --- ## Step 1: Importing Libraries and Dependencies ```python import os import json import math import torch import boat_torch as boat import torch.nn.functional as F from torch import nn from torchmeta.datasets.helpers import omniglot from torchmeta.utils.data import BatchMetaDataLoader from tqdm import tqdm import argparse ``` ### Explanation: - Import necessary libraries, including `torch`, `boat`, and `torchmeta`. --- ## Step 2: Argument and Dataset Setup ```python parser = argparse.ArgumentParser(description="BOAT Omniglot Meta-Training ") parser.add_argument("--gm_op", type=str, default="NGD") parser.add_argument("--na_op", type=str, default="CG") parser.add_argument("--fo_op", type=str, default=None) parser.add_argument("--ways", type=int, default=20) parser.add_argument("--shot", type=int, default=1) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = omniglot( "./data/", ways=args.ways, shots=args.shot, test_shots=15, meta_train=True, download=True, ) ``` ### Explanation: - **Dataset**: The Omniglot dataset is used for few-shot meta-learning. - `ways` / `shots`: Control the number of classes and samples per class. - **`device`**: Specify the computation device (CPU in this case). --- ## Step 3: Model and Optimizer Setup ```python meta_model_x, meta_model_y = get_cnn_omniglot(64, args.ways) meta_model_x = meta_model_x.to(device) meta_model_y = meta_model_y.to(device) initialize(meta_model_x) initialize(meta_model_y) inner_opt = torch.optim.SGD(meta_model_y.parameters(), lr=0.4) outer_opt = torch.optim.Adam(meta_model_x.parameters(), lr=0.05) ``` ### Explanation: - `meta_model_y` for the lower-level (task-specific learner). - `meta_model_x` for the upper-level (meta-learner). - **`Optimizers`**: SGD is used for the inner loop, while Adam is used for the outer loop. - **`Initialization`**: Model parameters are re-initialized before training. --- ## Step 4: DataLoader Construction ```python batch_size = 8 dataloader = BatchMetaDataLoader( dataset, batch_size=batch_size, num_workers=1, pin_memory=False, ) ``` ### Explanation: - `BatchMetaDataLoader` constructs batches of tasks for meta-learning. - Each batch contains multiple few-shot classification tasks. --- ## Step 5: Bi-Level Optimization Configuration ```python with open("./configs/boat_config_CG.json", "r") as f: boat_config = json.load(f) with open("./configs/loss_config_CG.json", "r") as f: loss_config = json.load(f) boat_config["gm_op"] = args.gm_op.split(",") if args.gm_op else None boat_config["na_op"] = args.na_op.split(",") if args.na_op else None boat_config["fo_op"] = args.fo_op boat_config["lower_level_model"] = meta_model_y boat_config["upper_level_model"] = meta_model_x boat_config["lower_level_var"] = list(meta_model_y.parameters()) boat_config["upper_level_var"] = list(meta_model_x.parameters()) boat_config["lower_level_opt"] = inner_opt boat_config["upper_level_opt"] = outer_opt b_optimizer = boat.Problem(boat_config, loss_config) b_optimizer.build_ll_solver() b_optimizer.build_ul_solver() ``` ### Explanation: - Configure and initialize the bi-level optimizer using BOAT. - Define models, variables, and optimizers for both levels. --- ## Step 6: Meta-Training Loop ```python max_iters = 20 print("Start meta-training ") with tqdm(dataloader, total=max_iters, desc="Meta Training") as pbar: for meta_iter, batch in enumerate(pbar): initialize(meta_model_y) ll_feed = [ { "data": batch["train"][0][k].float().to(device), "target": batch["train"][1][k].to(device), } for k in range(batch_size) ] ul_feed = [ { "data": batch["test"][0][k].float().to(device), "target": batch["test"][1][k].to(device), } for k in range(batch_size) ] log_results, _ = b_optimizer.run_iter( ll_feed, ul_feed, current_iter=meta_iter ) if meta_iter >= max_iters: print(f"Reached {max_iters} iterations. Stop.") break ``` ### Explanation: - Iterate through batches using `tqdm` for progress visualization. - Prepare feed dictionaries for lower-level and upper-level optimizations. - Call `run_iter` for bi-level optimization, followed by updating the learning rate scheduler.