Как обновить и оптимизировать устаревшие модели ИИ/МО
Делиться

Пролог
Чтобы оправдать ожидания читателей и избежать разочарования, мы хотели бы сразу отметить, что данная публикация не предлагает полностью удовлетворительного решения проблемы, описанной в заголовке. Мы предложим и оценим две возможные схемы автоматического преобразования моделей TensorFlow в PyTorch: первая, основанная на формате и библиотеках Open Neural Network Exchange (ONNX), и вторая, использующая API Keras3. Однако, как мы увидим, каждая из них имеет свои сложности и ограничения. Насколько нам известно, на момент написания статьи общедоступных и надёжных решений этой проблемы не существует.
Большое спасибо Рому Мальцеру за его вклад в этот пост.
Упадок TensorFlow
За годы своего существования компьютерные науки познали немало «религиозных войн» — жарких, порой враждебных, споров между программистами и инженерами о «лучших» инструментах, языках и методологиях. Ещё несколько лет назад религиозная война между PyTorch и TensorFlow, двумя известными фреймворками глубокого обучения с открытым исходным кодом, была особенно острой. Сторонники TensorFlow подчёркивали его быстрый режим выполнения графов, в то время как сторонники PyTorch подчёркивали его «питоновскую» природу и простоту использования.
Однако в наши дни активность вокруг PyTorch значительно затмевает активность вокруг TensorFlow. Об этом свидетельствует количество крупных технологических компаний, перешедших на PyTorch вместо TensorFlow, количество моделей для каждого фреймворка в репозитории моделей HuggingFace, а также объём инноваций и оптимизации в каждом фреймворке. Проще говоря, TensorFlow — это лишь тень от себя прежнего. Война окончена, и PyTorch — безусловный победитель. Краткую историю войны PyTorch и TensorFlow, а также причины краха TensorFlow можно найти в публикации Пань Синханя «TensorFlow мёртв. PyTorch победил».
Проблема: Что нам делать со всеми нашими устаревшими моделями TensorFlow?!!
В свете этой новой реальности многие организации, ранее использовавшие TensorFlow, перенесли всю разработку новых моделей искусственного интеллекта и машинного обучения (ИИ/МО) в PyTorch. Однако перед ними стоит сложная задача, связанная с устаревшим кодом: что делать со всеми моделями, которые уже были созданы и развернуты в TensorFlow?
Вариант 1: Ничего не делать.
Вы можете задаться вопросом, почему это вообще проблема — модели TensorFlow работают, давайте не будем их трогать. Хотя этот подход и допустим, у него есть ряд недостатков, которые следует учитывать:
- Сокращение затрат на поддержку : по мере ухудшения ситуации с TensorFlow, его поддержка будет снижаться. Неизбежно, что-то начнёт ломаться. Например, могут возникнуть проблемы с совместимостью с новыми пакетами Python или системными библиотеками.
- Ограниченная экосистема : решения на базе ИИ/МО обычно включают в себя множество вспомогательных программных библиотек и сервисов, взаимодействующих с выбранным нами фреймворком, будь то PyTorch или TensorFlow. Со временем можно ожидать, что многие из них прекратят поддержку TensorFlow. Например, HuggingFace недавно объявила о прекращении поддержки TensorFlow.
- Ограниченное сообщество : Индустрия искусственного интеллекта и машинного обучения (ИИ/МО) во многом обязана своим быстрым развитием сообществу. Количество проектов с открытым исходным кодом, количество онлайн-уроков и активность на специализированных каналах поддержки в сфере ИИ/МО не имеют себе равных. По мере упадка TensorFlow будет сокращаться и его сообщество, и вам может стать всё сложнее получить необходимую помощь. Стоит ли говорить, что сообщество PyTorch процветает.
- Стоимость возможностей : Экосистема PyTorch процветает благодаря постоянным инновациям и оптимизациям. В последние годы появились ядра с мгновенным вниманием, поддержка восьмибитных типов данных с плавающей запятой, компиляция графов и множество других достижений, которые продемонстрировали значительный рост производительности и снижение затрат на ИИ/МО. За тот же период набор функций TensorFlow практически не изменился. Использование TensorFlow означает отказ от многих возможностей оптимизации затрат на ИИ/МО.
Вариант 2: ручное преобразование моделей TensorFlow в PyTorch
Второй вариант — переписать устаревшие модели TensorFlow в PyTorch. Это, вероятно, лучший вариант с точки зрения результата, но для компаний, накопивших технический долг за многие годы, конвертация даже одной модели может оказаться непростой задачей. Учитывая требуемые усилия, можно использовать этот подход только для моделей, находящихся в стадии активной разработки (например, на этапе обучения). Повторное использование всех уже развёрнутых моделей может оказаться непомерным.
Вариант 3: Автоматизация преобразования TensorFlow в PyTorch
Третий вариант, который мы рассмотрим в этой статье, — автоматизация преобразования устаревших моделей TensorFlow в PyTorch. Таким образом, мы надеемся реализовать преимущества выполнения моделей в PyTorch, но без огромных усилий, связанных с ручным преобразованием каждой модели.
Для упрощения обсуждения мы определим модель TensorFlow и рассмотрим два предложения по её конвертации в PyTorch. В качестве среды выполнения мы будем использовать Amazon EC2 g6e.xlarge с графическим процессором NVIDIA L40S, AWS Deep Learning Ubuntu (22.04) AMI и среду Python, включающую библиотеки TensorFlow (2.20), PyTorch (2.9), torchvision (0.24.0) и transformers (4.55.4). Обратите внимание, что фрагменты кода, которые мы будем публиковать, предназначены для демонстрационных целей. Пожалуйста, не воспринимайте использование нами какого-либо кода, библиотеки или платформы как одобрение их использования.
Преобразование моделей — почему это сложно?
Определение модели ИИ состоит из двух компонентов: архитектуры модели и её обученных весов. Решение для преобразования модели должно учитывать оба компонента. Преобразование весов модели довольно просто; веса обычно хранятся в формате, который можно легко преобразовать в отдельные тензорные массивы и использовать повторно в выбранном фреймворке. В отличие от этого, преобразование архитектуры модели представляет собой гораздо более сложную задачу.
Один из подходов мог бы заключаться в создании соответствия между строительными блоками модели в каждом из фреймворков. Однако существует ряд факторов, которые делают этот подход практически нереализуемым:
- Наложение и распространение API : если принять во внимание огромное количество часто перекрывающихся API TensorFlow для построения компонентов модели, а затем добавить огромное количество элементов управления и аргументов API для каждого слоя, можно увидеть, как создание всеобъемлющего однозначного сопоставления может быстро стать некрасивым.
- Различные подходы к реализации : на уровне реализации TensorFlow и PyTorch используют принципиально разные подходы. Хотя некоторые допущения обычно скрыты за API верхнего уровня, они требуют особого внимания пользователя. Например, в то время как TensorFlow по умолчанию использует формат «channels-last» (NHWC), PyTorch предпочитает формат «channels-first» (NCHW). Это различие в индексации и хранении тензоров усложняет преобразование операций модели, поскольку для корректного порядка измерений необходимо проверять/изменять каждый слой.
Вместо того, чтобы пытаться преобразовать данные на уровне API, альтернативным подходом может быть захват и преобразование внутреннего представления графа TensorFlow. Однако, как скажет вам любой, кто когда-либо заглядывал «под капот» TensorFlow, это тоже может быстро привести к серьёзным проблемам. Внутреннее представление графа TensorFlow невероятно сложно и часто включает в себя множество низкоуровневых операций, поток управления и вспомогательные узлы, не имеющие прямого эквивалента в PyTorch (особенно если вы работаете со старыми версиями TensorFlow). Даже простое его понимание кажется за пределами человеческих возможностей, не говоря уже о его преобразовании в PyTorch.
Обратите внимание, что те же самые проблемы затруднили бы для генеративной модели ИИ выполнение преобразования способом, который был бы абсолютно надежным.
Предлагаемые схемы преобразования
В связи с этими трудностями мы отказываемся от реализации собственного конвертера моделей и вместо этого рассматриваем инструменты, предлагаемые сообществом AI/ML. В частности, мы рассматриваем две различные стратегии преодоления описанных нами трудностей:
- Преобразование через унифицированное графовое представление : это решение предполагает наличие общего стандарта представления определения модели искусственного интеллекта/машинного обучения и утилит для преобразования моделей в этот стандарт и обратно. Решение, которое мы рассмотрим, использует популярный формат ONNX.
- Конвертация на основе стандартизированного высокоуровневого API : в этом решении мы упрощаем задачу конвертации, ограничивая нашу модель определённым набором высокоуровневых абстрактных API с поддерживаемыми реализациями в каждой из интересующих нас платформ искусственного интеллекта/машинного обучения. Для этого подхода мы будем использовать библиотеку Keras3.
В следующих разделах мы оценим эти стратегии на экспериментальной модели TensorFlow.
Игрушечная модель TensorFlow
В приведённом ниже блоке кода мы инициализируем и запускаем модель TensorFlow Vision Transformer (ViT) из популярной библиотеки трансформеров HuggingFace (версии 4.55.4) — TFViTForImageClassification. Обратите внимание, что в соответствии с решением HuggingFace прекратить поддержку TensorFlow, этот класс был удалён из последних версий библиотеки. Модель TensorFlow HuggingFace зависит от Keras 2, который мы добросовестно устанавливаем через пакет tf-keras (2.20.1). Для совместимости с ONNX мы устанавливаем поле ViTConfig.hidden_act в значение «gelu_new»:
импортировать тензорный поток как tf gpu = tf.config.list_physical_devices('GPU')[0] tf.config.experimental.set_memory_growth(gpu, True) из трансформаторов импортировать ViTConfig, TFViTForImageClassification vit_config = ViTConfig(hidden_act=»gelu_new», return_dict=False) tf_model = TFViTForImageClassification(vit_config)
Преобразование моделей с использованием ONNX
Первый рассматриваемый нами метод основан на Open Neural Network Exchange (ONNX) — проекте сообщества, целью которого является определение открытого формата для построения моделей искусственного интеллекта/машинного обучения (ИИ/МО), который повышает взаимодействие между фреймворками ИИ/машинного обучения и снижает зависимость от какого-либо одного из них. В состав API ONNX входят утилиты для преобразования моделей из распространённых фреймворков, включая TensorFlow, в формат ONNX. Существует также несколько общедоступных библиотек для преобразования моделей ONNX в PyTorch. В этой статье мы используем утилиту onnx2torch. Таким образом, преобразование модели из TensorFlow в PyTorch может быть выполнено путём последовательного применения преобразования TensorFlow в ONNX, а затем преобразования ONNX в PyTorch.
Для оценки этого решения мы устанавливаем библиотеки onnx (1.19.1), tf2onnx (1.16.1) и onnx2torch (1.5.15). Мы используем флаг no-deps, чтобы предотвратить нежелательное понижение версии библиотеки protobuf:
pip install —no-deps onnx tf2onnx onnx2torch
Схема преобразования представлена в блоке кода ниже:
импортировать tensorflow как tf импорт torch импорт tf2onnx, onnx2torch BATCH_SIZE = 32 DEVICE = «cuda» spec = (tf.TensorSpec((BATCH_SIZE, 3, 224, 224), tf.float32, name=»input»),) onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature=spec) converter_model = onnx2torch.convert(onnx_model)
Чтобы убедиться, что полученная модель действительно является модулем PyTorch, мы выполняем следующее утверждение:
assert isinstance(converted_model, torch.nn.Module)
Давайте теперь оценим качество и состав получившейся модели PyTorch.
Числовая точность
Чтобы проверить достоверность преобразованной модели, мы запускаем как модель TensorFlow, так и преобразованную модель на одних и тех же входных данных и сравниваем результаты:
import numpy as np batch_input = np.random.randn(BATCH_SIZE, 3, 224, 224).astype(np.float32) # выполнить модель tf tf_input = tf.convert_to_tensor(batch_input) tf_output = tf_model(tf_input, training=False) tf_output = tf_output[0].numpy() # выполнить преобразованную модель converter_model = converter_model.to(DEVICE) converter_model = converter_model.eval() torch_input = torch.from_numpy(batch_input).to(DEVICE) torch_output = converter_model(torch_input) torch_output = torch_output.detach().cpu().numpy() # сравнить результаты print(«Max diff:», np.max(np.abs(tf_output — torch_output))) # пример вывода: # Максимальная разница: 9.3877316e-07
Результаты, безусловно, достаточно близки для подтверждения правильности преобразованной модели.
Структура модели
Чтобы получить представление о структуре преобразованной модели, мы подсчитываем количество обучаемых сравнений и сравниваем его с исходной моделью:
num_tf_params = sum([np.prod(v.shape) for v in tf_model.trainable_weights]) num_pyt_params = sum([p.numel() for p in converter_model.parameters() if p.requires_grad]) print(f»Параметры обучения TensorFlow: {num_tf_params}») print(f»Параметры обучения PyTorch: {num_pyt_params:,}»)
Разница в количестве обучаемых параметров колоссальна: всего 589 824 в преобразованной модели по сравнению с более чем 85 миллионами в исходной. Анализ слоёв преобразованной модели приводит к тому же выводу: преобразование на основе ONNX полностью изменило структуру модели, сделав её практически неузнаваемой. Этот вывод имеет ряд последствий, включая:
- Обучение/тонкая настройка преобразованной модели : хотя мы показали, что преобразованную модель можно использовать для вывода, изменение структуры — в частности, тот факт, что некоторые параметры модели были заложены в нее, означает, что мы не можем использовать преобразованную модель для обучения или точной настройки.
- Применение точечных оптимизаций PyTorch к модели : преобразованная модель состоит из очень большого количества слоёв, каждый из которых представляет собой относительно низкоуровневую операцию. Это значительно ограничивает наши возможности по замене неэффективных операций оптимизированными эквивалентами PyTorch, такими как torch.nn.functional.scaled_dot_product_attention (SPDA).
Оптимизация модели
Мы уже видели, что наши возможности доступа к операциям модели и их изменения ограничены, но существует ряд оптимизаций, которые можно применить, не требуя такого доступа. В блоке кода ниже мы применяем компиляцию PyTorch и автоматическую смешанную точность (AMP) и сравниваем полученную производительность с производительностью модели TensorFlow. Для дополнительной информации мы также тестируем время выполнения версии модели ViTForImageClassification для PyTorch:
# Установить смешанную политику точности tf на bfloat16 tf.keras.mixed_precision.set_global_policy('mixed_bfloat16') # Установить высокую точность matmul для torch.set_float32_matmul_precision('high') @tf.function def tf_infer_fn(batch): return tf_model(batch, training=False) def get_torch_infer_fn(model): def infer_fn(batch): with torch.inference_mode(), torch.amp.autocast(DEVICE, dtype=torch.bfloat16, enabled=DEVICE=='cuda' ): output = model(batch) return output return infer_fn def benchmark(infer_fn, batch): # разминка для _ in range(20): _ = infer_fn(batch) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start.record() iters = 100 for _ in range(iters): _ = infer_fn(batch) end.record() torch.cuda.synchronize() return start.elapsed_time(end) / iters # оценка производительности модели TF avg_time = benchmark(tf_infer_fn, tf_input) print(f»nСреднее время шага TensorFlow: {(avg_time):.4f}») # оценка производительности преобразованной модели torch_infer_fn = get_torch_infer_fn(converted_model) avg_time = benchmark(torch_infer_fn, torch_input) print(f»nСреднее время шага преобразованной модели: {(avg_time):.4f}») # оценка производительности скомпилированной модели torch_infer_fn = get_torch_infer_fn(torch.compile(converted_model)) avg_time = benchmark(torch_infer_fn, torch_input) print(f»nСреднее время шага скомпилированной модели: {(avg_time):.4f}») # оценка производительности Torch ViT из трансформаторов import ViTForImageClassification torch_model = ViTForImageClassification(vit_config).to(DEVICE) Torch_infer_fn = get_torch_infer_fn(torch_model) avg_time = benchmark(torch_infer_fn, torch_input) print(f»nСреднее время шага модели PyTorch ViT: {(avg_time):.4f}») # оценка производительности скомпилированной Torch ViT Torch_infer_fn = get_torch_infer_fn(torch.compile(torch_model)) avg_time = benchmark(torch_infer_fn, torch_input) print(f»nСреднее время шага скомпилированной модели ViT: {(avg_time):.4f}»)
Обратите внимание, что изначально компиляция PyTorch на преобразованной модели завершается сбоем из-за использования оператора torch.Size в слое OnnxReshape. Хотя это легко исправить (например, tuple([int(i) for i in shape])), это указывает на более серьёзное препятствие для оптимизации модели: слой reshape, который встречается в модели десятки раз, обрабатывает формы как тензоры PyTorch, хранящиеся на графическом процессоре. Это означает, что каждый вызов требует отсоединения тензора формы от графика и копирования его в центральный процессор. Вывод заключается в том, что, хотя преобразованная модель функционально точна, её результирующее определение не оптимизировано для производительности во время выполнения. Это видно из результатов пошагового выполнения для различных конфигураций модели:

Преобразованная модель медленнее исходного потока TensorFlow и значительно медленнее версии PyTorch модели ViT.
Ограничения
Хотя (в случае нашей игрушечной модели) схема преобразования на основе ONNX работает, она имеет ряд существенных ограничений:
- В ходе преобразования в модель было включено множество параметров, что ограничило ее применение только рабочими нагрузками вывода.
- Преобразование ONNX разбивает граф вычислений на низкоуровневые операторы таким образом, что это затрудняет применение и/или извлечение выгоды из некоторых оптимизаций PyTorch.
- Использование ONNX подразумевает, что наша схема преобразования будет работать только с моделями, совместимыми с ONNX. Она не будет работать с моделями, которые невозможно сопоставить со стандартным набором операторов ONNX (например, с моделями с динамическим потоком управления).
- Схема конвертации зависит от работоспособности и обслуживания сторонней библиотеки, которая не является частью официального предложения ONNX.
Хотя эта схема работает — по крайней мере, для вычислений методом вывода, — вы можете обнаружить, что её ограничения слишком строги для использования в ваших собственных моделях TensorFlow. Один из возможных вариантов — отказаться от преобразования ONNX в PyTorch и выполнять вывод с помощью библиотеки ONNX Runtime.
Преобразование модели через Keras3
Keras3 — это высокоуровневый API для глубокого обучения, ориентированный на максимальную читаемость, удобство поддержки и использования приложений искусственного интеллекта и машинного обучения. В предыдущей публикации мы оценили Keras3 и отметили его поддержку нескольких бэкендов. В этой публикации мы вновь рассмотрим его поддержку нескольких фреймворков и оценим, можно ли её использовать для преобразования моделей. Предлагаемая нами схема включает в себя 1) миграцию существующей модели TensorFlow в Keras3 и затем 2) запуск модели с бэкендом PyTorch из Keras3.
Обновление TensorFlow до Keras3
В отличие от схемы конвертации на основе ONNX, наше текущее решение может потребовать внесения некоторых изменений в код модели TensorFlow для её миграции в Keras3. Хотя документация и выглядит просто, на практике сложность миграции будет во многом зависеть от деталей реализации модели. В случае нашей экспериментальной модели HuggingFace явно требует использования устаревшего tf-keras, предотвращая использование Keras3. Для реализации нашей схемы нам необходимо 1) переопределить модель без этого ограничения и 2) заменить собственные операторы TensorFlow эквивалентами из Keras3. Приведённый ниже блок кода содержит урезанную версию модели с необходимыми корректировками. Чтобы полностью понять, какие изменения потребовались, выполните параллельное сравнение кода с исходным определением модели.
импорт математика импорт keras HIDDEN_SIZE = 768 IMG_SIZE = 224 PATCH_SIZE = 16 ATTN_HEADS = 12 NUM_LAYERS = 12 INTER_SZ = 4*HIDDEN_SIZE N_LABELS = 2 класс TFViTEmbeddings(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.patch_embeddings = TFViTPatchEmbeddings() num_patches = self.patch_embeddings.num_patches self.cls_token = self.add_weight((1, 1, HIDDEN_SIZE)) self.position_embeddings = self.add_weight((1, num_patches+1, HIDDEN_SIZE)) def вызов (self, pixel_values, training=False): bs, num_channels, height, width = pixel_values.shape встраивания = self.patch_embeddings(pixel_values, training=training) cls_tokens = keras.ops.repeat(self.cls_token, repeats=bs, axis=0) встраивания = keras.ops.concatenate((cls_tokens, embeddings), axis=1) встраивания = встраивания + self.position_embeddings возвращают встраивания класс TFViTPatchEmbeddings(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) patch_size = (PATCH_SIZE, PATCH_SIZE) image_size = (IMG_SIZE, IMG_SIZE) num_patches = (image_size[1]//patch_size[1]) * (image_size[0]//patch_size[0]) self.patch_size = patch_size self.num_patches = num_patches self.projection = keras.layers.Conv2D( фильтры = HIDDEN_SIZE, kernel_size = patch_size, strides = patch_size, padding=»valid», data_format=»channels_last» ) def call(self, pixel_values, training=False): bs, num_channels, height, width = pixel_values.shape pixel_values = keras.ops.transpose(pixel_values, (0, 2, 3, 1)) проекция = self.projection(pixel_values) num_patches = (ширина // self.patch_size[1]) * (высота // self.patch_size[0]) вложения = keras.ops.reshape(projection, (bs, num_patches, -1)) возвращают вложения класс TFViTSelfAttention(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.num_attention_heads = ATTN_HEADS self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS) self.all_head_size = ATTN_HEADS * self.attention_head_size self.sqrt_att_head_size = math.sqrt(self.attention_head_size) self.query = keras.layers.Dense(self.all_head_size, name=»query») self.key = keras.layers.Dense(self.all_head_size, name=»key») self.value = keras.layers.Dense(self.all_head_size, name=»value») def transpose_for_scores(self, тензор, batch_size: int): тензор = keras.ops.reshape(тензор, (batch_size, -1, ATTN_HEADS, self.attention_head_size)) return keras.ops.transpose(тензор, [0, 2, 1, 3]) def call(self, hidden_states, training=False): bs = hidden_states.shape[0] mixed_query_layer = self.query(входные данные=скрытые_состояния) mixed_key_layer = self.key(входные данные=скрытые_состояния) mixed_value_layer = self.value(входные данные=скрытые_состояния) query_layer = self.transpose_for_scores(mixed_query_layer, bs) key_layer = self.transpose_for_scores(mixed_key_layer, bs) value_layer = self.transpose_for_scores(mixed_value_layer, bs) key_layer_T = keras.ops.transpose(key_layer, [0,1,3,2]) attention_scores = keras.ops.matmul(query_layer, key_layer_T) dk = keras.ops.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) attention_scores = keras.ops.divide(attention_scores, dk) attention_probs = keras.ops.softmax(attention_scores+1e-9, axis=-1) attention_output = keras.ops.matmul(attention_probs, value_layer) attention_output = keras.ops.transpose(attention_output,[0,2,1,3]) attention_output = keras.ops.reshape(attention_output, (bs, -1, self.all_head_size)) return (attention_output,) class TFViTSelfOutput(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.dense = keras.layers.Dense(HIDDEN_SIZE) def call(self, hidden_states, input_tensor, training = False): return self.dense(inputs=hidden_states) class TFViTAttention(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.self_attention = TFViTSelfAttention() self.dense_output = TFViTSelfOutput() def call(self, input_tensor, training = False): self_outputs = self.self_attention( hidden_states=input_tensor, training=training ) attention_output = self.dense_output( hidden_states=self_outputs[0], input_tensor=input_tensor, training=training ) return (attention_output,) class TFViTIntermediate(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.dense = keras.layers.Dense(INTER_SZ) self.intermediate_act_fn = keras.activations.gelu def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class TFViTOutput(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.dense = keras.layers.Dense(HIDDEN_SIZE) def call(self, hidden_states, input_tensor, training: bool = False): hidden_states = self.dense(inputs=hidden_states) hidden_states = hidden_states + input_tensor return hidden_states class TFViTLayer(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.attention = TFViTAttention() self.intermediate = TFViTIntermediate() self.vit_output = TFViTOutput() self.layernorm_before = keras.layers.LayerNormalization( epsilon=1e-12 ) self.layernorm_after = keras.layers.LayerNormalization( epsilon=1e-12 ) def call(self, hidden_states, training=False): attention_outputs = self.attention( input_tensor=self.layernorm_before(inputs=hidden_states), training=training, ) attention_outputs = attention_outputs[0] hidden_states = attention_output + hidden_states layer_output = self.layernorm_after(hidden_states) intermediate_output = self.intermediate(layer_output) layer_output = self.vit_output( hidden_states = промежуточный_выход, input_tensor = скрытые_состояния, training = обучение ) outputs = (layer_output,) return outputs class TFViTEncoder(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.layer = [TFViTLayer(name = f»layer_{i}») for i in range(NUM_LAYERS)] def call(self, hidden_states, training = False): for i, layer_module in enumerate(self.layer): layer_outputs = layer_module( hidden_states = скрытые_состояния, training = обучение, ) hidden_states = layer_outputs[0] return tuple([hidden_states]) class TFViTMainLayer(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.embeddings = TFViTEmbeddings() self.encoder = TFViTEncoder() self.layernorm = keras.layers.LayerNormalization(epsilon=1e-12) def call(self, pixel_values, training=False): embedding_output = self.embeddings( pixel_values=pixel_values, training=обучение, ) encoder_outputs = self.encoder( hidden_states=embedding_output, training=обучение, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(inputs=sequence_output) return (sequence_output,) class TFViTForImageClassification(keras.Model): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) self.vit = TFViTMainLayer() self.classifier = keras.layers.Dense(N_LABELS) def call(self, pixel_values, training=False): outputs = self.vit(pixel_values, training=training) sequence_output = outputs[0] logits = self.classifier(inputs=sequence_output[:, 0, :]) return (logits,)
Преобразование TensorFlow в PyTorch
Последовательность преобразования представлена в блоке кода ниже. Как и прежде, мы проверяем выходные данные полученной модели, а также количество обучаемых параметров.
# сохранить веса модели TensorFlow tf_model.save_weights(«model_weights.h5») import keras keras.config.set_backend(«torch») from keras3_vit import TFViTForImageClassification as Keras3ViT keras3_model = Keras3ViT() # вызвать модель для инициализации всех слоев keras3_model(torch_input, training=False) # загрузить веса из модели TensorFlow keras3_model.load_weights(«model_weights.h5») # проверить преобразованную модель assert isinstance(keras3_model, torch.nn.Module) keras3_model = keras3_model.to(DEVICE) keras3_model = keras3_model.eval() torch_output = keras3_model(torch_input, training=False) torch_output = torch_output[0].detach().cpu().numpy() print(«Максимальная разница:», np.max(np.abs(tf_output — torch_output))) num_pyt_params = sum([p.numel() for p in keras3_model.parameters() if p.requires_grad]) print(f»Обучаемые параметры Keras3: {num_pyt_params:,}»)
Обучение/тонкая настройка модели
В отличие от модели, преобразованной в ONNX, модель Keras3 сохраняет ту же структуру и обучаемые параметры. Это позволяет возобновить обучение и/или выполнить тонкую настройку преобразованной модели. Это можно сделать как в рамках обучения Keras3, так и с помощью стандартного цикла обучения PyTorch.
Оптимизация слоев модели
В отличие от модели, преобразованной в ONNX, согласованность определения модели Keras3 позволяет легко модифицировать и оптимизировать реализации слоёв. В приведённом ниже блоке кода мы заменяем существующий механизм внимания высокоэффективным оператором SDPA из PyTorch.
из torch.nn.functional импортировать scaled_dot_product_attention как класс sdpa TFViTSelfAttention(keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) self.num_attention_heads = ATTN_HEADS self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS) self.all_head_size = ATTN_HEADS * self.attention_head_size self.sqrt_att_head_size = math.sqrt(self.attention_head_size) self.query = keras.layers.Dense(self.all_head_size, name=»query») self.key = keras.layers.Dense(self.all_head_size, name=»key») self.value = keras.layers.Dense(self.all_head_size, name=»value») def transpose_for_scores(self, тензор, batch_size: int): тензор = keras.ops.reshape(тензор, (batch_size, -1, ATTN_HEADS, self.attention_head_size)) return keras.ops.transpose(тензор, [0, 2, 1, 3]) def call(self, hidden_states, training=False): bs = hidden_states.shape[0] mixed_query_layer = self.query(inputs=hidden_states) mixed_key_layer = self.key(inputs=hidden_states) mixed_value_layer = self.value(inputs=hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer, bs) key_layer = self.transpose_for_scores(mixed_key_layer, bs) value_layer = self.transpose_for_scores(mixed_value_layer, bs) sdpa_output = sdpa(query_layer, key_layer, value_layer) attention_output = keras.ops.transpose(sdpa_output,[0,2,1,3]) attention_output = keras.ops.reshape(attention_output, (bs, -1, self.all_head_size)) return (attention_output,)
Мы используем ту же функцию сравнительного анализа, что и выше, чтобы оценить влияние этой оптимизации на производительность модели во время выполнения:
torch_infer_fn = get_torch_infer_fn(keras3_model) avg_time = benchmark(torch_infer_fn, torch_input) print(f»Среднее время шага преобразованной модели Keras3: {(avg_time):.4f}»)
Результаты представлены в таблице ниже:

Используя схему преобразования моделей на основе Keras3 и применяя оптимизацию SDPA, нам удалось ускорить производительность вывода модели на 22% по сравнению с исходной моделью TensorFlow.
Компиляция модели
Ещё одна оптимизация, которую мы хотели бы применить, — это компиляция PyTorch. К сожалению (на момент написания этой статьи) компиляция PyTorch в Keras3 ограничена. В случае нашей модели как попытка применить torch.compile непосредственно к модели, так и попытка установить поле jit_compile функции Model.compile в Keras3 не удалась. В обоих случаях сбой был вызван многократными перекомпиляциями, инициированными внутренним механизмом Keras3. Хотя Keras3 предоставляет доступ к экосистеме PyTorch, её высокоуровневая абстракция может накладывать некоторые ограничения.
Ограничения
И снова у нас есть схема конвертации, которая работает, но имеет несколько ограничений:
- Модели TensorFlow должны быть совместимы с Keras3. Объём необходимых работ будет зависеть от особенностей реализации вашей модели. Для этого может потребоваться настройка слоёв Keras.
- Хотя полученная модель представляет собой torch.nn.Module, она не является «чистой» моделью PyTorch, поскольку состоит из слоёв Keras3 и включает в себя множество дополнительного кода Keras3. Это может потребовать некоторой адаптации нашего инструментария PyTorch и наложить определённые ограничения, как мы видели при попытке компиляции PyTorch.
- Решение зависит от работоспособности и обслуживания Keras3 и его поддержки бэкэндов TensorFlow и PyTorch.
Краткое содержание
В этой статье мы предложили и оценили два метода автоматического преобразования устаревших моделей TensorFlow в PyTorch. Результаты представлены в следующей таблице.

В конечном счете, наилучший подход, будь то один из обсуждаемых здесь методов, ручное преобразование, решение на основе генеративного ИИ или решение вообще не выполнять преобразование, будет во многом зависеть от деталей модели и ситуации.
Источник: towardsdatascience.com























