Как поделить не деля или оптимизация деления компиляторам(и)
- пятница, 9 августа 2024 г. в 00:00:09
Если вы никогда не пробовали смотреть как код на C++ разворачивается компилятором в код Assembly – вас ждёт много сюрпризов, причём, не нужно смотреть какой-то замудренный исходный код полный templates или других сложных конструкций: рассмотрите следующий snippet:
uint8_t div10(uint8_t x)
{
return x/10;
}
Конечно, я это уже сделал, и приведу результаты прямо здесь, хотя, советую и самим сходить на замечательный ресурс https://godbolt.org/ – выставить там, например, x86-64 gcc 14.1, добавить опцию -O2 и убедиться в том, что результаты крайне интересные:
div10(unsigned char):
mov edx, -51
mov eax, edi
mul dl
shr ax, 11
ret
Действительно, чего совсем не видно в этом куске кода так это инструкции div, которая несомненно существует и собственно осуществляет деление на x86 архитектуре. Зато – откуда-то взялась магическая константа, ещё и отрицательная! Ну ладно, на самом деле она положительная и равна 205 (т.к. 205+51=256 = mod 256), так что этот вопрос закрыт, но как же всё таки это всё работает?
Работает это следующим образом после умножения на 205 мы делаем суммарный сдвиг вправо на 11 разрядов, что эквивалентно делению на 2^11 = 2048 с отбрасыванием остатка. Последнее – очень важно. Заметим, что 205/2048 = 0.10009765625, иначе говоря чуть-чуть больше чем 0.1, если вы умножите последнее число на калькуляторе (или в Python) на 255 вы получите 25.524, иначе говоря, после отбрасывания дробной части – это правильный ответ.
А можно было взять число чуть-чуть меньше чем 0.1? Нет – тривиально умножив на 10 мы бы получили "чуть-чуть меньше чем 1", и просто отбрасывая дробную часть получили бы 0 – немного не тот результат которого ожидаешь деля 10 на 10. А насколько чуть-чуть больше должно быть число, что б трюк сработал? Для uint8_t максимальное число – 255, соответственно, число должно отличаться от 1/10 не больше чем на 1/256. А с любым ли числом (хотя бы из базового типа uint8_t это возможно) – да, с любым.
Тут, я напомню про такие математические функции как floor, ceil и trunc: мы работаем только с положительными числами, поэтому, без лишних сложностей floor=trunc и просто отбрасывает дробную часть, a ceil всегда округляет вверх. Пусть d – наш делитель, N – наша разрядность (у нас 8), мы хотим получить выражение вида: m / 2 ^ (N + k) такое что оно чуть больше 1/d (тут обо всём думаем в вещественных числах), а если точнее то:
m / (2 ^ (N + k)) - (1 / d) >= 1 / 2^N (просто обобщение предыдущего абзаца).
Утверждение: m = ceil(2^(ceil(log(d)) + N) / d), k = ceil(log(d)) – как это всё понять, что тут такое написано? ... Это можно понять, например, таким образом: число внизу по условию это степень 2ки, и эта степень точно не меньше чем 2^N, далее предположим, я как-то нашёл k – как теперь по заданному k как подобрать m ? Я уже знаю знаменатель – 2 ^ (N + k), я хочу подобрать целое число, что б оно было чуть больше чем 1/d, что если я возьму просто trunc(2 ^ (N + k) / d) ?
Давайте в числах из примера: я хочу делить uint8_t на 10: т.е. N=8, d = 10, а k пока возьмем равным 2, например, тогда trunc(2 ^ (N + k) / d) = trunk(2^10 / 10) = trunk (1024 /10) = 102, а всё выражение m / (2 ^ (N + k)) = 102/ 1024 = 0,099609375, в общем, близко, но чуть меньше чем нам надо. А вот ceil - всегда будет больше потому что: ceil(x) >= x для положительных чисел. Я ещё не сделал эту оговорку, но сделаю, что d > 1 и d не является точной степенью двойки, то есть точных делений у нас тут не будет, вторая оговорка trunc (x/y) в C++ это просто обычное целочисленное деление.
Итак, я надеюсь, к этому моменту стало понятно, что m в том виде как я ищу действительно аппроксимирует 1/d сверху. Теперь посмотрим почему я выбрал такое k: ceil(2^(ceil(log(d)) + N) / d, вот здесь делая внешний ceil я прибавляю к числителю число не более чем d – потому что остаток не может быть ни больше ни равен d, и понятно что 2^(ceil(log(d)) > d.
Думаю, время показать код:
template<typename InputInteger, typename OutputInteger>
std::pair<OutputInteger, uint8_t> getDivisionMultiplier(InputInteger divisor)
{
if (!divisor)
{
throw std::invalid_argument("Division by zero is impossible");
}
if (divisor == 1)
{
return {1,0};
}
constexpr uint8_t n = sizeof(InputInteger) * CHAR_BIT;
const double log_d_temp = std::log2(static_cast<double>(divisor));
const uint8_t log_d = std::ceil(log_d_temp);
if (log_d == std::floor(log_d_temp))
{
return {1, log_d};
}
OutputInteger res = std::ceil(static_cast<double>(static_cast<OutputInteger>(1) << (log_d + n)) / double(divisor));
return {res, n + log_d};
}
// somewhere in the main function
for(uint8_t divisor = 1; divisor > 0; divisor++)
{
auto [multiplier, shift] = getDivisionMultiplier<uint8_t, uint16_t>(divisor);
for(uint8_t numenator = 1; numenator > 0; numenator++)
{
uint32_t res = static_cast<uint32_t>(numenator * multiplier) >> shift;
if (res != numenator / divisor)
{
std::cout << "panic: did something went wrong?" << std::endl;
}
}
}
Наверное, очевидно, что фразу про панику мы никогда не увидим – значит всё? Работает и статью пора заканчивать? – Нет.
Компилятор, точно делает по-другому, действительно, выведем, что код, приведенный выше дает для d=10:
auto p = getDivisionMultiplier<uint8_t, uint16_t>(static_cast<uint8_t>(10));
std::cout << p.first << " " << (uint16_t)(p.second) << std::endl;
410 12
А из куска Assembly из начала статьи понятно, что должно было быть 205 и 11... Дело в том, что gcc хочет получить константу m того же размера, что и d и использует для этого более хитрый алгоритм (если вы приглядитесь к моему коду выше – я предусмотрительно использовал тип uint16_t).
Алгоритм gcc основан на подсчёте пары констант m_low = trunc(2^(ceil(log(d)) + N) / d) и m_high = trunc(2^(ceil(log(d)) + N) + 2^(ceil(log(d)) / d), тут важно понимать что m_low – не может быть настоящим m (показывал это выше), а m_high – может (извините, я не буду это тоже пытаться тут расписать как и все оставшиеся выкладки), но m_high, вообще говоря, даже больше m из моего наивного алгоритма. Да, но, и m_low и m_high – точно меньше чем 2^(N + 1), то есть для нашего случая это было бы 9 бит, и ещё точно в целых числах m_high > m_low (без равенства). Что же делает этот алгоритм дальше, чтобы сделать из 9-ти битного числа 8-ми битное число? Правильно, он просто сдвинет m_high на разряд вправо: тут важно понимать, что таким образом он делает trunc(m_high/2), и конечно параллельно, он уменьшит k (число сдвигов направо в конечном коде), но ...если число нечетное, это же не тоже самое? Да, в этом случае есть риск что мы начнём аппроксимировать снизу...поэтому, компилятор так делает и trunc(m_low / 2) и сравнивает эти два числа, потому что m_low уже изначально слишком мала – если m_high деградировала до неё – то нельзя дальше уменьшать m_high и соответствующий сдвиг. А вот и код делающий это:
template<typename InputInteger>
std::tuple<InputInteger, uint8_t, bool> getDivisionMultiplier(InputInteger divisor)
{
if (!divisor)
{
throw std::invalid_argument("Division by zero is impossible");
}
if (divisor == 1)
{
return {1, 0, false};
}
constexpr uint8_t n = sizeof(InputInteger) * CHAR_BIT;
const double log_d_temp = std::log2(static_cast<double>(divisor));
const uint8_t log_d = std::ceil(log_d_temp);
if (log_d == std::floor(log_d_temp))
{
return {1, log_d, false};
}
uint64_t temp_low = (1UL << (log_d + n));
uint64_t temp_hight = (1UL << log_d) | (1UL << (log_d + n));
temp_hight /= divisor;
temp_low /= divisor;
uint8_t additionla_shift = log_d;
while (additionla_shift)
{
if (temp_low /2 >= temp_hight/2)
{
break;
}
temp_low /= 2;
temp_hight /= 2;
--additionla_shift;
}
return {temp_hight, n + additionla_shift, temp_hight > std::numeric_limits<uint8_t>::max()};
}
// somewhere in the main function
auto [coeff, shift, _] = getDivisionMultiplier(static_cast<uint8_t>(10));
std::cout << (uint16_t)coeff << " " << (uint16_t)(shift) << std::endl;
205 11
Успех! Коэффициенты совпали! Всё ли на этом? Увы: нет гарантии, что цикл, который редуцирует temp_hight, не выйдет сразу после первой итерации. То есть нет гарантии получить 8-ми битное число, но тогда у нас возможен срез по модулю в return. Gcc умеет успешно использовать это срезанное значение – но это явно отдельная тема.
А если очень интересно или просто не терпится, то я оставлю ссылки, на которые я опирался