Плиточный GEMM, память GPU, объединение и многое другое!
Делиться

Умножение матриц, несомненно, является самой распространённой операцией, выполняемой графическими процессорами. Это фундаментальный строительный блок линейной алгебры, который применяется в широком спектре областей, таких как графика, физическое моделирование и научные вычисления, а также повсеместно используется в машинном обучении.
В сегодняшней статье мы подробно разберём концептуальную реализацию общего метода умножения матриц на матрицу (GEMM), а также познакомимся с несколькими концепциями оптимизации, такими как тайлинг и объединение памяти. Наконец, мы реализуем GEMM в Triton!
Эта статья — вторая из серии, посвящённой Triton и ядрам GPU. Если вы не знакомы с Triton или хотите освежить знания по основам GPU, ознакомьтесь с предыдущей статьёй! Весь код, представленный в этой статье, доступен на GitHub.
Изучаем Triton по одному ядру за раз: сложение векторов
Отказ от ответственности: все представленные ниже рисунки и анимации были сделаны автором, если не указано иное.
Наивный GEMM
Начнём с простого: нам нужно перемножить две матрицы X и Y с размерами (M,N) и (N,K) соответственно. Таким образом, выходная матрица Z=X@Y будет иметь размер (M,K).
Эта операция включает вычисление скалярных произведений всех пар строк и столбцов по осям X и Y соответственно. Простая реализация на NumPy может выглядеть примерно так:
Несмотря на простоту написания, чтения и понимания, эта реализация крайне неэффективна с точки зрения доступа к памяти и кэширования. Как упоминалось в первой статье этой серии, фундаментальным аспектом оптимизации GPU является минимизация передачи данных .
Однако наша текущая реализация начинается с загрузки строки из X, итеративно загружает все K столбцов Y, вычисляет их скалярное произведение и повторяет процесс для каждой строки в X. Это приводит к общему количеству операций загрузки M(K+1).

Как видно из анимации, такой подход к доступу к памяти расточителен, поскольку каждый столбец Y загружается M раз. Аналогично: это всё равно, что бегать в магазин за продуктами (глобальная память) каждый раз, когда нужен новый ингредиент для блюда, вместо того, чтобы приготовить все ингредиенты на кухонном столе (общая память). В идеале хотелось бы минимизировать количество загрузок каждого блока данных и максимально увеличить возможность его повторного использования после загрузки. Это оставляет нам два основных направления оптимизации:
- Как можно улучшить схему доступа, чтобы минимизировать избыточные нагрузки?
- Какой объем данных мы можем загрузить одновременно и где они должны храниться на графическом процессоре?
Плиточный GEMM
Как упоминалось ранее, наивный подход к GEMM приводит к множеству избыточных загрузок, что приводит к ненужным накладным расходам. В идеале хотелось бы загружать каждый сегмент данных только один раз и выполнять все операции, в которых он используется, прежде чем удалять его из памяти.
Элегантным подходом к этой задаче является тайлинг (разбиение на плитки ), который подразумевает разделение больших матриц на более мелкие «плитки» или подматрицы. Рассмотрим две матрицы X и Y с формами (4,6) и (6,4) соответственно. X@Y даёт матрицу Z с формами (4,4).
Чтобы вычислить первый элемент матрицы Z, Z[0,0], нам нужно вычислить скалярное произведение первой строки матрицы X и первого столбца матрицы Y: Z[0,0] = dot(X[0, :], Y[:, 0]). Мы также можем разбить скалярное произведение на более мелкие части, например, на группы по 3 элемента: Z[0,0] = dot(X[0,0:3], Y[0:3, 0]) + dot(X[0,3:6], Y[3:6, 0]).
В качестве альтернативы мы можем расширить этот подход до двух измерений и вычислить весь блок (2,2) Z за один раз: Z[0:2, 0:2] = dot(X[0:2, 0:2], Y[0:2, 0:2]) + dot(X[0:2, 2:4], Y[2:4, 0:2]) + dot(X[0:2, 4:6], Y[4:6, 0:2]).
Вот наглядное представление процесса умножения матриц:

Анимация выше иллюстрирует повторное использование данных в тайловой модели GEMM. Для каждого блока 2×2 по осям X и Y мы вычисляем 4 скалярных произведения, что приводит к выходной матрице (2,2) в Z. Поскольку каждая плитка содержит 3 блока, нам необходимо суммировать 3 такие матрицы для вычисления итогового выходного значения (2,2) в Z. Это суммирование представлено цветными ячейками в Z.
Если использовать аналогию с кухней, это можно сравнить с покупкой ингредиентов из магазина и их приготовлением на кухонном столе (т. е. небольшой общей памятью), а затем их повторным использованием несколько раз перед тем, как вернуть их в магазин.
Важно отметить, что повторное использование загруженных данных на нескольких этапах позволяет этому подходу значительно сократить количество операций загрузки. Для блоков (2,2) каждая строка X и каждый столбец Y используются в двух скалярных произведениях. Таким образом, мы выполняем вдвое больше операций с каждым блоком загруженных данных, что примерно вдвое сокращает количество операций загрузки! Обратите внимание, что это распространяется и на блоки большего размера: использование блока (32,32) сократит количество загрузок примерно в 32 раза.
Теперь вы, вероятно, задаетесь вопросом: «Насколько большими могут быть эти блоки?» Чтобы ответить на этот вопрос, давайте вспомним, как управляется память в современных видеокартах.
Иерархия памяти графического процессора
Мы различаем четыре основных типа памяти в видеокартах Nvidia. В качестве примера рассмотрим A100:
- Регистры: самый быстрый и компактный тип памяти в графическом процессоре, расположенный непосредственно в каждом потоковом мультипроцессоре (SM). В A100 каждый SM предоставляет 256 КБ пространства регистрового файла (65 536 × 32-битных регистров), распределенного между потоками. Каждый поток получает собственные 32-битные регистры для хранения временных переменных и промежуточных результатов, что позволяет полностью исключить трафик памяти. Однако использование регистров потоком напрямую влияет на занятость, поскольку использование слишком большого количества регистров на поток ограничивает количество потоков, которые могут выполняться одновременно.
- L1/Общая память : В A100 каждый SM имеет 192 КБ статической памяти (SRAM), которую можно гибко настроить как аппаратно управляемый кэш L1 или как управляемую программистом общую память . Для критически важных с точки зрения производительности ядер, таких как умножение матриц, мы явно используем это пространство как общую память для размещения фрагментов данных рядом с вычислительными блоками, полностью минуя кэш L1. Это обеспечивает нам точный контроль над повторным использованием данных.
- Кэш L2 : этот кэш медленнее, чем L1, но значительно больше, с общим объёмом около 40 МБ для всех SM в A100. Он служит глобальным кэшем для данных и инструкций, сокращая количество обращений к памяти HBM с высокой задержкой. Кэш L2 когерентен между SM , то есть обновления одного SM видны другим, что обеспечивает синхронизацию между блоками потоков. Его пропускная способность может достигать нескольких терабайт в секунду, что позволяет ему выступать в качестве буфера между быстрой встроенной SRAM и более медленной HBM.
- Память с высокой пропускной способностью (HBM) : это память устройства, объём которой составляет 40 или 80 ГБ в зависимости от модели A100. Она обеспечивает чрезвычайно высокую пропускную способность (до 2 ТБ/с в версии 80 ГБ) , но с гораздо большей задержкой, чем внутрикристальные кэши. В HBM хранятся большие тензоры, веса моделей и наборы данных во время выполнения. Поскольку доступ к HBM требует больших затрат, эффективные ядра стремятся минимизировать перемещение данных и максимально увеличить повторное использование данных на кристалле через регистры и общую память.
Как видите, иерархия памяти обычно обеспечивает баланс между ёмкостью и задержкой. Таким образом, максимизация производительности сводится к эффективной загрузке данных из HBM в разделяемую память и максимально возможному их повторному использованию.

Выбор размера блока критически важен. Мы хотим, чтобы блоки были достаточно большими для выполнения большого объёма параллельной работы, но при этом достаточно маленькими, чтобы их данные помещались в общую память и регистры SM. BLOCK_SIZE, равный 64, — распространённая отправная точка, поскольку он кратен размеру варпа (32 потока), что гарантирует полное использование оборудования.
Параллельный плиточный GEMM
Учитывая эти соображения, естественным продолжением нашей плиточной модели GEMM является распараллеливание вычисления каждой пары плиток по нескольким блокам потоков, как показано на следующей анимации.

Объединение памяти
Прежде чем писать плиточный GEMM в Triton, нам нужно учесть ещё одну деталь: объединение памяти (memory joining ) — метод, позволяющий оптимально использовать глобальную пропускную способность памяти. Объединение памяти достигается, когда последующие потоки в варпе обращаются к последующим адресам памяти (memory address) . Представьте себе библиотекаря, которому нужно принести книги для клиента: если все книги стоят рядом на полке, он может взять их все сразу. Если же все книги лежат на разных полках, ему придётся брать их по одной, что займёт значительно больше времени.
Чтобы понять, как это применимо к нашему случаю, обратите внимание, что матрицы хранятся в памяти линейно, то есть матрица (2,2) хранится как последовательность из 4 последовательных элементов. Такие фреймворки, как PyTorch, используют построчную компоновку, то есть элементы матрицы располагаются в памяти последовательно по строкам . Например, элементы нашей матрицы (2,2) будут храниться следующим образом: [(0,0), (0,1), (1,0), (1,1)]. Обратите внимание, что элементы одной строки являются смежными (соприкасаются), а элементы одного столбца имеют шаг 1 (разделены одним элементом).

Это означает, что мы можем загружать строки, используя объединенные загрузки , но столбцы этому условию не удовлетворяют. Однако нам необходим доступ к столбцам Y для вычисления скалярных произведений. Для максимизации производительности рекомендуется транспонировать Y так, чтобы итерации выполнялись по его строкам, а не по столбцам.
Однако транспонирования Y недостаточно для изменения расположения матрицы в памяти. Как упоминалось ранее, PyTorch хранит матрицы в плоском массиве. Каждое измерение матрицы связано с атрибутом «шаг», обозначающим шаг, необходимый для перехода от одного элемента к следующему по этому измерению. Например, матрица (10,10) будет иметь шаг = (10,1). Действительно, начиная с элемента [0,0], элемент [1,0] находится на расстоянии 10 ячеек памяти (т.е. в одной строке), тогда как элемент [0,1] является соседним.
При транспонировании тензора PyTorch не изменяет структуру в памяти, а просто пересчитывает шаги. Чтобы транспонирование было эффективным с точки зрения памяти, необходимо вызвать функцию YTcontiguous().
Это необходимые шаги для эффективной загрузки столбцов Y, однако нам понадобится транспонировать загруженные блоки внутри ядра, чтобы правильно выполнить скалярное произведение: z_block = tl.dot(X_block, Y_block.T).

Реализация Тритона
Далее мы сначала опишем ядро без объединения памяти, чтобы упростить логику и арифметику указателей, а затем обобщим изменения, необходимые для объединения операций загрузки по Y столбцам.
Начнём с обёртки PyTorch вокруг ядра. Нам нужно считать M, N, K из входных матриц и вычислить их шаги, так как эти константы будут полезны позже в ядре. Затем мы определяем BLOCK_SIZE и объявляем сетку.
Теперь перейдём к самому коду ядра. Мы воспользуемся утилитой Triton make_block_ptr, которая упрощает арифметику указателей. Мы создаём один указатель на блок для каждой матрицы и передаём в качестве входных данных форму матрицы, её шаги и размер блока. Кроме того, мы указываем смещение — координату верхнего левого элемента в текущем блоке. Для оси X это соответствует (m_idx * BLOCK_SIZE, 0), где m_idx — индекс текущего блока по оси M.
Далее мы определяем z_acc, нулевую матрицу, которая будет принимать частичные скалярные произведения по мере итерации по плиткам. Теперь мы итерируем по общему измерению N, загружая блоки размером (BLOCK_SIZE, BLOCK_SIZE) и накапливая их скалярные произведения в z_acc. Затем мы перемещаем указатели блоков по общему измерению с помощью .advance.
Вы могли заметить, что при загрузке данных мы используем bound_check и padding_option вместо mask и other, как в предыдущей статье. Эти аргументы относятся к использованию указателей блоков и определяют, какие оси следует проверять на наличие операций, выходящих за пределы диапазона (в данном случае (0,1) для x и y), и как обрабатывать эти недопустимые значения. Здесь мы устанавливаем их равными нулю, чтобы они не учитывались при вычислении скалярного произведения.
Теперь мы можем оценить производительность этого ядра, используя следующую функцию:
def bench(fn: callable, x: torch.Tensor, y: torch.Tensor, repeat: int): flops = [] med_latency = [] for _ in tqdm(range(repeat), desc=f»Benchmarking {fn.__name__}»): latency_ms = triton.testing.do_bench(lambda: fn(x, y), quantiles=[0.5], # получаем медианную задержку return_mode=»all», ) n_flops = 2 * M * N * K # matmul примерно требует 2*M*N*K операций tflops = n_flops / (latency_ms / 1e3) / 1e12 med_latency.append(latency_ms) flops.append(tflops) flops = np.array(flops) med_latency = np.array(med_latency) print(f»Абсолютная погрешность: {torch.sum(torch.abs(X@Y — fn(x, y)))}») print(f»Медианная задержка: {med_latency.mean():.4f} ± {med_latency.std():.3f} мс») print(f»Пропускная способность: {flops.mean():.4f} ± {flops.std():.3f} ТФЛОПС») M = 8192 N = 6144 K = 4096 X = torch.randn((M, N), device=»cuda», dtype=torch.float32) Y = torch.randn((N, K), device=»cuda», dtype=torch.float32) bench(block_matmul, X, Y, repeat=10)
Мы получаем следующие результаты (используя графический процессор T4 на Colab):
Абсолютная ошибка: 0,0 # ядро выдаёт правильный результат! Медианная задержка: 130,7831 ± 1,794 мс. Пропускная способность: 3,1533 ± 0,043 ТФЛОПС.
Теперь рассмотрим изменения, необходимые для объединённых загрузок по оси Y: нам в основном нужно изменить форму, шаги и смещения при определении указателя блока для оси Y. Кроме того, мы обновляем указатель блока для перемещения вдоль измерения столбца (ранее — измерения строки). Полный код этой реализации доступен на GitHub.
@triton.jit def coalesced_block_matmul_kernel( X_ptr, X_m_stride, X_n_stride, Y_ptr, Y_k_stride, Y_n_stride, Z_ptr, Z_m_stride, Z_k_stride, M, N, K, BLOCK_SIZE: tl.constexpr, ): … y_block_ptr = tl.make_block_ptr( base=Y_ptr, # перевернуть форму, шаги и смещения, чтобы они соответствовали YT shape=(K, N), strides=(Y_k_stride, Y_n_stride), offsets=(k_idx * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE), order=(0, 1), ) … for _ in range(0, N, BLOCK_SIZE): … # загрузить z_acc += tl.dot(x, yT) # транспонировать Y обратно для скалярного произведения x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE)) # продвигать указатель блока вдоль столбцов YT (т. е. строк Y) y_block_ptr = tl.advance(y_block_ptr, offsets=(0, BLOCK_SIZE)) tl.store(pointer=z_block_ptr, value=z_acc, border_check=(0, 1)) def coalesced_block_matmul(X, Y): Y = YTcontiguous() # Y теперь (K,N) M, N = X.shape K, _ = Y.shape Z = torch.empty((M, K), device=»cuda») x_stride_m, x_stride_n = X.stride() y_stride_k, y_stride_n = Y.stride() z_stride_m, z_stride_k = Z.stride() … # определяем BLOCK_SIZE и объединяем сетку coalesced_block_matmul_kernel[grid]( X, x_stride_m, x_stride_n, Y, y_stride_n, y_stride_k, Z, z_stride_m, z_stride_k, M, N, K, BLOCK_SIZE, ) возвращаем Z
Вот результаты нашего теста для ядра с объединенными нагрузками для Y:
Абсолютная погрешность: 0,0 # И снова, ядро верно! Медианная задержка: 261,9420 ± 0,858 мс Пропускная способность: 1,5741 ± 0,005 Тфлопс
Удивительно, но пропускная способность этого второго ядра составляет всего половину того, что мы получили с первым, несмотря на повышение эффективности операций загрузки 🤔
Быстрый осмотр с помощью nsight (профилировщика ядра NVIDIA, подробнее об этом в следующей статье) показывает, что операция транспонирования в ядре создаёт «пробку». В частности, транспонирование создаёт конфликты банков , из-за чего потоки большую часть времени простаивают. Примечательно, что у планировщика варпов нет подходящих варпов для отправки в 87,6% случаев, поскольку он ожидает разрешения конфликта банков. Кроме того, в отчёте говорится:
———————– ———– ————–
Название метрики Единица измерения метрики Значение метрики
———————– ———– ————–
…
Пропускная способность DRAM % 8,20
Пропускная способность вычислений (SM) % 21,14
…
Это означает, что ядро ограничено задержкой (т.е. не ограничено ни памятью, ни вычислительными мощностями; подробнее см. в предыдущей статье). В отличие от этого, первое ядро ограничено вычислительными мощностями (т.е. увеличение вычислительных мощностей повысит производительность), поскольку его вычислительная мощность выше, чем пропускная способность DRAM.
———————– ———– ————–
Название метрики Единица измерения метрики Значение метрики
———————– ———– ————–
…
Пропускная способность DRAM % 29,35
Пропускная способность вычислений (SM) % 74,39
…
Заключение
Этот эксперимент подчёркивает важность профилирования и эмпирической проверки. Даже такие продуманные оптимизации, как объединение обращений к памяти, могут создавать новые узкие места, если их не оценить тщательно. Первое ядро, хотя и было проще, было ограничено вычислительной мощностью и лучше соответствовало характеристикам оборудования.
В следующих статьях этой серии мы реализуем ядро softmax, уделив особое внимание интеграции Triton с autograd PyTorch и профилированию ядер с помощью Nsight.
До следующего раза! 👋
Полезные ресурсы
- Полная реализация
- Введение в GEMM и назначение
- Архитектура Nvidia Ampere (спецификации A100)
Источник: towardsdatascience.com





















