learnables / learn2learn
- вторник, 17 сентября 2019 г. в 00:23:58
Python
PyTorch Meta-learning Framework for Researchers
learn2learn is a PyTorch library for meta-learning implementations.
The goal of meta-learning is to enable agents to learn how to learn. That is, we would like our agents to become better learners as they solve more and more tasks. For example, the animation below shows an agent that learns to run after a only one parameter update.
Features
learn2learn provides high- and low-level utilities for meta-learning. The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms. The low-level utilities enable researchers to develop new and better meta-learning algorithms.
Some features of learn2learn include:
pip install learn2learnThe following is an example of using the high-level MAML implementation on MNIST. For more algorithms and lower-level utilities, please refer to the documentation or the examples.
import learn2learn as l2l
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
ways=3,
classes=[0, 1, 4, 6, 8, 9],
tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)
for iteration in range(num_iterations):
learner = maml.clone() # Creates a clone of model
adaptation_task = task_generator.sample(shots=1)
# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(adaptation_task)
learner.adapt(error)
# Compute evaluation loss
evaluation_task = task_generator.sample(shots=1,
task=adaptation_task.sampled_task)
evaluation_error = compute_loss(evaluation_task)
# Meta-update the model parameters
opt.zero_grad()
evaluation_error.backward()
opt.step()