Meta-Learning

This example demonstrates how to use the BOAT library to perform meta-learning tasks, focusing on bi-level optimization using sinusoid functions as the dataset. The explanation is broken down into steps with corresponding code snippets.


Step 1: Importing Libraries and Dependencies

import os
import json
import math
import torch
import boat_torch as boat
import torch.nn.functional as F
from torch import nn
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from tqdm import tqdm
import argparse

Explanation:

  • Import necessary libraries, including torch, boat, and torchmeta.


Step 2: Argument and Dataset Setup

parser = argparse.ArgumentParser(description="BOAT Omniglot Meta-Training ")
parser.add_argument("--gm_op", type=str, default="NGD")
parser.add_argument("--na_op", type=str, default="CG")
parser.add_argument("--fo_op", type=str, default=None)
parser.add_argument("--ways", type=int, default=20)
parser.add_argument("--shot", type=int, default=1)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = omniglot(
    "./data/",
    ways=args.ways,
    shots=args.shot,
    test_shots=15,
    meta_train=True,
    download=True,
)

Explanation:

  • Dataset: The Omniglot dataset is used for few-shot meta-learning.

  • ways / shots: Control the number of classes and samples per class.

  • device: Specify the computation device (CPU in this case).


Step 3: Model and Optimizer Setup

meta_model_x, meta_model_y = get_cnn_omniglot(64, args.ways)
meta_model_x = meta_model_x.to(device)
meta_model_y = meta_model_y.to(device)
initialize(meta_model_x)
initialize(meta_model_y)

inner_opt = torch.optim.SGD(meta_model_y.parameters(), lr=0.4)
outer_opt = torch.optim.Adam(meta_model_x.parameters(), lr=0.05)

Explanation:

  • meta_model_y for the lower-level (task-specific learner).

  • meta_model_x for the upper-level (meta-learner).

  • Optimizers: SGD is used for the inner loop, while Adam is used for the outer loop.

  • Initialization: Model parameters are re-initialized before training.


Step 4: DataLoader Construction

batch_size = 8
dataloader = BatchMetaDataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=False,
)

Explanation:

  • BatchMetaDataLoader constructs batches of tasks for meta-learning.

  • Each batch contains multiple few-shot classification tasks.


Step 5: Bi-Level Optimization Configuration

with open("./configs/boat_config_CG.json", "r") as f:
    boat_config = json.load(f)
with open("./configs/loss_config_CG.json", "r") as f:
    loss_config = json.load(f)

boat_config["gm_op"] = args.gm_op.split(",") if args.gm_op else None
boat_config["na_op"] = args.na_op.split(",") if args.na_op else None
boat_config["fo_op"] = args.fo_op

boat_config["lower_level_model"] = meta_model_y
boat_config["upper_level_model"] = meta_model_x
boat_config["lower_level_var"] = list(meta_model_y.parameters())
boat_config["upper_level_var"] = list(meta_model_x.parameters())
boat_config["lower_level_opt"] = inner_opt
boat_config["upper_level_opt"] = outer_opt

b_optimizer = boat.Problem(boat_config, loss_config)
b_optimizer.build_ll_solver()
b_optimizer.build_ul_solver()

Explanation:

  • Configure and initialize the bi-level optimizer using BOAT.

  • Define models, variables, and optimizers for both levels.


Step 6: Meta-Training Loop

max_iters = 20
print("Start meta-training ")

with tqdm(dataloader, total=max_iters, desc="Meta Training") as pbar:
    for meta_iter, batch in enumerate(pbar):
        initialize(meta_model_y)

        ll_feed = [
            {
                "data": batch["train"][0][k].float().to(device),
                "target": batch["train"][1][k].to(device),
            }
            for k in range(batch_size)
        ]

        ul_feed = [
            {
                "data": batch["test"][0][k].float().to(device),
                "target": batch["test"][1][k].to(device),
            }
            for k in range(batch_size)
        ]

        log_results, _ = b_optimizer.run_iter(
            ll_feed, ul_feed, current_iter=meta_iter
        )

        if meta_iter >= max_iters:
            print(f"Reached {max_iters} iterations. Stop.")
            break

Explanation:

  • Iterate through batches using tqdm for progress visualization.

  • Prepare feed dictionaries for lower-level and upper-level optimizations.

  • Call run_iter for bi-level optimization, followed by updating the learning rate scheduler.