Ускоряем анализ данных в 170 000 раз с помощью Python
- понедельник, 5 февраля 2024 г. в 00:00:17
В статье «Ускоряем анализ данных в 180 000 раз с помощью Rust» показано, как неоптимизированный код на Python, после переписывания и оптимизации на Rust, ускоряется в 180 000 раз. Автор отмечает: «есть множество способов сделать код на Python быстрее, но смысл этого поста не в том, чтобы сравнить высокооптимизированный Python с высокооптимизированным Rust. Смысл в том, чтобы сравнить "стандартный-Jupyter-notebook" Python с высокооптимизированным Rust».
Возникает вопрос: какого ускорения мы могли бы достичь, если бы остановились на Python?
Под катом разработчик Сидни Рэдклифф* проходит путь профилирования и итеративного ускорения кода на Python, чтобы выяснить это.
*Обращаем ваше внимание, что позиция автора может не всегда совпадать с мнением МойОфис.
Как и в упомянутой выше статье, мы используем M1 Macbook, и по тем же бенчмаркам получаем сопоставимые показатели:
Среднее время итерации исходного неоптимизированного кода, измеренное за 1000 итераций, — 35 мс. В оригинальной статье — 36 мс.
После полной оптимизации код на Rust ускорен в 180,081 раз. В оригинальной статье сообщается о 182 450-кратном ускорении.
Вот неоптимизированный код на Python из ранее упомянутой статьи.
from itertools import combinations
import pandas as pd
from pandas import IndexSlice as islice
def k_corrset(data, K):
all_qs = data.question.unique()
q_to_score = data.set_index(['question', 'user'])
all_grand_totals = data.groupby('user').score.sum().rename('grand_total')
# Inner loop
corrs = []
for qs in combinations(all_qs, K):
qs_data = q_to_score.loc[islice[qs,:],:].swaplevel()
answered_all = qs_data.groupby(level=[0]).size() == K
answered_all = answered_all[answered_all].index
qs_totals = qs_data.loc[islice[answered_all,:]] \
.groupby(level=[0]).sum().rename(columns={'score': 'qs'})
r = qs_totals.join(all_grand_totals).corr().qs.grand_total
corrs.append({'qs': qs, 'r': r})
corrs = pd.DataFrame(corrs)
return corrs.sort_values('r', ascending=False).iloc[0].qs
data = pd.read_json('scores.json')
print(k_corrset(data, K=5))
А вот первые две строки DataFrame (далее — датафрейм), data
.
user | question | score |
e213cc2b-387e-4d7d-983c-8abc19a586b1 | d3bdb068-7245-4521-ae57-d0e9692cb627 | 1 |
951ffaee-6e17-4599-a8c0-9dfd00470cd9 | d3bdb068-7245-4521-ae57-d0e9692cb627 | 0 |
Для проверки корректности нашего оптимизированного кода, мы можем использовать вывод исходного кода.
Поскольку мы пытаемся оптимизировать внутренний цикл, поместим его в собственную функцию, чтобы профилировать с помощью line_profiler.
Avg time per iteration: 35 ms
Speedup over baseline: 1.0x
% Time Line Contents
=====================
def compute_corrs(
qs_iter: Iterable, q_to_score: pd.DataFrame, grand_totals: pd.DataFrame
):
0.0 result = []
0.0 for qs in qs_iter:
13.5 qs_data = q_to_score.loc[islice[qs, :], :].swaplevel()
70.1 answered_all = qs_data.groupby(level=[0]).size() == K
0.4 answered_all = answered_all[answered_all].index
0.0 qs_total = (
6.7 qs_data.loc[islice[answered_all, :]]
1.1 .groupby(level=[0])
0.6 .sum()
0.3 .rename(columns={"score": "qs"})
)
7.4 r = qs_total.join(grand_totals).corr().qs.grand_total
0.0 result.append({"qs": qs, "r": r})
0.0 return result
Мы видим значения, которые пытаемся оптимизировать (среднее время итерации/ускорение), а также долю времени, потраченного на выполнение каждой строки.
Это позволяет оптимизировать код следующим образом:
Запускаем профилировщик
Определяем самые медленные строки
Пробуем сделать медленные строки более быстрыми
Повторяем
В приведённом выше коде мы видим, что есть наиболее медленная строка, которая занимает ~70% времени.
Однако есть еще один важный шаг, который предшествует вышеупомянутым:
Проверяем вывод на корректность
Запускаем профилировщик
Определяем самые медленные строки
Пробуем сделать медленные строки более быстрыми
Повторяем
Проверки корректности вывода помогают экспериментировать, пробовать различные методы, библиотеки и т.д., зная при этом, что любые случайные изменения в вычисляемой информации будут отслежены.
Наш базовый код выполняет различные тяжёлые операции Pandas, выясняя, какие пользователи ответили на заданный набор вопросов — qs
. В частности, для этого он проверяет каждую строку датафрейма, чтобы определить, какие пользователи отвечали на вопросы. В первой оптимизации вместо полноценного датафрейма мы можем использовать словарь множеств пользователей. Это позволит нам быстро выяснить, какие пользователи ответили на каждый вопрос qs
, и использовать пересечение множеств в Python, чтобы выявить пользователей, ответивших на все вопросы.
Avg time per iteration: 10.0 ms
Speedup over baseline: 3.5x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, q_to_score, grand_totals):
0.0 result = []
0.0 for qs in qs_iter:
0.0 user_sets_for_qs = [users_who_answered_q[q] for q in qs]
3.6 answered_all = set.intersection(*user_sets_for_qs)
40.8 qs_data = q_to_score.loc[islice[qs, :], :].swaplevel()
0.0 qs_total = (
22.1 qs_data.loc[islice[list(answered_all), :]]
3.7 .groupby(level=[0])
1.9 .sum()
1.1 .rename(columns={"score": "qs"})
)
26.8 r = qs_total.join(grand_totals).corr().qs.grand_total
0.0 result.append({"qs": qs, "r": r})
0.0 return result
Так мы значительно ускоряем вычисление строки answered_all
, которая вместо 70 % теперь занимает 4 %, и наш код становится быстрее в 3 раза.
Если сложить время, затрачиваемое на каждую строку, участвующую в вычислении qs_total
(включая строку qs_data
), то получится ~65%; таким образом, наша следующая задача по оптимизации ясна. Нужно снова заменить тяжёлые операции над полным датафреймом (индексирование, группировка и т. д.) быстрым поиском по словарю. Для этого вводим score_dict
, словарь, который позволяет проводить оценку для заданной пары вопрос-пользователь.
Avg time per iteration: 690 μs
Speedup over baseline: 50.8x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals):
0.0 result = []
0.0 for qs in qs_iter:
0.1 user_sets_for_qs = [users_who_answered_q[q] for q in qs]
35.9 answered_all = set.intersection(*user_sets_for_qs)
3.4 qs_total = {u: sum(score_dict[q, u] for q in qs) for u in answered_all}
8.6 qs_total = pd.DataFrame.from_dict(qs_total, orient="index", columns=["qs"])
0.1 qs_total.index.name = "user"
51.8 r = qs_total.join(grand_totals).corr().qs.grand_total
0.0 result.append({"qs": qs, "r": r})
0.0 return result
Это помогает нам ускорить код в 50 раз.
Самая медленная строка в коде выше делает несколько вещей: сперва Pandas join
, чтобы объединить grand_totals
и qs_total
, а затем вычисляет для этого коэффициент корреляции. Опять же, мы можем ускорить процесс, используя поиск по словарю вместо join, и поскольку у нас больше нет объектов Pandas, используем np.corrcoef
вместо Pandas corr
.
Avg time per iteration: 380 μs
Speedup over baseline: 91.6x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals):
0.0 result = []
0.0 for qs in qs_iter:
0.2 user_sets_for_qs = [users_who_answered_q[q] for q in qs]
83.9 answered_all = set.intersection(*user_sets_for_qs)
7.2 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all]
0.5 user_grand_total = [grand_totals[u] for u in answered_all]
8.1 r = np.corrcoef(qs_total, user_grand_total)[0, 1]
0.1 result.append({"qs": qs, "r": r})
0.0 return result
Получаем ~90-кратное ускорение кода.
Эта оптимизация не вносит никаких изменений в код внутреннего цикла. Но она ускоряет некоторые операции. Мы заменяем длинные uuid пользователя/вопроса (например, e213cc2b-387e-4d7d-983c-8abc19a586b1) на гораздо более короткие целочисленные данные. Как это делается:
data.user = data.user.map({u: i for i, u in enumerate(data.user.unique())})
data.question = data.question.map(
{q: i for i, q in enumerate(data.question.unique())}
)
Теперь измеряем:
Avg time per iteration: 210 μs
Speedup over baseline: 168.5x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals):
0.0 result = []
0.1 for qs in qs_iter:
0.4 user_sets_for_qs = [users_who_answered_q[q] for q in qs]
71.6 answered_all = set.intersection(*user_sets_for_qs)
13.1 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all]
0.9 user_grand_total = [grand_totals[u] for u in answered_all]
13.9 r = np.corrcoef(qs_total, user_grand_total)[0, 1]
0.1 result.append({"qs": qs, "r": r})
0.0 return result
Видно, что операция с множествами пользователей в коде выше по-прежнему самая медленная. Вместо использования наборов ints мы переходим к использованию массива пользователей np.bool_
и применяем np.logical_and.reduce
для поиска пользователей, ответивших на все вопросы qs. (Обратите внимание, что np.bool_
использует целый байт для каждого элемента, но np.logical_and.reduce
все равно довольно быстр.) Это даёт нам значительное ускорение:
Avg time per iteration: 75 μs
Speedup over baseline: 466.7x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals):
0.0 result = []
0.1 for qs in qs_iter:
12.0 user_sets_for_qs = users_who_answered_q[qs, :] # numpy indexing
9.9 answered_all = np.logical_and.reduce(user_sets_for_qs)
10.7 answered_all = np.where(answered_all)[0]
33.7 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all]
2.6 user_grand_total = [grand_totals[u] for u in answered_all]
30.6 r = np.corrcoef(qs_total, user_grand_total)[0, 1]
0.2 result.append({"qs": qs, "r": r})
0.0 return result
Теперь самая медленная строка — вычисление qs_total
. Следуя примеру из оригинальной статьи, мы переходим к использованию плотного массива np.array для поиска оценок вместо словаря, и используем быструю индексацию NumPy для получения оценок.
Avg time per iteration: 56 μs
Speedup over baseline: 623.7x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_matrix, grand_totals):
0.0 result = []
0.2 for qs in qs_iter:
16.6 user_sets_for_qs = users_who_answered_q[qs, :]
14.0 answered_all = np.logical_and.reduce(user_sets_for_qs)
14.6 answered_all = np.where(answered_all)[0]
7.6 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
3.9 user_grand_total = [grand_totals[u] for u in answered_all]
42.7 r = np.corrcoef(qs_total, user_grand_total)[0, 1]
0.4 result.append({"qs": qs, "r": r})
0.0 return result
Самая медленная строка в коде — np.corrcoef
... Мы всеми силами пытаемся оптимизировать код, поэтому вот наша собственная реализация corrcoef, которая в данном случае будет в два раза быстрее:
def corrcoef(a: list[float], b: list[float]) -> float | None:
"""same as np.corrcoef(a, b)[0, 1]"""
n = len(a)
sum_a = sum(a)
sum_b = sum(b)
sum_ab = sum(a_i * b_i for a_i, b_i in zip(a, b))
sum_a_sq = sum(a_i**2 for a_i in a)
sum_b_sq = sum(b_i**2 for b_i in b)
num = n * sum_ab - sum_a * sum_b
den = sqrt(n * sum_a_sq - sum_a**2) * sqrt(n * sum_b_sq - sum_b**2)
if den == 0:
return None
return num / den
Получаем приличное ускорение:
Avg time per iteration: 43 μs
Speedup over baseline: 814.6x
% Time Line Contents
=====================
def compute_corrs(qs_iter, users_who_answered_q, score_matrix, grand_totals):
0.0 result = []
0.2 for qs in qs_iter:
21.5 user_sets_for_qs = users_who_answered_q[qs, :] # numpy indexing
18.7 answered_all = np.logical_and.reduce(user_sets_for_qs)
19.7 answered_all = np.where(answered_all)[0]
10.0 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
5.3 user_grand_total = [grand_totals[u] for u in answered_all]
24.1 r = corrcoef(qs_total, user_grand_total)
0.5 result.append({"qs": qs, "r": r})
0.0 return result
Мы ещё не закончили оптимизацию структур данных в приведённом выше коде, но давайте посмотрим, что нам даст внедрение на текущем этапе Numba? Речь о библиотеке в экосистеме Python, которая «переводит подмножество кода Python и NumPy в быстрый машинный код».
Чтобы иметь возможность использовать Numba, выполним два изменения:
Модификация 1. Передаем qs_combinations как массив numpy, вместо qs_iter
Numba не очень хорошо работает с itertools
или генераторами, поэтому мы заранее превращаем qs_iter
в массив NumPy, чтобы передать его функции. Влияние этого изменения на скорость выполнения (до добавления Numba) показано ниже.
Avg time per iteration: 42 μs
Speedup over baseline: 829.2x
Модификация 2. Массив результатов вместо списка
Вместо добавления в список, мы инициализируем массив и помещаем в него результаты. Вот как это изменение повлияло на скорость.
Avg time per iteration: 42 μs
Speedup over baseline: 833.8x
В итоге наш код выглядит так:
import numba
@numba.njit(parallel=False)
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
result = np.empty(len(qs_combinations), dtype=np.float64)
for i in numba.prange(len(qs_combinations)):
qs = qs_combinations[i]
user_sets_for_qs = users_who_answered_q[qs, :]
# numba doesn't support np.logical_and.reduce
answered_all = user_sets_for_qs[0]
for j in range(1, len(user_sets_for_qs)):
answered_all *= user_sets_for_qs[j]
answered_all = np.where(answered_all)[0]
qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
user_grand_total = grand_totals[answered_all]
result[i] = corrcoef_numba(qs_total, user_grand_total)
return result
(Обратите внимание, что мы также дополнили corrcoef
с помощью Numba, потому что функции, вызываемые внутри функции Numba, тоже должны быть скомпилированы.)
Результаты с параметром parallel=False
:
Avg time per iteration: 47 μs
Speedup over baseline: 742.2x
Результаты с параметром parallel=True
:
Avg time per iteration: 8.5 μs
Speedup over baseline: 4142.0x
Видно, что при значении parallel=False
код Numba работает немного медленнее, чем наш предыдущий код на Python, но когда мы включаем параллелизм, то начинаем использовать все ядра процессора (10 на нашей рабочей машине) — и это даёт хороший множитель скорости.
Однако мы теряем возможность использовать line_profiler на JIT-компилированном коде; (возможно, мы захотим обратиться к сгенерированному LLVM IR / сборке).
Пока отвлечёмся от Numba. В оригинальной статье для быстрого вычисления пользователей, ответивших на текущий qs
, используются bitsets — проверим, применим ли такой подход в нашем случае. Для реализации bitsets мы можем использовать массивы NumPy np.int64
и np.bitwise_and.reduce
. В отличие от использования массива np.bool_
, теперь мы используем отдельные биты в байте для представления сущностей в множестве. Обратите внимание, что для данного bitset может понадобиться несколько байтов, в зависимости от максимального количества элементов, которые нам нужны. Мы можем использовать быстрый bitwise_and для байтов каждого вопроса в qs
, чтобы найти пересечение множеств и, следовательно, количество пользователей, ответивших на все qs
.
Вот функции bitset
, которые мы будем использовать:
def bitset_create(size):
"""Initialise an empty bitset"""
size_in_int64 = int(np.ceil(size / 64))
return np.zeros(size_in_int64, dtype=np.int64)
def bitset_add(arr, pos):
"""Add an element to a bitset"""
int64_idx = pos // 64
pos_in_int64 = pos % 64
arr[int64_idx] |= np.int64(1) << np.int64(pos_in_int64)
def bitset_to_list(arr):
"""Convert a bitset back into a list of ints"""
result = []
for idx in range(arr.shape[0]):
if arr[idx] == 0:
continue
for pos in range(64):
if (arr[idx] & (np.int64(1) << np.int64(pos))) != 0:
result.append(idx * 64 + pos)
return np.array(result)
И мы можем инициализировать bitsets следующим образом:
users_who_answered_q = np.array(
[bitset_create(data.user.nunique()) for _ in range(data.question.nunique())]
)
for q, u in data[["question", "user"]].values:
bitset_add(users_who_answered_q[q], u)
Посмотрим, какое ускорение мы получим:
Avg time per iteration: 550 μs
Speedup over baseline: 64.2x
% Time Line Contents
=====================
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
0.0 num_qs = qs_combinations.shape[0]
0.0 bitset_size = users_who_answered_q[0].shape[0]
0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64)
0.0 for i in range(num_qs):
0.0 qs = qs_combinations[i]
0.3 user_sets_for_qs = users_who_answered_q[qs_combinations[i]]
0.4 answered_all = np.bitwise_and.reduce(user_sets_for_qs)
96.7 answered_all = bitset_to_list(answered_all)
0.6 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
0.0 user_grand_total = grand_totals[answered_all]
1.9 result[i] = corrcoef(qs_total, user_grand_total)
0.0 return result
Как видно, мы замедлились, поскольку операция bitset_to_list
занимает слишком много времени.
Преобразуем bitset_to_list
в скомпилированный код. Для этого мы можем добавить декоратор Numba:
@numba.njit
def bitset_to_list(arr):
result = []
for idx in range(arr.shape[0]):
if arr[idx] == 0:
continue
for pos in range(64):
if (arr[idx] & (np.int64(1) << np.int64(pos))) != 0:
result.append(idx * 64 + pos)
return np.array(result)
Измерим скорость:
Benchmark #14: bitsets, with numba on bitset_to_list
Using 1000 iterations...
Avg time per iteration: 19 μs
Speedup over baseline: 1801.2x
% Time Line Contents
=====================
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
0.0 num_qs = qs_combinations.shape[0]
0.0 bitset_size = users_who_answered_q[0].shape[0]
0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64)
0.3 for i in range(num_qs):
0.6 qs = qs_combinations[i]
8.1 user_sets_for_qs = users_who_answered_q[qs_combinations[i]]
11.8 answered_all = np.bitwise_and.reduce(user_sets_for_qs)
7.7 answered_all = bitset_to_list(answered_all)
16.2 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
1.1 user_grand_total = grand_totals[answered_all]
54.1 result[i] = corrcoef(qs_total, user_grand_total)
0.0 return result
Мы получили ускорение в 1800 раз по сравнению с исходным кодом. Вспомните, что оптимизация 7, до введения Numba, дала 814x. (Оптимизация 8 дала 4142x, но это было с parallel=True
во внутреннем цикле, так что показатель здесь нерелевантен.)
Строчка с corrcoef снова выделяется как слишком медленная. Навесим на corrcoef
декоратор Numba.
@numba.njit
def corrcoef_numba(a, b):
"""same as np.corrcoef(a, b)[0, 1]"""
n = len(a)
sum_a = sum(a)
sum_b = sum(b)
sum_ab = sum(a * b)
sum_a_sq = sum(a * a)
sum_b_sq = sum(b * b)
num = n * sum_ab - sum_a * sum_b
den = math.sqrt(n * sum_a_sq - sum_a**2) * math.sqrt(n * sum_b_sq - sum_b**2)
return np.nan if den == 0 else num / den
Смотрим результаты:
Avg time per iteration: 11 μs
Speedup over baseline: 3218.9x
% Time Line Contents
=====================
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
0.0 num_qs = qs_combinations.shape[0]
0.0 bitset_size = users_who_answered_q[0].shape[0]
0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64)
0.7 for i in range(num_qs):
1.5 qs = qs_combinations[i]
15.9 user_sets_for_qs = users_who_answered_q[qs_combinations[i]]
26.1 answered_all = np.bitwise_and.reduce(user_sets_for_qs)
16.1 answered_all = bitset_to_list(answered_all)
33.3 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
2.0 user_grand_total = grand_totals[answered_all]
4.5 result[i] = corrcoef_numba(qs_total, user_grand_total)
0.0 return result
Прекрасно, очередное значительное ускорение!
Вместо использования np.bitwise_and.reduce
мы вводим функцию bitwise_and
и применяем к ней jit-компиляцию.
@numba.njit
def bitset_and(arrays):
result = arrays[0].copy()
for i in range(1, len(arrays)):
result &= arrays[i]
return result
Benchmark #16: numba also on bitset_and
Using 1000 iterations...
Avg time per iteration: 8.9 μs
Speedup over baseline: 3956.7x
% Time Line Contents
=====================
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
0.1 num_qs = qs_combinations.shape[0]
0.0 bitset_size = users_who_answered_q[0].shape[0]
0.1 result = np.empty(qs_combinations.shape[0], dtype=np.float64)
1.0 for i in range(num_qs):
1.5 qs = qs_combinations[i]
18.4 user_sets_for_qs = users_who_answered_q[qs_combinations[i]]
16.1 answered_all = bitset_and(user_sets_for_qs)
17.9 answered_all = bitset_to_list(answered_all)
37.8 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
2.4 user_grand_total = grand_totals[answered_all]
4.8 result[i] = corrcoef_numba(qs_total, user_grand_total)
0.0 return result
Код стал значительно быстрее исходного, причём вычисления распределены довольно равномерно между несколькими строками цикла. Похоже, самая медленная строка выполняет индексацию NumPy, которая и так довольно быстрая. Давайте скомпилируем всю функцию с помощью Numba.
@numba.njit(parallel=False)
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
result = np.empty(len(qs_combinations), dtype=np.float64)
for i in numba.prange(len(qs_combinations)):
qs = qs_combinations[i]
user_sets_for_qs = users_who_answered_q[qs, :]
answered_all = user_sets_for_qs[0]
# numba doesn't support np.logical_and.reduce
for j in range(1, len(user_sets_for_qs)):
answered_all *= user_sets_for_qs[j]
answered_all = np.where(answered_all)[0]
qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1)
user_grand_total = grand_totals[answered_all]
result[i] = corrcoef_numba(qs_total, user_grand_total)
return result
Avg time per iteration: 4.2 μs
Speedup over baseline: 8353.2x
А теперь с параметром parallel=True
:
Avg time per iteration: 960 ns
Speedup over baseline: 36721.4x
Отлично, наш код уже в 36 000 раз быстрее исходного.
Куда двигаться дальше?... Ну, в коде все ещё достаточно много значений помещается в массивы, а затем передаётся по ним. Если мы посмотрим, как вычисляется corrcoef, то поймём, что нам не нужно создавать массивы answered_all
и user_grand_total
, мы можем накапливать значения по мере выполнения цикла.
Вот код (мы также включили некоторые оптимизации компилятора, например, отключили boundschecking
для массивов и включили fastmath
):
@numba.njit(boundscheck=False, fastmath=True, parallel=False, nogil=True)
def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals):
num_qs = qs_combinations.shape[0]
bitset_size = users_who_answered_q[0].shape[0]
corrs = np.empty(qs_combinations.shape[0], dtype=np.float64)
for i in numba.prange(num_qs):
# bitset will contain users who answered all questions in qs_array[i]
bitset = users_who_answered_q[qs_combinations[i, 0]].copy()
for q in qs_combinations[i, 1:]:
bitset &= users_who_answered_q[q]
# retrieve stats for the users to compute correlation
n = 0.0
sum_a = 0.0
sum_b = 0.0
sum_ab = 0.0
sum_a_sq = 0.0
sum_b_sq = 0.0
for idx in range(bitset_size):
if bitset[idx] != 0:
for pos in range(64):
if (bitset[idx] & (np.int64(1) << np.int64(pos))) != 0:
user_idx = idx * 64 + pos
score_for_qs = 0.0
for q in qs_combinations[i]:
score_for_qs += score_matrix[user_idx, q]
score_for_user = grand_totals[user_idx]
n += 1.0
sum_a += score_for_qs
sum_b += score_for_user
sum_ab += score_for_qs * score_for_user
sum_a_sq += score_for_qs * score_for_qs
sum_b_sq += score_for_user * score_for_user
num = n * sum_ab - sum_a * sum_b
den = np.sqrt(n * sum_a_sq - sum_a**2) * np.sqrt(n * sum_b_sq - sum_b**2)
corrs[i] = np.nan if den == 0 else num / den
return corrs
Посмотрим со значением parallel=False
.
Avg time per iteration: 1.7 μs
Speedup over baseline: 20850.5x
Результат можно сравнить с оптимизацией 12 с parallel=False
, которая показала 8353x.
Теперь с параметром parallel=True
.
Avg time per iteration: 210 ns
Speedup over baseline: 170476.3x
Мы достигли ускорения в 170 000 по сравнению с исходным кодом!
Благодаря Numba и NumPy мы получили большинство тех инструментов, которые сделали оптимизированный код Rust быстрым: в частности, bitsets, SIMD и параллелизм на уровне циклов. Сперва мы значительно ускорили оригинальный код на Python с помощью нескольких вспомогательных функций с JIT-компиляцией, в итоге использовали JIT-компиляцию повсеместно и оптимизировали код для этого. Мы использовали подход проб и ошибок, применяя профилирование, чтобы сосредоточить усилия на самых медленных строках кода. Мы показали, что можем использовать Numba для постепенного добавления JIT-компилированного кода в нашу кодовую базу Python. Мы можем сразу же добавить этот код в существующую кодовую базу Python. Однако мы не достигли 180 000-кратного ускорения оптимизированного кода Rust, и развернули собственную реализацию корреляции и bitsets, в то время как код Rust смог использовать библиотеки для них, оставаясь при этом быстрым.
Это было забавное упражнение, которое, надеюсь, продемонстрировало вам некоторые полезные инструменты в экосистеме Python.
Стал бы я рекомендовать один подход вместо другого? Нет, все зависит от конкретной ситуации.