Мой решатель ОДУ на SciPy губил мой байесовский вывод: честный рассказ космолога об открытии дифракционной картины.
Во сколько это обойдется, что это даст и три ошибки, которые я совершаю.
Делиться

Проблема, которая заставила меня искать альтернативу.
Я — теоретический космолог. Моя работа заключается в том, чтобы брать модели Вселенной — уравнения состояния темной энергии, модифицированную гравитацию, тахионные поля — и задаваться вопросом: что же на самом деле говорят данные о параметрах? Инструментом для ответа на этот вопрос является байесовский вывод. Обычно я использую метод вложенной выборки династий для нескольких тысяч или нескольких сотен тысяч оценок вероятности в зависимости от сложности модели.
На протяжении большей части моей работы над докторской диссертацией я не особо задумывался о решателе ОДУ внутри функции правдоподобия, поскольку solve_ivp работал исправно. Он был надежным. Поэтому я использовал его и перешел к следующему этапу.
Затем я начал работать над тахионной моделью темной энергии DBI, где поле темной энергии описывается нестандартным кинетическим членом, а уравнения фона и возмущений представляют собой связанную систему относительно жестких уравнений. Каждый вызов функции правдоподобия решал эти ОДУ, вычислял сопутствующее расстояние и оценивал модуль расстояния на красных смещениях 30 сверхновых.
Я провел профилирование. Решение ОДУ занимало 0,4 мс на один вызов. В режиме вложенной выборки с 10⁵ вычислениями это составляет 40 секунд — только на вызовы ОДУ, не считая каких-либо дополнительных вычислений. А для 10-параметрической модели получение градиента с помощью центральных конечных разностей требует 20 дополнительных прямых вычислений, превращая эти 0,4 мс в 8 мс на каждый градиент. Это 300 секунд, или около 5 минут, только на градиенты. Для одного запуска вложенной выборки.
Что-то должно было измениться.

Что я обнаружил: дифракс
После целого дня поисков я наткнулся на diffrax [1], библиотеку численных решателей ОДУ, полностью написанную на JAX. Не нейронный аналог. Не приближение. Те же самые встроенные алгоритмы Рунге-Кутты, которые я уже использую в scipy — Tsit5 вместо RK45, но то же семейство методов — просто скомпилированные, дифференцируемые и векторизуемые.
Три свойства обусловлены дизайном, «полностью написанным на JAX»:
JIT-компиляция – Весь цикл адаптивного пошагового выполнения компилируется в единое ядро XLA. После первого вызова отсутствуют накладные расходы Python.
Autodiff – Поскольку каждая операция внутри решателя является примитивом JAX, jax.grad распространяет градиенты на протяжении всего решения. Точные градиенты. Один обратный проход. Независимо от количества параметров.
vmap – С помощью jax.vmap можно параллельно решить целый пакет векторов параметров. Это критически важно для вложенной выборки.
Установка занимает 10 секунд:
pip install jax diffrax
Задача проверки: плоская ΛCDM из сверхновых
Чтобы сравнение стало нагляднее, позвольте мне показать точную задачу, над которой я работал. В плоской ΛCDM-вселенной сопутствующее расстояние удовлетворяет следующему условию:
dχdz=cH(z),H(z)=H0Ωm(1+z)3+(1−Ωm),χ(0)=0frac{dchi}{dz} = frac{c}{H(z)}, quad H(z) = H_0sqrt{Omega_m(1+z)^3 + (1-Omega_m)}, quad chi(0)=0
Модуль расстояния определяется следующим образом: μ(z) = 5 log₁₀[(1+z)χ(z) / 10 пк]. Я хочу вывести (Ωₘ, H₀) из 30 смоделированных наблюдений модуля расстояния сверхновых типа SNIa.
from scipy.integrate import solve_ivp import numpy as np C_KMS = 299792.458 # speed of light [km/s] def rhs(z, chi, Om, H0): return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om))) def forward_scipy(Om, H0, z_obs): sol = solve_ivp(rhs, t_span=(0, z_obs[-1]), y0=[0.0], t_eval=z_obs, args=(Om, H0), method="RK45", rtol=1e-8, atol=1e-10) chi = sol.y[0] return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
Старый способ: SciPy
from scipy.integrate import solve_ivp import numpy as np C_KMS = 299792.458 # speed of light [km/s] def rhs(z, chi, Om, H0): return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om))) def forward_scipy(Om, H0, z_obs): sol = solve_ivp(rhs, t_span=(0, z_obs[-1]), y0=[0.0], t_eval=z_obs, args=(Om, H0), method="RK45", rtol=1e-8, atol=1e-10) chi = sol.y[0] return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
Новый подход: Дифракс
import jax, jax.numpy as jnp import diffrax as dfx # Non-negotiable: enable 64-bit (more on this below) jax.config.update("jax_enable_x64", True) def H_jax(z, Om, H0): return H0 * jnp.sqrt(Om*(1+z)**3 + (1-Om)) @jax.jit # compile once, call fast forever def forward_diffrax(theta, z_obs): Om, H0 = theta[0], theta[1] sol = dfx.diffeqsolve( dfx.ODETerm(lambda z, chi, a: C_KMS / H_jax(z, a[0], a[1])), dfx.Tsit5(), t0=0.0, t1=float(z_obs[-1]), # initial and final value dt0=1e-3, # initial step-size y0=jnp.array(0.0), # initial condition args=(Om, H0), saveat=dfx.SaveAt(ts=z_obs), stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10), max_steps=10_000, ) chi = sol.ys return 5 * jnp.log10((1 + z_obs) * chi * 1e5)
Физические принципы идентичны. Алгоритм решателя практически идентичен (Tsit5 очень похож на RK45). Единственные структурные различия заключаются в @jax.jit и API diffrax. Давайте посмотрим, что дают эти два изменения.
Сюрприз 1: скорость
solve_ivp: 404 мкс на вызов. diffrax post-JIT: 59 мкс на вызов. Это в 7 раз быстрее.
Когда я впервые увидел это число, я несколько секунд смотрел на него. Честно говоря, откуда на самом деле берется это ускорение, ведь это не магия.
В solve_ivp Python повторно обращается к бэкенду C/Cython при каждом вызове. Память выделяется заново. Адаптивный цикл while проходит через интерпретатор Python, спрашивая: «слишком ли велика локальная ошибка? Отклонить; иначе увеличить шаг; повторить». Для решения задачи из 12 шагов это означает 12 циклов диспетчеризации Python, 12 выделений памяти и 12 вычислений оценки ошибки, находящихся за блокировкой интерпретатора.
В diffrax первый вызов @jax.jit отслеживает все вычисления, включая адаптивный цикл while, который преобразуется в lax.while_loop и передается XLA для компиляции в ядро машинного кода. Каждый последующий вызов выполняет это ядро напрямую. Следовательно, нет необходимости в Python, выделении памяти и диспетчеризации.

Для 100 000 оценок вероятности 404 мкс против 59 мкс соответствуют 40,4 секундам против 5,9 секунд. Эта разница усиливается с увеличением сложности модели.
Сюрприз 2: градиенты становятся бесплатными
Именно этот момент изменил не только мой рабочий процесс, но и мое представление об инференции. В библиотеке scipy получение градиента логарифма функции правдоподобия относительно двух параметров (Ωₘ, H₀) требует 4 прямых решения (центральные конечные разности). Как только вы начинаете увеличивать количество параметров, это быстро становится дорого: 10 параметров означают 20 прямых решений, 50 параметров — 100. Затраты растут линейно с количеством параметров.
∂ℱ∂Ωm≈ℱ(Ωm+h,H0)−ℱ(Ωm−h,H0)2h,∂ℱ∂H0≈ℱ(Ωm,H0+h)−ℱ(Ωm,H0−h)2hfrac{partialmathcal{F}}{partialOmega_m} approx frac{mathcal{F}(Omega_m+h,H_0) – mathcal{F}(Omega_m-h,H_0)}{2h}, qquad frac{partialmathcal{F}}{partial H_0} approx frac{mathcal{F}(Omega_m,H_0+h) – mathcal{F}(Omega_m,H_0-h)}{2h}
С помощью Diffrax я пишу:
def loss(theta): mu_pred = forward_diffrax(theta, z_obs) return 0.5 * jnp.sum(((mu_pred - mu_obs) / sigma_mu)**2) grad_fn = jax.jit(jax.grad(loss)) # that is the entire change g = grad_fn(jnp.array([0.3, 70.0])) # exact gradient
Внутри JAX обратный режим автодифференциала интегрирует сопряженные уравнения [2] в обратном направлении через решение ОДУ – но мне никогда не приходится записывать эти уравнения. В результате получается точный градиент по времени, сравнимый с одним прямым проходом, независимо от количества параметров.

Как выбрать решатель
При выборе решателя нужно быть немного осторожным. Я почти всегда использовал Tsit5 , и он без нареканий справился примерно с 95% моих задач. Если вам нужен полный процесс принятия решения, вот он:
- Нежесткие ОДУ (большинство космологических задач) →
dfx.Tsit5()← начать здесь - Очень жесткие допуски (< 10⁻⁹) →
dfx.Dopri8() - Жесткое ОДУ (много шагов, решатель, кажется, работает медленно) →
dfx.Kvaerno5() - Жесткие + нежесткие члены (IMEX) →
dfx.KenCarp4() - SDE →
dfx.EulerHeun()илиdfx.SPaRK()
Быстрый способ определить, является ли ваша задача «жесткой»: выведите sol.stats["num_steps"] . Если это значение в 10–100 раз больше, чем вы ожидаете, задача является «жесткой», и вам нужен неявный решатель.
Результат: космологические выводы от начала до конца.
Теперь позвольте мне показать полное сравнение результатов вывода. Я запускаю оба конвейера с одного и того же плохого начального предположения (Ωₘ, H₀) = (0,10, 60), которое находится далеко от истинного значения (0,30, 70), и выполняю 350 шагов градиента.
- Конвейер обработки данных scipy: градиент из центральных конечных разностей, простой градиентный спуск, фиксированная скорость обучения.
- Конвейер обработки данных diffrax: градиент из autodiff, оптимизатор Adam с расписанием скорости обучения, изменяющимся по косинусному закону.
import optax # optimisers for JAX # Scale parameters so Adam can handle them equally # Om ~ 0.3, h = H0/100 ~ 0.7 -- both O(1) now def loss_scaled(theta_s): theta = jnp.array([theta_s[0], 100.0 * theta_s[1]]) return loss(theta) grad_scaled = jax.jit(jax.grad(loss_scaled)) schedule = optax.cosine_decay_schedule( init_value=0.05, decay_steps=350, alpha=0.04) opt = optax.adam(schedule) theta = jnp.array([0.10, 0.60]) # start far from truth state = opt.init(theta) for step in range(350): g = grad_scaled(theta) updates, state = opt.update(g, state) theta = optax.apply_updates(theta, updates) if (step + 1) % 50 == 0: print(f"Step {step+1}: Om={theta[0]:.3f} H0={100*theta[1]:.2f}")

В то время как конвейер diffrax восстанавливает физически обоснованные параметры, конвейер scipy не может одновременно изменять оба параметра — это классический пример неудачи градиентного спуска на задачах с недостаточным масштабированием. Adam обрабатывает это автоматически с помощью адаптивных скоростей обучения для каждого параметра, но Adam доступен только потому, что autodiff предоставляет мне точные градиенты для его передачи.
Три вещи, в которых я ошибся (чтобы вам не пришлось этого делать)

Предостережение 1: забудьте о 64-битной точности. JAX по умолчанию использует 32-битные числа с плавающей запятой. Если вы выйдете за пределы допустимых отклонений (rtol < 10⁻⁷), это может привести к очень странным результатам: в моем случае решателю ОДУ требуется 69 шагов в 32-битном режиме, но только 12 в 64-битном. Если еще больше ужесточить допуски, он может полностью выйти из строя. Решение простое — включите 64-битную точность перед тем, как что-либо делать:
jax.config.update("jax_enable_x64", True) # must be first
Предостережение 2: тестирование без предварительной подготовки. Первый вызов любой функции, помеченной атрибутом @jax.jit включает в себя одноразовое время компиляции около 90–100 мс. Если вы учтете это в своих измерениях, diffrax будет казаться медленнее, чем scipy, по неправильной причине. Решение состоит в том, чтобы один раз провести предварительную подготовку и отбросить результаты первого запуска:
_ = forward_diffrax(theta, z_obs).block_until_ready() # compile # NOW benchmark -- this is the real speed
Кроме того: JAX выполняет асинхронную отправку данных. Всегда вызывайте метод .block_until_ready() в циклах измерения времени, иначе вы будете измерять время отправки задания, а не его завершения.
Предостережение 3: ловушка порядка аргументов. scipy.odeint ожидает f(y, t) (сначала состояние, затем время). Почти все остальные функции ( solve_ivp , diffrax) ожидают f(t, y) . Если вы перенесете старый код odeint в diffrax без замены аргументов, вы в итоге решите другое ОДУ, и обычно ошибки не возникнет. Вы просто получите неправильный ответ.
Стоит ли вам переходить на другую систему?
Честный ответ таков: если вы решаете разовое ОДУ и вам не нужны градиенты, solve_ivp вполне подходит — нет необходимости изучать новый API. Но если вы занимаетесь выводом (повторные вычисления функции правдоподобия, градиенты параметров или пакетные решения), то переход стоит затраченных усилий.
| Ситуация | solve_ivp | одеинт | дифракция |
|---|---|---|---|
| Одноразовое решение, без выводов. | ✓ | ✓ | тоже хорошо |
| Вложенная выборка / MCMC | медленный | медленный | ДА |
| Необходимы градиенты | Только FD | Только FD | точный, бесплатный |
| Пакетная обработка сетки параметров | цикл for | цикл for | vmap |
| Жесткая система | Радау | авто (LSODA) | Kvaerno5 |
| SDE или нейронный ODE | нет | нет | ДА |
| GPU/TPU | нет | нет | ДА |
Сама миграция незначительна. Прямая модель изменяется примерно на шесть строк. Градиент появляется при добавлении еще одной строки. Остальная часть кода вывода остается идентичной.
Здесь необходимо отметить, что Diffrax не является «основанным на машинном обучении» в смысле использования нейронной сети. Это та же классическая математика Рунге-Кутты, написанная на JAX. «Ускорение машинного обучения» обеспечивается за счет JIT-компиляции и автодифференцирования — инструментов инфраструктуры из мира машинного обучения, применяемых к классическому численному решателю. Единственным действительно основанным на машинном обучении подходом был бы нейронный суррогат, который обучается θ → μ(z) на основе обучающих данных — это отдельная и более сложная тема.
Полный рабочий код
Всё вышеперечисленное в одном автономном скрипте ( pip install jax diffrax optax ):
""" flat_lcdm_inference.py Infer (Omega_m, H0) from 30 mock supernovae using diffrax + Adam. pip install jax diffrax optax """ import jax, jax.numpy as jnp, numpy as np import diffrax as dfx, optax from scipy.integrate import solve_ivp # only for generating mock data jax.config.update("jax_enable_x64", True) # -- Constants and data ----------------------------------------------- C_KMS = 299792.458 z_obs = jnp.linspace(0.05, 1.5, 30) SIGMA = 0.10 # Mock data at truth (Om=0.30, H0=70) def chi_np(Om, H0): sol = solve_ivp(lambda z, y: C_KMS/(H0*np.sqrt(Om*(1+z)**3+(1-Om))), (0, 1.5), [0.], t_eval=np.array(z_obs), rtol=1e-10) return sol.y[0] mu_true = 5*np.log10((1+np.array(z_obs))*chi_np(0.3, 70.)*1e5) mu_obs = jnp.array(mu_true + 0.10*np.random.default_rng(42).standard_normal(30)) # -- diffrax forward model -------------------------------------------- @jax.jit def forward(theta): Om, H0 = theta[0], theta[1] sol = dfx.diffeqsolve( dfx.ODETerm(lambda z, chi, a: C_KMS/(a[1]*jnp.sqrt(a[0]*(1+z)**3+(1-a[0])))), dfx.Tsit5(), t0=0., t1=1.5, dt0=1e-3, y0=jnp.array(0.), args=(Om, H0), saveat=dfx.SaveAt(ts=z_obs), stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10), max_steps=10_000, ).ys return 5*jnp.log10((1+z_obs)*sol*1e5) # -- Loss and gradient ------------------------------------------------ def loss(th_s): # optimise in scaled coords (Om, h=H0/100) mu = forward(jnp.array([th_s[0], 100.*th_s[1]])) return 0.5*jnp.sum(((mu - mu_obs)/SIGMA)**2) grad_fn = jax.jit(jax.grad(loss)) # Warm up the JIT compiler theta_init = jnp.array([0.10, 0.60]) _ = forward(jnp.array([0.3, 0.7])).block_until_ready() _ = grad_fn(theta_init).block_until_ready() # -- Adam optimiser with cosine LR schedule --------------------------- sched = optax.cosine_decay_schedule(init_value=0.05, decay_steps=350, alpha=0.04) opt = optax.adam(sched) theta = theta_init state = opt.init(theta) print(f"{'Step':>5} {'Om':>7} {'H0':>7} {'Loss':>8}") for step in range(350): g = grad_fn(theta) upd, state = opt.update(g, state) theta = optax.apply_updates(theta, upd) if (step + 1) % 70 == 0 or step == 0: L = float(loss(theta)) print(f"{step+1:5d} {float(theta[0]):7.4f} {100*float(theta[1]):7.3f} {L:8.2f}") Om_fit, H0_fit = float(theta[0]), 100*float(theta[1]) print(f"nFinal: Om = {Om_fit:.3f} H0 = {H0_fit:.2f}") print(f"Truth: Om = 0.300 H0 = 70.00")
Краткий обзор цифр
| Измерение | сципи | дифракция | Ускорение |
|---|---|---|---|
| Одиночный переадресованный звонок | 0,4 мс | 57 мкс | ~ 07× |
| Градиент (2 параметра) | 1,62 мс | 195 мкс | ~ 08× |
| 10⁵ переадресация звонков | 40 с | 5,9 с | ~ 07× |
| 10⁵ градиентных вызовов | ~98 с | ~19,6 с | ~ 05× |
| Итоговое значение Ωₘ (350 шагов) | 0,652 (неверно) | 0,270 | — |
| Финальный H₀ (350 шагов) | 60.10 (застрял) | 70.94 | — |
«Неправильный» результат в scipy не является ошибкой решателя — он отражает тот факт, что простой градиентный спуск с градиентами конечных разностей не может справиться с 200-кратным несоответствием масштабов между Ωₘ и H₀.
Заключительная мысль
Переход от прямой модели к дифракционной модели не изменил ни физику, ни метод вывода. Он изменил практическую осуществимость этого вывода вообще. Расчет с использованием вложенной выборки, который требовал больших временных затрат для прямой модели, теперь занимает менее минуты. Градиенты, которые раньше требовали 20 дополнительных вычислений на шаг, теперь стали практически бесплатными.
Освоить все тонкости удалось примерно за один день. Отладка в основном касалась особенностей 64-битной архитектуры и путаницы с прогревом JIT-компилятора. Результат оказался ощутимым и незамедлительным.
Если вы физик, использующий библиотеку scipy для многократных вычислений функции правдоподобия, и вы еще не знакомы с библиотекой diffrax, надеюсь, это даст вам повод это сделать.
Примечание по воспроизводимости: точные значения времени выполнения могут отличаться на вашем компьютере и даже между запусками на одном и том же компьютере. На моем Mac (Macbook Air M3 Base Model) время выполнения прямого вызова diffrax варьировалось от 55 мкс до 62 мкс в разных сессиях, а scipy — от 400 мкс до 407 мкс. Это нормально — тепловое состояние процессора, планирование операционной системы и состояние кэша памяти влияют на абсолютные значения на 10–15%. Стабильным остается соотношение: diffrax стабильно в 7–8 раз быстрее, чем scipy, в решении этой задачи. Важно именно соотношение, а не абсолютное время.
Код на Python, сгенерировавший все рисунки в этой статье, доступен по адресу: github.com/Samit1424/ODE_solver_comparison
Примечание: За исключением изображения, представленного на главной странице и созданного с помощью инструмента искусственного интеллекта, все остальные иллюстрации являются оригинальными работами автора.
Ссылки
[1] П. Киджер, О нейронных дифференциальных уравнениях, диссертация на соискание степени доктора философии, Оксфордский университет, 2021. docs.kidger.site/diffrax/
[2] RTQ Чен, Ю. Рубанова, Ж. Бетанкур, Д. Дювено, Нейронные обыкновенные дифференциальные уравнения, NeurIPS 2018.
Самит Гангули. Все материалы от Самита Гангули.
Источник: towardsdatascience.com

Добавить комментарий
Для отправки комментария вам необходимо авторизоваться.