habrahabr

На практике пробуем KAN – принципиально новую архитектуру нейросетей

  • понедельник, 6 мая 2024 г. в 00:00:12
https://habr.com/ru/articles/812147/

На днях ученые из MIT показали альтернативу многослойному перцептрону (MLP). MLP с самого момента изобретения глубокого обучения лежит в основе всех нейросетей, какими мы их знаем сегодня. На его идее в том числе построены большие языковые модели и системы компьютерного зрения.

Однако теперь все может измениться. В KAN (Kolmogorov-Arnold Networks) исследователи реализовали перемещение функций активации с нейронов на ребра нейросети, и такой подход показал блестящие результаты.

Идею KAN ученые подчерпнули из теоремы Колмогорова-Арнольда, именно в их честь и названа архитектура. Вообще говоря, исследование очень математичное, в статье 50 страниц с формулами, повсюду термины из мат.анализа, высшей алгебры, функана и прочего.

В общем, если хотите разобраться с тем, как эта сенсация работает, и при этом не сойти с ума, на нашем сайте мы, команда канала Data Secrets, написали для вас длинный и интересный разбор. Там мы на пальцах объяснили всю математику, рассказали про строение сети, привели примеры и ответили на вопрос "а почему до этого раньше никто не додумался".

Прочитайте, не пожалеете: https://datasecrets.ru/articles/9.

А эта статья - для тех, кто хочет поиграть с новой архитектурой на практике. Мы рассмотрим несколько примеров кода на Python и понаблюдаем, как KAN справляется с привычными нам задачами машинного обучения. Поехали!

Установка

Чтобы участники сообщества могли сразу же потрогать все своими руками, добрые исследователи вместе со статьей представили библиотеку pykan, благодаря которой можно запускать KAN из коробки. Именно с ней мы сегодня и будем работать.

Итак, начнем с установки. Библиотеку можно поставить привычно через pip (pip install pykan) или с помощью клонирования репозитория:

git clone https://github.com/KindXiaoming/pykan.git
cd pykan
pip install -e .
# pip install -r requirements.txt # install requirements

Далее импортируем библиотеку с помощью from kan import * и наконец-то переходим к написанию кода!

Регрессия

Ну куда же без задачи регрессии? Ведь именно с нее началось машинное обучение в 50-х годах прошлого века... Ладно, краткие исторические справки оставим на потом.

Давайте загадаем KAN такую загадку: возьмем функцию от двух переменных f(x,y) = exp(sin(pi*x)+y^2) и попросим KAN по входам и выходам функции найти ее формулу. Это так называемая символьная регрессия. Надо сказать, что задача хоть и кажется тривиальной, но обычно математически трудна для нейросетей.

from kan import *
# формируем KAN: 2D входы, 1D выходы, 5 скрытых нейронов, 
# кубические сплайны и сетка на 5 точках.
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

# сгенерируем датасет
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
#(torch.Size([1000, 2]), torch.Size([1000, 1]))

Сплайны в KAN – это как раз те самые обучаемые функции на ребрах. В математике сплайн – это такая гладкая кривая, кусочно-полиномиальная функция, которая на разных отрезках задается различными полиномами. Каждый сплайн аппроксимируется с помощью заданного количества точек (сетки). Чем больше точек - тем точнее аппроксимация.

Обучающую и тестовую выборки получили, значит можно обучать. Тут ничего нового – привычный метод train:

# train the model
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);

Можно визуализировать KAN, который у нас получился:

Давай посмотрим на эту картинку внимательнее. Наверху мы видим сплайн, похожий на экспоненту, а затем слева и справа наблюдаем соотвественно синус и параболу. Ничего не напоминает?

Все верно, если сложить все вместе, то получится формула которую мы загадывали: f(x,y) = exp(sin(pi*x)+y^2). Благодаря тому, что в KAN обучаются не параметры (числа), а функции, он почти идеально справляется с задачей регрессии на сложных функциях и, как показали исследователи, гораздо эффективнее генерализирует данные. В частности, в этой задаче мы получаем метрику r2 равной 0.99.

В статье исследователи также показали, как KAN помогает решать дифференциальные уравнения и (пере)открывает законы физики и математики.

Классификация

Тут все еще интереснее. Но все по порядку. Снова сгенерируем игрушечный датасет (в сообществе его прозвали "две луны"):

from kan import KAN
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np

dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)

dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label[:,None])
dataset['test_label'] = torch.from_numpy(test_label[:,None])

X = dataset['train_input']
y = dataset['train_label']
plt.scatter(X[:,0], X[:,1], c=y[:,0])

Для начала давайте немного развлечемся и решим задачу, как будто это регрессия: будем предсказывать некоторое число, округлять его и сравнивать с реальной меткой класса.

model = KAN(width=[2,1], grid=3, k=3)

def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1]
# (1.0, 1.0)

По последней строке видно: KAN справился идеально. Если заглянуть глубже, то мы увидим, что (опять же с помощью обучения функций) сетка вывела для себя "формулу ответа", которая и помогает ей безупречно справится с задачей:

А теперь попробуем по-взрослому, с кросс-энтропией, логитами и argmax. Вот код, в котором мы немного подправляем размерности в датасете и обучаем KAN:

dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)

dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label)
dataset['test_label'] = torch.from_numpy(test_label)

X = dataset['train_input']
y = dataset['train_label']

model = KAN(width=[2,2], grid=3, k=3)

def train_acc():
    return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())

def test_acc():
    return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())

results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss());

Точность в этом случае немного ниже, но все еще достаточно хороша: 0.9660. Кстати, вот так можно посмотреть на формулы KAN (для каждого класса формула своя):

lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
formula1, formula2 = model.symbolic_formula()[0]

В данном случае они получаются такими:

Заключение

В статье мы рассмотрели, как запустить KAN для привычных задач регрессии и классификации и немного заглянули "под капот" архитектуры. Если хотите больше примеров – загляните в документацию или в репозитория проекта, там лежат очень красивые и понятные ноутбуки, в которых можно найти туториалы по библиотеке и кейсы использования KAN.

Авторы KAN доказали, что ему требуется во много раз меньше нейронов, чтобы достичь точности MLP. Также KAN гораздо лучше генерализует данные и лучше справляется с аппроксимацией сложных математических функций (мы увидели это на примерах), у него, можно сказать, "технический склад ума".

Однако у архитектуры есть бутылочное горлышко: KAN учится медленнее MLP примерно в 10 раз. Возможно, это станет серьезным камнем преткновения, а возможно инженеры быстро научатся оптимизировать эффективность таких сетей.

Больше новостей из мира машинного обучения можно найти в нашем телеграм-канале. Подписывайтесь, чтобы быть в курсе: @data_secrets.