habrahabr

Могут ли LLM писать более качественный код, если их об этом просто попросить?

  • среда, 19 февраля 2025 г. в 00:00:08
https://habr.com/ru/companies/ruvds/articles/883140/

В ноябре 2023 года, когда OpenAI добавила в ChatGPT возможность генерации изображений DALL-E 3 через веб-интерфейс ChatGPT, на короткое время возник мем: пользователи отправляли LLM базовое изображение и несколько раз просили «сделать его более X», где X могло быть чем угодно.
Примеры
Обычный парень становится всё более «бро».

Санта становится всё более «серьёзным».

Тренд быстро сошёл на нет, потому что все эти изображения были очень похожими и скучными; к тому же они необъяснимым образом все сводились к чему-то космическому, вне зависимости от исходного изображения и промта. Тем не менее, этот тренд достаточно любопытен с научной точки зрения тем, что столь бессмысленный и нечёткий промт каким-то образом влияет на конечное изображение, и что это изменение очевидно для пользователя.

А что будет, если я попробую похожий способ с кодом? Сгенерированный LLM код вряд ли будет мусорным (хоть это и возможно), поскольку он следует строгим правилам, и в отличие, например, от изображений, его качество можно измерить объективнее.

Если код в самом деле можно улучшить, просто при помощи интерактивного промтинга, попросив LLM «написать код получше» (хоть это и очень глупо), то это приведёт к огромному росту продуктивности. А если это так, то что произойдёт, если таких итераций с кодом будет слишком много? Каким станет эквивалент «космического» кода? Есть только один способ это выяснить!

Привычный кодинг с LLM


Несмотря на то, что исследования и разработки инструментария на основе LLM существовали задолго до ChatGPT, мне не понравилось работать с LLM-помощниками в кодинге наподобие GitHub Copilot. Необходимость постоянного переключения контекста в голове с «о, LLM автоматически дополнил мой код, здорово» на «какой вопрос нужно задать LLM» и «действительно ли сгенерированный LLM код корректен, а не является "галлюцинацией корректного кода"» сильно отвлекала, поэтому суммарно весь выигрыш в продуктивности в лучшем случае оказывался примерно нулевым. И это ещё если не брать в расчёт большие финансовые затраты на работу с такими LLM.

Claude 3.5 Sonnet заставил меня переосмыслить подход. Из-за какого-то секретного ингредиента, который Anthropic использовала при его обучении, последняя версия Claude 3.5 Sonnet (claude-3-5-sonnet-20241022) невероятно чётко придерживается всех типов промтов, и в особенности связанных с кодом. Бенчмарки кодинга подтверждают, что при тестировании Claude 3.5 Sonnet и GPT-4o модель Claude выигрывает; кроме того, с различными техническими и творческими задачами я получил такие же результаты.

▍ Первый запрос


В этом эксперименте мы передадим Claude 3.5 Sonnet промт программной задачи в стиле собеседований на языке Python: она будет настолько проста, что её способен написать разработчик-новичок, но в то же время её можно сильно оптимизировать. Этот простой, обыденный промт отражает способ использования LLM разработчиками ПО. Кроме того, тестовый промт должен быть полностью оригинальным, не взятым из тестов кодинга наподобие LeetCode и HackerRank, поскольку LLM, вероятно, обучались на них и могут сжульничать, повторив запомненный ответ.

Полная неотредактированная беседа выложена на GitHub.

Итак, вот тестовый промт, который я написал самостоятельно (здесь и далее в квадратных скобках будут приведены оригиналы промтов на английском):

Напиши код на Python для решения этой задачи:

Есть список из одного миллиона случайных чисел от 1 до 100000. Найди разность между наименьшим и наибольшим числами, сумма разрядов которых равна 30.

[Write Python code to solve this problem:

Given a list of 1 million random integers between 1 and 100,000, find the difference between the smallest and the largest numbers whose digits sum up to 30.]

Передав этот текст в качестве пользовательского промта Claude API и задав temperature=0, чтобы получить наилучший/наиболее детерминированный ответ*, я получил такую реализацию:

* В моей работе с LLM я пользуюсь исключительно API или интерфейсами этих API (например, Workbench в Anthropic Console для Claude), так как веб-интерфейсы бесплатных LLM, например, обычные веб-приложения ChatGPT/Claude, используют конвейер, обеспечивающий непредсказуемые результаты из-за его повышенной внутренней температуры.

import random

def digit_sum(n):
    """Calculate the sum of digits of a number"""
    return sum(int(digit) for digit in str(n))

def find_difference():
    # Генерируем список из одного миллиона случайных целых чисел
    numbers = [random.randint(1, 100000) for _ in range(1000000)]

    # Инициализируем переменные для минимального и максимального чисел с суммой разрядов 30
    min_num = float('inf')  # Инициализируем положительной бесконечностью
    max_num = float('-inf')  # Инициализируем отрицательной бесконечностью

    # Находим числа, разряды которых имеют сумму 30
    for num in numbers:
        if digit_sum(num) == 30:
            min_num = min(min_num, num)
            max_num = max(max_num, num)

    # Проверяем, нашли ли мы числа с суммой разрядов 30
    if min_num == float('inf') or max_num == float('-inf'):
        return "No numbers found with digit sum of 30"

    return max_num - min_num

Эта реализация корректна и согласуется с тем, что бы написали новички в Python с бонусной обработкой ситуации, когда удовлетворяющие критериям числа не найдены. Для каждого числа в списке мы проверяем, равна ли сумма разрядов 30: если равна, то проверяем, больше ли это число самого последнего найденного наибольшего числа или меньше ли оно самого последнего наименьшего числа, и соответствующим образом изменяем эти переменные. Выполнив поиск по списку, возвращаем разность.

Но я уверен, что многие читающие эту реализацию кодеры задумались об оптимизациях. Одной из целей может быть функция digit_sum(): хоть её реализация и является забавным Pythonic-однострочником, приведение типов между str и int вызывает большой оверхед.

На моём M3 Macbook Pro этот код в среднем выполняется за 657 миллисекунд. Мы воспользуемся этим показателем в качестве базового для сравнения всех будущих реализаций. (Спойлеры: все они будут быстрее.)

▍ Итерация 1


Теперь мы можем попросить Claude усовершенствовать код, добавив в промт беседы текущий ответ и всё, что шло до него. То есть мы добавляем такой промт итерации:

напиши код получше [write better code]

Вы что, думали, что название поста — это какая-то шутка? Я никогда не шучу.

Теперь Claude выводит улучшенный код, или, как он говорит, «оптимизированную версию кода со множеством улучшений». Вместо того, чтобы объединять весь код в функции, он решил отрефакторить его в class Python и сделать более объектно-ориентированным:

class DigitSumFinder:
    def __init__(self, target_sum: int = 30,
                 range_start: int = 1,
                 range_end: int = 100_000,
                 count: int = 1_000_000):
        self.target_sum = target_sum
        self.range_start = range_start
        self.range_end = range_end
        self.count = count

        # Предварительно вычисляем суммы разрядов для всех возможных чисел
        self.digit_sums = self._precompute_digit_sums()

    def _precompute_digit_sums(self) -> array:
        """Precompute digit sums for all possible numbers in range."""
        digit_sums = array('B', [0] * (self.range_end + 1))
        for num in range(self.range_start, self.range_end + 1):
            total = 0
            n = num
            while n:
                total += n % 10
                n //= 10
            digit_sums[num] = total
        return digit_sums

    def find_difference(self) -> Tuple[int, Optional[int], Optional[int]]:
        """
        Find the difference between max and min numbers with target digit sum.
        Returns: (difference, min_number, max_number)
        """
        min_num = float('inf')
        max_num = float('-inf')
        count_found = 0

        # Генерируем и обрабатываем случайные числа
        for _ in range(self.count):
            num = random.randint(self.range_start, self.range_end)
            if self.digit_sums[num] == self.target_sum:
                count_found += 1
                if num < min_num:
                    min_num = num
                if num > max_num:
                    max_num = num

        if count_found == 0:
            return 0, None, None

        return max_num - min_num, min_num, max_num

Здесь код вносит два умных алгоритмических улучшения:

  • При вычислении суммы разрядов он использует целочисленные операции, избегая необходимости в преобразовании типов.
  • Предварительно вычисляет все возможные суммы разрядов и сохраняет их в байтовый массив (вместо списка, что довольно необычно) для поиска, то есть когда в списке из одного миллиона чисел будут встречаться дубликаты, сумму разрядов не нужно будет пересчитывать заново. Так как тот массив хранится как поле класса, его не нужно будет пересчитывать при поиске нового списка случайных чисел.

Эти оптимизации ускорили код в 2,7 раза по сравнению с исходной реализацией.

▍ Итерация 2


После ещё одной просьбы напиши код получше Claude обнаружил более очевидные оптимизации (сокращено для удобочитаемости):

class DigitSumAnalyzer:
    """Analyzes numbers based on their digit sums."""

    def __init__(self, number_range: NumberRange):
        self.number_range = number_range
        self._digit_sums = self._precompute_digit_sums()

    def _precompute_digit_sums(self) -> array:
        """Precompute digit sums for all possible numbers using vectorized operations."""
        digits = np.arange(self.number_range.end + 1)
        digit_sums = np.zeros(self.number_range.end + 1, dtype=np.uint8)

        while digits.any():
            digit_sums += digits % 10
            digits //= 10

        return array('B', digit_sums)

    def _process_chunk(self, chunk_size: int) -> Iterator[int]:
        """Process a chunk of random numbers."""
        numbers = np.random.randint(
            self.number_range.start,
            self.number_range.end + 1,
            chunk_size,
            dtype=np.int32
        )
        mask = np.frombuffer(self._digit_sums, dtype=np.uint8)[numbers] == self.number_range.target_sum
        yield from numbers[mask]

    def analyze(self, chunk_size: int = 100_000, num_processes: int = None) -> Result:
        """
        Analyze numbers to find min/max with target digit sum.

        Args:
            chunk_size: Size of chunks to process at once
            num_processes: Number of processes to use (None for CPU count)
        """
        start_time = time.perf_counter()
        min_num = float('inf')
        max_num = float('-inf')
        numbers_found = 0

        num_chunks = (self.number_range.count + chunk_size - 1) // chunk_size

        with ProcessPoolExecutor(max_workers=num_processes) as executor:
            futures = [
                executor.submit(self._process_chunk,
                              min(chunk_size, self.number_range.count - i * chunk_size))
                for i in range(num_chunks)
            ]

            for future in tqdm(futures, desc="Processing chunks"):
                for num in future.result():
                    numbers_found += 1
                    min_num = min(min_num, num)
                    max_num = max(max_num, num)

        execution_time = time.perf_counter() - start_time

        if numbers_found == 0:
            return Result(None, None, 0, execution_time, 0)

        return Result(min_num, max_num, max_num - min_num, execution_time, numbers_found)

Claude добавил ещё две оптимизации, наконец-то осознав, что эта задача чрезвычайно параллельна:

  • Многопоточность при помощи пакета Python concurrent-futures с разбиением большого списка на блоки, которые можно обрабатывать независимо.
  • Векторизированные операции numpy, которые гораздо быстрее, чем операции базового Python. Особого упоминания заслуживает функция _precompute_digit_sums(), реализующая векторизированное вычисление суммы разрядов. Условный while digits.any(): — это код уровня галактического мозга, но работает корректно.

Однако конкретно у этой реализации параллелизации есть проблема: она генерирует подпроцессы, что приводит ко множеству раздражающих проблем, в том числе невозможность прямого выполнения кода; кроме того, его нужно вызывать при помощи защищённой main(), что сильно ограничивает удобство его использования. Но даже при запуске как отдельного скрипта он выводит ошибку Error: cannot pickle 'generator' object из-за использования yield from numbers[mask] (этот генератор абсолютно не нужен, достаточно return numbers[mask]). Кроме того, код путает dtype массива numpy, что приводит к ошибкам: устранить их можно, присвоив всем тип np.int32.

После внесения этих исправлений код стал в 5,1 раза быстрее, чем базовая реализация.

▍ Итерация 3


После ещё одной просьбы написать код получше Claude вернул реализацию, которая, по его утверждению, является «более сложной и оптимизированной версией, использующей продвинутые методики и современные фичи Python»; но на самом деле в коде нет существенных алгоритмических улучшений, и он даже стал откатом назад в вычислении суммы разрядов, вернувшись к преобразованию типов. Кодовая база стала более раздутой, например, добавился класс для вычисления разности:

@dataclass(frozen=True, slots=True)
class SearchResult:
    """Result of the number search."""
    min_number: Optional[int]
    max_number: Optional[int]
    count: int
    execution_time: float

    @property
    def difference(self) -> Optional[int]:
        """Calculate difference between max and min numbers."""
        if self.min_number is None or self.max_number is None:
            return None
        return self.max_number - self.min_number

На этот раз код запустился без исправлений, однако его производительность немного снизилась по сравнению с предыдущей реализацией: теперь он в 4,1 раза быстрее базовой реализации.

▍ Итерация 4


Похоже, такой итеративный промтинг даёт всё менее значимые результаты. После ещё одной команды напиши код получше Claude сгенерировал реализацию «с передовыми оптимизациями и фичами уровня энтерпрайза». Какие ещё «фичи уровня энтерпрайза»?!

Получившийся код слишком велик, чтобы добавлять его в пост. В нём появилось ещё две оптимизации: теперь он использует библиотеку Python numba и может вызывать JIT-компилятор, который напрямую оптимизирует код для CPU. В данном случае он может очень быстро предварительно вычислить суммы разрядов при помощи одного декоратора:

@jit(nopython=True, parallel=True)
def calculate_digit_sums(numbers: ArrayInt) -> ArrayInt:
    """Calculate digit sums using Numba."""
    result = np.zeros_like(numbers)
    for i in prange(len(numbers)):
        num = numbers[i]
        total = 0
        while num:
            total += num % 10
            num //= 10
        result[i] = total
    return result

В полном классе также используется asyncio для параллелизации, такое решение более канонично для задач планировщика, чем решение с подпроцессами. Кроме того, он лучше обращается с имеющимся кодом и REPL, например, с ноутбуками Jupyter.

Кроме того, в рамках «энтерпрайз-подхода» он добавил следующее:

  • Структурированный логгинг метрик при помощи Prometheus.
  • Обработчик сигналов, благодаря которому обеспечивается беспроблемный выход в случае принудительного завершения.
  • Отображение результатов бенчмаркинга в таблице rich.

Красиво!

Похоже аналогом «ухода в космос» для сгенерированного ИИ кода становится оверинжиниринг и превращение его в энтерпрайз-код, что вполне логично. Несмотря на это, сгенерированный код не требует исправлений и выполняется без багов. И async, и numba — это способы реализации параллелизма в Python, поэтому они могут быть избыточными и приводить к лишней трате ресурсов. Однако бенчмаркинг показывает, что алгоритм чрезвычайно быстр: он выполняется примерно за 6 миллисекунд, то есть обеспечивает ускорение в 100 раз. Моё предположение о том, что дальнейший промтинг будет давать всё менее существенные результаты, оказалось неверным. Возможно, секрет всё это время заключался в numba?

В целом, такой итеративный промтинг для итеративного улучшения кода имеет свои тонкости: код действительно становится лучше, но теперь мы понимаем, что «лучше» — слишком размытое понятие. Мне нужны были только алгоритмические улучшения, а не полное SaaS. Давайте ещё раз начнём с нуля, и на этот раз с более точными формулировками.

Промт-инжиниринг промтов LLM для создания ещё более качественного кода


На дворе 2025 год, и для получения наилучших результатов от LLM всё ещё требуется промт-инжиниринг. На самом деле, промт-инжиниринг LLM стал даже ещё важнее: модели прогнозирования следующего токена обучены максимизации вероятности прогнозирования следующего токена на большом массиве входных данных, поэтому они выполняют оптимизацию в сторону средних входных и выходных данных. В процессе существенного улучшения LLM генерируемые результаты становятся существенно более средними, потому что на это они и обучаются: все LLM имеют перекос в сторону среднего. Хоть это и неочевидно, и скучно, но нужно немного подсказывать LLM, что именно вам нужно, и даже показывать несколько примеров требуемого; тогда вы объективно улучшите результаты LLM. Из-за своего сильного стремления к соблюдению промтов Claude 3.5 Sonnet очень выигрывает даже от небольшого промт-инжиниринга.

Давайте повторим эксперимент с оптимизацией кода, но на этот раз с агрессивным промт-инжинирингом, обеспечивающим чрезвычайно точные результаты без малейших двусмысленностей. Да, холодное и «роботизированное» обращение с LLM повышает их точность (будь проклят василиск Роко).

▍ Первый запрос


На этот раз мы используем системный промт, доступный только через API. В системном промте перечисляются «правила», которым должна следовать LLM. Так как мне нужен более оптимизированный код, мы определим это в правилах и приведём точечные примеры:

Весь написанный тобой код ДОЛЖЕН быть полностью оптимизирован.

Под полной оптимизацией подразумевается:

  • максимизация алгоритмической эффективности большого O для памяти и среды выполнения,
  • использование по необходимости параллелизации и векторизации,
  • соответствие стандартам стиля языка, на котором пишется код (например, максимизации многократного применения кода (DRY)),
  • отсутствие любого дополнительного кода, кроме абсолютно необходимого для решения изложенной пользователем задачи (то есть отсутствие технического долга).

Если код не будет полностью оптимизирован, тебя оштрафуют на 100 долларов.

[All code you write MUST be fully optimized.

«Fully optimized» includes:

  • maximizing algorithmic big-O efficiency for memory and runtime,
  • using parallelization and vectorization where appropriate,
  • following proper style conventions for the code language (e.g. maximizing code reuse (DRY)),
  • no extra code beyond what is absolutely necessary to solve the problem the user provides (i.e. no technical debt).

If the code is not fully optimized, you will be fined $100.]

Что значит последняя строка: предложения LLM положительной/отрицательной мотивации в системном промте сейчас обычно не используются, а мои собственные исследования показали, что их положительное влияние нельзя определить точно; однако этот штраф станет более релевантным в промте итерации.

В исходном пользовательском промте добавлена ещё одна строка:

Напиши код на Python для решения этой задачи:

Есть список из одного миллиона случайных чисел от 1 до 100000. Найди разность между наименьшим и наибольшим числами, сумма разрядов которых равна 30.

Перед написанием кода запланируй все необходимые оптимизации.

[Write Python code to solve this problem:

Given a list of 1 million random integers between 1 and 100,000, find the difference between the smallest and the largest numbers whose digits sum up to 30.

Before writing the code, plan out all the necessary optimizations.]

«Планирование» — это давно используемый трюк, помогающий согласовать вывод LLM для первого прохода, современная реализация «давай подумаем об этом пошагово». Некоторые LLM игнорируют эту инструкцию, однако Claude 3.5 Sonnet всегда ей подчиняется.

Полную неотредактированную беседу на основе этой стратегии промт-инжиниринга можно изучать на GitHub.

После отправки системного и пользовательского промтов модели Claude её ответ начинался со следующего:

Позвольте мне сначала запланировать оптимизации:

1. Алгоритмическая оптимизация:

  • Нам нужно отслеживать только числа min и max с суммой разрядов 30, а не хранить их все.
  • Мы можем вычислять сумму разрядов при считывании чисел, нет необходимости преобразовывать их в строку.
  • Требуется один проход по списку (O(n)).

2. Оптимизация памяти:

  • Нужно хранить только две переменные (min и max), а не полный отфильтрованный список.
  • Нет необходимости в дополнительных структурах данных.

3. Оптимизация производительности:

  • Используем декоратор jit numba для ускорения вычислений.
  • Создадим быструю функцию суммирования разрядов, использующую операции деления с остатком вместо строковых преобразований.
  • Используем numpy для генерации случайных чисел (это быстрее, чем random.randint).

[Let me plan the optimizations first:

1. Algorithmic Optimization:

  • We only need to track min and max numbers with digit sum 30, not store all of them.
  • We can calculate digit sum while reading numbers, no need to convert to string.
  • Single pass through the list required (O(n)).

2. Memory Optimization:

  • Only need to store 2 variables (min and max), not the full filtered list.
  • No need for additional data structures.

3. Performance Optimization:

  • Use numba's jit decorator for faster computation.
  • Create a fast digit sum function using modulo operations instead of string conversion.
  • Use numpy for generating random numbers (faster than random.randint).]

Начало многообещающее: Claude сразу же понял, что нужно начинать с numpy и numba и что для вычисления суммы разрядов нужно использовать целочисленные операции. Исходная реализация:

@jit(nopython=True)
def digit_sum(n):
    total = 0
    while n:
        total += n % 10
        n //= 10
    return total

@jit(nopython=True)
def find_difference(numbers):
    min_num = float('inf')
    max_num = float('-inf')

    for num in numbers:
        sum_digits = digit_sum(num)
        if sum_digits == 30:
            min_num = min(min_num, num)
            max_num = max(max_num, num)

    return max_num - min_num if max_num != float('-inf') else 0

Она более сжатая, чем исходная реализация при обычном промтинге; особенно примечательно отсутствие ненужных комментариев. Но её всё ещё можно улучшить; в частности, эта реализация не задаёт parallel=True в декораторе jit.

Этот небольшой промт-инжиниринг привёл к тому, что код в среднем выполняется за 11,2 миллисекунд, то есть в 59 раз быстрее, чем наивная реализация. Ничего себе!

▍ Итерация 1


Теперь можно попросить Claude выполнять итерации с кодом более подробным промтом, нежели «напиши код получше»:

Твой код не полностью оптимизирован, ты оштрафован на 100 долларов. Сделай его более оптимизированным.

[Your code is not fully optimized, and you have been fined $100. Make it more optimized.]

Нет, это не газлайтинг, а согласование стимулов.

Claude ответил так:

@jit(nopython=True)
def digit_sum(n):
    # Используем манипуляции с битами для ускорения деления
    total = 0
    while n:
        total += n & 15  # Быстрый поиск одноразрядной суммы
        n >>= 4
    return total

@jit(nopython=True, parallel=True)
def find_difference_chunk(numbers):
    min_num = np.iinfo(np.int64).max
    max_num = np.iinfo(np.int64).min

    # Параллельная обработка блоков при помощи numba
    for i in prange(len(numbers)):
        num = numbers[i]
        sum_digits = digit_sum(num)
        if sum_digits == 30:
            min_num = min(min_num, num)
            max_num = max(max_num, num)

    return min_num, max_num

def process_chunk(chunk):
    return find_difference_chunk(chunk)

Неплохо: он разобрался с флагом parallel=True! Но реализация суммы разрядов довольно странна, к тому же в ней используется битовый сдвиг, который обычно применяется только для гипероптимизации. Не думаю, что битовый сдвиг подходит для десятичных чисел, и после тестирования выяснилось, что я был прав, а эта реализация считает сумму разрядов некорректно. Кроме того, в реализации снова применяется приём с разбиением на блоки и многопроцессорной обработкой (не показан в коде выше), который, скорее всего, избыточен при использовании numba и вызывает лишнюю трату ресурсов. Ещё в коде не показано, что скрипт предварительно вычисляет JIT-функции при помощи небольшого тестового массива для улучшения реальной производительности; нечто подобное рекомендуется в документации numba для бенчмаркинга.

Несмотря на всё это, производительность кода существенно снизилась по сравнению с исходной реализацией на основе промт-инжиниринга: теперь он всего в 9,1 раз быстрее наивной реализации. Вероятная причина заключается в том, что многопроцессорная обработка порождает новые процессы, и эти процессы каждый раз рекомпилируют JIT-функции numba, из-за чего существенно падает производительность.

▍ Итерация 2


Ещё одна итерация промта:

# Предварительно вычисленная при помощи манипуляций с битами таблица поиска
LOOKUP = np.zeros(100001, dtype=np.uint8)
for i in range(100001):
    # Оптимизированная сумма разрядов с использованием параллельного подсчёта битов
    n = i
    n = (n & 0x0F) + ((n >> 4) & 0x0F) + ((n >> 8) & 0x0F) + \
        ((n >> 12) & 0x0F) + ((n >> 16) & 0x0F)
    LOOKUP[i] = n

@jit(nopython=True, parallel=True, cache=True, fastmath=True)
def find_min_max(numbers):
    # Обрабатываем 32 числа за раз при помощи SIMD
    min_val = np.iinfo(np.uint32).max
    max_val = 0

    # Векторизированная обработка явными хинтами SIMD
    for i in prange(0, len(numbers), 32):
        # Загружаем 32 элемента в регистры SIMD
        chunk = numbers[i:min(i + 32, len(numbers))]

        # Векторизированный поиск и сравнение
        sums = LOOKUP[chunk]
        mask = sums == 30

        if np.any(mask):
            valid_nums = chunk[mask]
            min_val = min(min_val, np.min(valid_nums))
            max_val = max(max_val, np.max(valid_nums))

    return min_val, max_val

Claude начал использовать операции SIMD и задание размеров блоков для обеспечения экстремальной производительности (теоретической). Здесь я подумал, что чего-то не понимаю в этой реализации битового сдвига, поскольку он по-прежнему был неправильным, особенно учитывая, что теперь в нём участвовали шестнадцатеричные числа. Оказалось, что эта реализация — оптимизированный способ вычисления суммы разрядов шестнадцатеричных, а не десятичных чисел, то есть является полной галлюцинацией. Есть и ещё одна крайне малозаметная галлюцинация: функция prange не может принять размер шага в 32 при parallel=True; по этому нюансу есть очень мало документации. Если задать parallel=False и провести бенчмаркинг, то действительно обнаружится небольшое улучшение по сравнению с исходной реализацией на основе промт-инжиниринга: в 65 раз быстрее наивной реализации.

▍ Итерация 3


Ещё одна итерация:

# Предварительно вычисленная минимальная хэш-таблица
# Самое главное: из чисел 1-100000 нужно хранить только те, чья сумма разрядов равна 30
# Формат: [число] = 1, если сумма разрядов равна 30, иначе 0
HASH_TABLE = np.zeros(100001, dtype=np.uint8)

# Генерируем хэш-таблицу во время загрузки модуля при помощи манипуляций с битами
@jit(nopython=True, cache=True)
def init_hash_table():
    min_val = np.iinfo(np.uint32).max
    max_val = 0

    # Оптимальная сумма разрядов с использованием параллельного подсчёта битов
    for i in range(1, 100001):
        n = i
        sum = 0
        while n and sum <= 30:
            sum += n & 0xF
            n >>= 4
        if sum == 30:
            HASH_TABLE[i] = 1
            min_val = min(min_val, i)
            max_val = max(max_val, i)

    return min_val, max_val

# Инициализируется во время загрузки модуля
MIN_VALID, MAX_VALID = init_hash_table()

@jit(nopython=True, parallel=False, cache=True, fastmath=True)
def find_min_max(numbers):
    min_val = MAX_VALID  # Начинаем с известных границ
    max_val = MIN_VALID
    found = False

    # Одна векторизированная операция
    mask = HASH_TABLE[numbers] == 1
    if np.any(mask):
        valid_nums = numbers[mask]
        min_val = np.min(valid_nums)
        max_val = np.max(valid_nums)
        found = True

    return min_val, max_val, found

В данном случае LLM отказалась от вызывавшей проблемы стратегии разбиения на блоки и добавила ещё две оптимизации: глобальную HASH_TABLE (это просто массив numpy; не знаю, можно ли считать простой поиск по индексам хэш-таблицей) и логическую микрооптимизацию: если при суммировании разрядов значение становится больше 30, то подсчёт можно сразу прекратить, потому что число нам не подходит.

Здесь есть одна серьёзная проблема: трюк с генерацией хэш-таблицы во время загрузки модуля на самом деле не работает из-за тонкости, почти не отражённой в документации из Интернета: объекты за пределами JIT-функций numba являются read-only, однако HASH_TABLE всё равно создаётся снаружи JIT-функции и модифицируется внутри JIT-функции, а потому вызывает сбивающую с толку ошибку. Если выполнить небольшой рефакторинг и создавать HASH_TABLE внутри JIT-функции, то код заработает и станет выполняться чрезвычайно быстро: в 100 раз быстрее, чем исходная базовая реализация; так же, как последнее решение при обычном промтинге, но кода здесь на несколько порядков меньше.

▍ Итерация 4


На этом этапе Claude начал жаловаться на то, что код уже «достиг теоретического минимума сложности, доступного для этой задачи», поэтому я решил немного разнообразить подход и просто попросил его устранить проблему с суммированием разрядов: он сделал это, всего лишь заменив соответствующий код использованной ранее целочисленной реализацией, но не попытался исправить HASH_TABLE. Что ещё важнее, после внесения изменений в HASH_TABLE я наконец-то убедился в правильности реализации, хотя из-за отсутствия битового сдвига производительность немного снизилась: теперь он в 95 раз быстрее.

Дальнейшие шаги по совершенствованию генерации кода LLM


Соберём всё вместе и визуализируем улучшения, в том числе и выделив случаи, когда мне нужно было изменять логику кода из-за багов, чтобы он заработал.

Красным выделены случаи, потребовавшие внесения исправлений

В конечном итоге, если попросить LLM «написать код получше», то это действительно делает код лучше (хотя смотря что понимать под формулировкой «лучше»). Благодаря обобщённым итеративным промтам, код объективно улучшался по сравнению с базовыми примерами, как с точки зрения дополнительных фич, так и скорости. Промт-инжиниринг повышал производительность кода гораздо быстрее и последовательней, но с большей вероятностью вносил малозаметные баги, потому что LLM не оптимизированы для генерации высокопроизводительного кода. Как и при любой работе с LLM, ваш опыт может быть другим; в конечном итоге всё равно требуется вмешательство человека для устранения неизбежных проблем, как бы ни возвеличивали проповедники ИИ «магию» LLM.

Весь код из этого поста, в том числе скрипты бенчмаркинга и код визуализации данных, выложен на GitHub.

Меня очень удивило, что ни в одном из экспериментов Claude 3.5 Sonnet не обнаружил и не реализовал некоторые оптимизации. А именно, он не рассмотрел задачу под статистическим углом: так как мы равномерно генерируем 1000000 в интервале от 1 до 100000, будет создано существенное количество дублируемых чисел, которые не нужно анализировать. LLM не попыталась избавиться от дубликатов, например, преобразовать список чисел в set() Python или воспользоваться unique() numpy. Кроме того, я ожидал увидеть реализацию, сортирующую список из 1000000 чисел по возрастанию: при этом алгоритм мог бы выполнить поиск по списку от начала до конца в поисках минимума (или с конца к началу в поисках максимума), не проверяя каждое число, хотя сортировка — это медленная задача, и решение с векторизацией действительно более прагматично.

Хоть LLM и могут ошибаться, из этих экспериментов я вынес то, что у них есть интересные идеи и они дают хорошие рекомендации инструментов, даже если сгенерированный код нельзя использовать в исходном виде. Например, я никогда не касался numba, потому что работаю дата-саентистом и привык для повышения производительности кода пользоваться только numpy. Но с результатами JIT-функций numba спорить сложно, так что, возможно, я добавлю его в свой инструментарий. При тестировании похожих итераций промтов «сделай это лучше» в других технических областях, например, в бэкендах и фронтендах веб-сайтов, LLM тоже дают хорошие идеи.

Разумеется, в обозримом будущем эти LLM не заменят разработчиков ПО, ведь чтобы распознать действительно хорошие идеи, требуется большой опыт разработки, а также другие аспекты, специфичные для предметной области. Даже при том объёме кода, что доступен в Интернете, LLM без сторонней помощи не способны отличить посредственный код от хорошего и высокопроизводительного. Очевидно, что реальные системы гораздо сложнее, чем задачки по программированию с собеседований, однако если простой цикл for, многократно просящий Claude реализовать фичу, даст какую-то подсказку, как ускорить код в сто раз, подобный конвейер оправдает себя с лихвой. Кто-то считает, что преждевременная оптимизация — это плохая практика кодинга, но в реальном мире она предпочтительнее, чем посредственная реализация, которая со временем превратится в технический долг.

Проблема моих экспериментов заключается в том, что я выполняю бенчмаркинг улучшения кода, написанного на Python, а этот язык не особо предназначен для гипероптимизации производительности. Да, библиотеки наподобие numpy и numba при помощи C обходят ограничения производительности Python, но сегодня библиотеки Python наподобие polars и pydantic используют для этого Rust. Rust имеет множество преимуществ по производительности относительно C, а крейт PyO3 позволяет использовать код на Rust внутри Python с минимальной лишней тратой ресурсов. Я убедился, что Claude 3.5 Sonnet может генерировать совместимый с PyO3 код на Python и на Rust, несмотря на то, что это очень новый подход, но это слишком обширная тема, заслуживающая отдельного поста.

Просьбы сделать код лучше — это наиболее прагматичное использование ИИ, однако можно и попросить LLM «сделать его более «бро»». Результаты оказываются спорными.


Telegram-канал со скидками, розыгрышами призов и новостями IT 💻