Подробное рассмотрение важнейшего компонента внимания скалярного произведения, языкового моделирования и полиномиальной логистической регрессии
Делиться

В предыдущей статье этой серии мы рассмотрели распространённую во всех областях информатики операцию: умножение матриц. Она широко используется в нейронных сетях для вычисления активации линейных слоёв. Однако сами активации сложно интерпретировать, поскольку их значения и статистические характеристики (среднее значение, дисперсия, минимальная и максимальная амплитуда) могут сильно различаться от слоя к слою. Это одна из причин, по которой мы используем функции активации, например, логистическую функцию (сигмоиду), которая проецирует любое действительное число в диапазоне [0; 1].
Функция softmax, также известная как нормализованная экспоненциальная функция, представляет собой многомерное обобщение сигмоиды. Она преобразует вектор исходных оценок (логитов) в распределение вероятностей по M классам. Её можно интерпретировать как средневзвешенное значение , которое ведёт себя как гладкая функция и легко дифференцируется . Она является важнейшим компонентом скалярного произведения внимания, языкового моделирования и мультиномиальной логистической регрессии.
В этой статье мы рассмотрим:
- Реализация эффективного ядра softmax в Triton.
- Реализация обратного прохода (автограда).
- Оптимизация: модификаторы кэша и автонастройка.
Если вы еще не знакомы с Triton, обратитесь к предыдущим статьям!
Изучаем Triton по одному ядру за раз: сложение векторов
Изучаем Triton по одному ядру за раз: умножение матриц
Отказ от ответственности: все иллюстрации и анимации сделаны автором, если не указано иное.
Определение
Softmax определяется следующим образом:

Нормализация гарантирует, что сумма вектора равна 1 , поэтому ее можно интерпретировать как допустимое распределение вероятностей.
Обратите внимание, что эта формулировка softmax крайне чувствительна к переполнению . Напомним, что максимальное значение, которое может представить стандартный float16, — это 65 504 , что примерно равно exp(11) . Это означает, что любое входное значение, превышающее ~11, приведёт к выходу exp(z_i) за пределы представимого диапазона, что приведёт к переполнению .
Распространенный прием для смягчения этой проблемы — вычитание максимального значения входного вектора из каждого элемента таким образом, чтобы новый максимум был равен 0 до возведения в степень и 1 после.

Наивная реализация
Как видите, вычисление softmax включает две операции редукции : max и sum . Наивный алгоритм требует трёх отдельных проходов по входному вектору. Сначала вычисляется максимум, затем сумма и, наконец, нормализованные выходные данные.
Вот как выглядит наивная реализация Numpy:
В этой серии Triton повторяющаяся тема — минимизация доступа к глобальной памяти с высокой задержкой. Наша текущая реализация Numpy требует трёх отдельных операций чтения всего входного вектора, что крайне неэффективно.
Онлайн Софтмакс
К счастью, мы можем использовать хитрый трюк, известный как онлайн-softmax , чтобы объединить шаги max и sum, сократив количество чтений памяти до 2 .
Сначала рекурсивно определим сумму экспонент. В следующем наборе равенств m_i обозначает максимум по x до i -го индекса.

Это равенство позволяет нам итеративно вычислять сумму экспонент, используя максимальное на данный момент значение. Мы можем использовать это для объединения первого и второго циклов в наивной реализации и итеративно вычислять максимум и сумму экспонент.
Наш алгоритм становится следующим:

Это легко перевести на Numpy:
Теперь, когда мы разобрались с основными принципами softmax, мы реализуем его в Triton, начав с простой одноблочной версии и постепенно переходя к онлайновой многоблочной. В конечном итоге мы хотим, чтобы наше ядро вело себя как модуль PyTorch и было совместимо с Autograd.
К сожалению, с точки зрения PyTorch, ядра Triton ведут себя как чёрные ящики: выполняемые ими операции не отслеживаются Autograd. Это требует от нас самостоятельной реализации обратного прохода и явного указания способа вычисления градиентов. Давайте освежим в памяти наше любимое цепочное правило и выведем градиент Softmax.
Градиент
Поскольку выходные данные softmax строго положительны, мы можем использовать логарифмическую производную для упрощения вычисления градиента. Здесь мы берём производную логарифма выходного значения и применяем цепное правило:

Отсюда мы переставляем термины и выполняем следующие шаги:

Предположим, что у нас есть некоторый градиент вверх по течению, например, сгенерированный функцией потерь L (например, кросс-энтропийной потерей). Получаем следующее выражение для градиента:

Упрощение левого члена в (9) обусловлено тем, что δ_ij будет равно 1 только для i -го элемента, сводя сумму по j к одному члену.
Реализация Тритона
Одиночный блок Softmax
Теперь, когда мы разобрались с выводом градиента, мы можем написать прямое и обратное ядра softmax. Сначала сосредоточимся на обёртке PyTorch, чтобы понять, как работает реализация с одним блоком на высоком уровне. Учитывая двумерный входной тензор, прямое и обратное ядра будут обрабатывать все строки параллельно.
Для простоты мы определим BLOCK_SIZE достаточно большим, чтобы обрабатывать все столбцы одновременно. В частности, мы установим его как следующую степень числа 2, превышающую количество столбцов, как того требует Triton.
Затем мы определим нашу `сетку` как количество строк (она потенциально может также обрабатывать пакетное измерение).
Обертка PyTorch для нашего SoftmaxSingleBlock — это класс, наследующий от torch.autograd.Function, который реализует прямой и обратный проходы. Оба метода принимают аргумент ctx, который мы будем использовать для кэширования выходных данных Softmax во время прямого прохода и их повторного использования во время обратного прохода.
Оба ядра довольно просты: мы начинаем с загрузки входных строк, используя тот же синтаксис, что и в моем предыдущем примере сложения векторов. Статья. Обратите внимание, что BLOCK_SIZE и num_warps вычисляются с помощью функции calculate_settings. Эта функция взята из библиотеки Unsloth и использовалась в других библиотеках ядра, таких как LigerKernel (на которой в некоторой степени основаны ядра, представленные в этой статье). Она предоставляет эвристический подход для настройки обеих переменных:
def calculate_settings(n: int) -> tuple[int, int]: MAX_FUSED_SIZE = 65536 # максимальный размер сетки на графических процессорах Nvidia BLOCK_SIZE = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: # в этой статье мы убираем это утверждение raise RuntimeError( f»Невозможно запустить ядро Triton, так как n = {n} превышает » f»максимальный размер блока CUDA = {MAX_FUSED_SIZE}.» ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 return BLOCK_SIZE, num_warps
Затем мы реализуем обычный softmax для прямого прохода и уравнение (10) для обратного прохода. Единственное новшество по сравнению с предыдущими статьями — использование модификаторов кэша, которые сообщают компилятору, как кэшировать и вытеснять данные. Сейчас мы сосредоточимся только на трёх модификаторах кэша:
- .ca ( Кэш на всех уровнях ): указывает компилятору загружать данные как в кэш L1, так и в кэш L2, указывая на возможность их повторного использования в ближайшее время. Этот модификатор следует использовать, когда данные достаточно малы для размещения в кэше L1 (~128–192 КБ на SM на A100) и, вероятно, будут использоваться многократно.
- .cs ( потоковая передача ): обрабатывать данные как потоковые , они будут использованы один раз, а затем удалены для освобождения места в L1.
- .wb ( обратная запись ): Обычная кэшированная запись, данные остаются в иерархии кэша, хорошо, если выходные данные можно использовать повторно.
В следующих ядрах мы будем использовать модификатор .ca для загрузки, поскольку мы выполняем несколько операций с загруженными данными. Для сохранения мы будем использовать .cs в прямом проходе, поскольку выходные данные не будут сразу повторно использоваться, и .wb в обратном проходе, поскольку в контексте autograd (т.е. цепочечного правила) выходные данные градиента будут использоваться ядрами, расположенными ниже по потоку.
Многоблочный Softmax
Теперь давайте рассмотрим онлайн-формулировку softmax. В этом разделе мы реализуем многоблочный вариант предыдущего ядра. В этой версии будет использоваться BLOCK_SIZE < n_cols, то есть мы будем загружать только плитку с BLOCK_SIZE элементов за раз, аналогично тому, как мы работали с плиточной GEMM в предыдущем уроке. Вы можете спросить: «Как выбрать размер блока?»
Это отличный повод представить утилиту автонастройки Triton. Имея список конфигураций, автонастройка выполнит поиск по сетке, чтобы определить и кэшировать оптимальную конфигурацию для конкретной входной формы. Этот процесс повторяется каждый раз, когда ядру передается новая входная форма.
Здесь мы выполняем поиск по сетке по размеру блока и количеству варпов, используя следующую функцию полезности:
из itertools import product # — Многоблочная настройка — BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192] NUM_WARPS = [2, 4, 8, 16] def get_autotune_config( block_sizes: list[int], num_warps: list[int] ) -> list[triton.Config]: return [ triton.Config(kwargs={«BLOCK_SIZE»: bs}, num_warps=nw) for (bs, nw) in list(product(block_sizes, num_warps)) ]
Теперь мы можем снабдить наши многоблочные ядра автоматической настройкой и передать список конфигураций, key=”n_cols” указывает, что оптимальная конфигурация зависит от количества столбцов входных данных.
Реализация этих ядер концептуально очень близка к онлайн-программе SoftMax, которую мы рассматривали ранее. Главное отличие заключается в том, что мы итерируем по тайлам (а не по отдельным элементам, как в Numpy), что требует некоторых корректировок. Например, мы добавляем сумму по тайлу в обновлении d, а обратное ядро теперь также требует двух итераций.
Примечание: оболочка PyTorch точно такая же, за исключением того, что мы удалили строку, в которой объявлены BLOCK_SIZE и num_warps (поскольку они выбираются автонастройкой).
Тестирование и бенчмаркинг
Теперь мы можем выполнить прямой и обратный проход с обоими ядрами и убедиться, что они соответствуют базовым показателям PyTorch:
def validate_kernel(kernel_fn: callable) -> None: device = «cuda:0» if torch.cuda.is_available() else «cpu» torch.random.manual_seed(0) # Генерация входных данных x = torch.randn((256, 512), device=device) # входные данные triton x.requires_grad = True xt = deepcopy(x) # входные данные torch triton_output = kernel_fn(x) torch_output = torch.softmax(xt, dim=1) torch.testing.assert_close(triton_output, torch_output) # тест ядра fwd # Настройка поддельных меток y = torch.zeros_like(x) inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],))) y[inds] = 1 # Определение потерь и запуск обратного прохода loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(torch_output, y) loss.backward() # Сохраняем тензор градиента для дальнейшего использования torch_xgrad = xt.grad.detach().clone() triton_loss = loss_fn(triton_output, y) triton_loss.backward() torch.testing.assert_close(x.grad, torch_xgrad) # тестовые выходные данные градиента validate_kernel(softmax_sb) validate_kernel(softmax_mb)
Наконец, мы сравниваем нашу реализацию с базовым уровнем PyTorch, используя следующий фрагмент:
# — Источник: Triton softmax tutorial — @triton.testing.perf_report( triton.testing.Benchmark( x_names=[«N»], # имена аргументов для использования в качестве оси X графика x_vals=[ 128 * i for i in range(2, 100) ], # различные возможные значения для `x_name` line_arg=»provider», # имя аргумента, значение которого соответствует другой линии графика line_vals=[ «triton_single_block», «triton_multi_block», «torch», ], # возможные значения для `line_arg« line_names=[ «Triton_single_block», «Triton_multi_block», «Torch», ], # имя метки для линий styles=[(«blue», «-«), («green», «-«), («red», «-«)], ylabel=»GB/s», # имя метки для оси Y plot_name=»softmax-performance», # имя для графика. Используется также как имя файла для сохранения графика. args={«M»: 4096}, # значения аргументов функции, не указанные в `x_names` и `y_name` ) ) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == «torch»: ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == «triton_single_block»: torch.cuda.synchronize() ms = triton.testing.do_bench(lambda: softmax_sb(x)) torch.cuda.synchronize() if provider == «triton_multi_block»: torch.cuda.synchronize() мс = triton.testing.do_bench(лямбда: softmax_mb(x)) torch.cuda.synchronize() гбит/с = лямбда мс: 2 * x.numel() * x.element_size() * 1e-9 / (мс * 1e-3) return гбит/с(мс) benchmark.run(show_plots=True, print_data=True)
Хорошие новости! Наше одноблочное ядро стабильно превосходит базовый уровень PyTorch, в то время как многоблочный вариант теряет производительность на входных данных с более чем 6 тысячами столбцов:

Рассматривая более крупные вложения, мы можем сделать несколько наблюдений:
- Многоблочное ядро в конечном итоге стабилизирует пропускную способность на уровне около 900 ГБ/с, превосходя базовый уровень PyTorch для входных данных с более чем 30 тыс. столбцов.
- Интересно, что, похоже, многоблочный вариант будет доминировать для входных данных с более чем 60 тыс. столбцов.
- Несмотря на то, что мы превышаем максимальный размер блока в одноблочном варианте, ядро почему-то всё равно работает без сбоев. Более того, Triton автоматически управляет размером блока изнутри.
Когда n_cols превышает аппаратный предел, Triton разбивает входные данные на части и выполняет итерации по ним. Однако, похоже, это медленнее, чем многоблочный подход.
Для дальнейшего развития мы могли бы объединить оба подхода в одном ядре, которое явно выбирает оптимальное ядро на основе размера входных данных. Таким образом, мы бы выиграли от высокой производительности одноблочного ядра для небольших входных данных и более высокой пропускной способности многоблочного варианта для входных данных с более чем 60 тысячами столбцов.

На этом завершается третий эпизод сериала «Тритон», еще раз спасибо за вашу поддержку!
В следующей статье мы применим онлайн-формулу softmax в контексте Flash Attention .
До следующего раза! 👋
Ресурсы:
- Реализация LigerKernel Softmax
- Вывод градиента Softmax Томасом Курбилем
- Оптимизация ядра графического процессора: Softmax — Часть 2, автор Хьюго Розенкранц-коста (ядра Cuda и Triton с большим упором на профилирование и оптимизацию оборудования)
- От онлайн-софтмакса до FlashAttention от Цзыхао Йе
Источник: towardsdatascience.com



























