Meta-Learning
This example demonstrates how to use the BOAT library to perform meta-learning tasks, focusing on bi-level optimization using sinusoid functions as the dataset. The explanation is broken down into steps with corresponding code snippets.
Step 1: Importing Libraries and Dependencies
import os
import json
import math
import torch
import boat_torch as boat
import torch.nn.functional as F
from torch import nn
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from tqdm import tqdm
import argparse
Explanation:
Import necessary libraries, including
torch,boat, andtorchmeta.
Step 2: Argument and Dataset Setup
parser = argparse.ArgumentParser(description="BOAT Omniglot Meta-Training ")
parser.add_argument("--gm_op", type=str, default="NGD")
parser.add_argument("--na_op", type=str, default="CG")
parser.add_argument("--fo_op", type=str, default=None)
parser.add_argument("--ways", type=int, default=20)
parser.add_argument("--shot", type=int, default=1)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = omniglot(
"./data/",
ways=args.ways,
shots=args.shot,
test_shots=15,
meta_train=True,
download=True,
)
Explanation:
Dataset: The Omniglot dataset is used for few-shot meta-learning.
ways/shots: Control the number of classes and samples per class.device: Specify the computation device (CPU in this case).
Step 3: Model and Optimizer Setup
meta_model_x, meta_model_y = get_cnn_omniglot(64, args.ways)
meta_model_x = meta_model_x.to(device)
meta_model_y = meta_model_y.to(device)
initialize(meta_model_x)
initialize(meta_model_y)
inner_opt = torch.optim.SGD(meta_model_y.parameters(), lr=0.4)
outer_opt = torch.optim.Adam(meta_model_x.parameters(), lr=0.05)
Explanation:
meta_model_yfor the lower-level (task-specific learner).meta_model_xfor the upper-level (meta-learner).Optimizers: SGD is used for the inner loop, while Adam is used for the outer loop.Initialization: Model parameters are re-initialized before training.
Step 4: DataLoader Construction
batch_size = 8
dataloader = BatchMetaDataLoader(
dataset,
batch_size=batch_size,
num_workers=1,
pin_memory=False,
)
Explanation:
BatchMetaDataLoaderconstructs batches of tasks for meta-learning.Each batch contains multiple few-shot classification tasks.
Step 5: Bi-Level Optimization Configuration
with open("./configs/boat_config_CG.json", "r") as f:
boat_config = json.load(f)
with open("./configs/loss_config_CG.json", "r") as f:
loss_config = json.load(f)
boat_config["gm_op"] = args.gm_op.split(",") if args.gm_op else None
boat_config["na_op"] = args.na_op.split(",") if args.na_op else None
boat_config["fo_op"] = args.fo_op
boat_config["lower_level_model"] = meta_model_y
boat_config["upper_level_model"] = meta_model_x
boat_config["lower_level_var"] = list(meta_model_y.parameters())
boat_config["upper_level_var"] = list(meta_model_x.parameters())
boat_config["lower_level_opt"] = inner_opt
boat_config["upper_level_opt"] = outer_opt
b_optimizer = boat.Problem(boat_config, loss_config)
b_optimizer.build_ll_solver()
b_optimizer.build_ul_solver()
Explanation:
Configure and initialize the bi-level optimizer using BOAT.
Define models, variables, and optimizers for both levels.
Step 6: Meta-Training Loop
max_iters = 20
print("Start meta-training ")
with tqdm(dataloader, total=max_iters, desc="Meta Training") as pbar:
for meta_iter, batch in enumerate(pbar):
initialize(meta_model_y)
ll_feed = [
{
"data": batch["train"][0][k].float().to(device),
"target": batch["train"][1][k].to(device),
}
for k in range(batch_size)
]
ul_feed = [
{
"data": batch["test"][0][k].float().to(device),
"target": batch["test"][1][k].to(device),
}
for k in range(batch_size)
]
log_results, _ = b_optimizer.run_iter(
ll_feed, ul_feed, current_iter=meta_iter
)
if meta_iter >= max_iters:
print(f"Reached {max_iters} iterations. Stop.")
break
Explanation:
Iterate through batches using
tqdmfor progress visualization.Prepare feed dictionaries for lower-level and upper-level optimizations.
Call
run_iterfor bi-level optimization, followed by updating the learning rate scheduler.