NextStatNextStat

NextStat v0.9.6: нулевой JIT-налог, ESS/grad и сходимость

Финальные канонические результаты со строгим разделением бэкендов: Metal, CUDA V100 и EPYC CPU.

NextStatMAMSLAPSBlackJAXБенчмаркиБайесовский вывод

2026-02-18 · 18 мин


TL;DR (только финальные числа v0.9.6)

  • LAPS Metal: финальная матрица — 8/8 ok, Div%=0 по всем кейсам.
  • CUDA V100 паритет (3-seed медиана, canonical): NextStat LAPS сохраняет нулевой JIT-налог (cold ≈ warm), тогда как cold-start BlackJAX — 11.8–90.6 с в данной конфигурации.
  • Time-to-result в реальных циклах редактирования: при изменении структуры/формы модели (приоры, параметризация, размерности) JAX/XLA-воркфлоу обычно рекомпилируются; эта задержка компиляции повторяется на каждой итерации. AOT-ядра NextStat сохраняют латентность итерации близкой к warm-path.
  • ESS/grad (V100 сэмплирование, report-chain нормализация): на совпадающих целях NS LAPS в диапазоне от 2.46× до 45.11× vs BlackJAX в этом canonical прогоне.
  • Исправлена честность CPU funnel: FunnelNcpModel (NCP) — 6/6 ok по 3 сидам на EPYC для MAMS и NUTS; centered funnel остаётся известным патологическим контролем.

0.MAMS и LAPS за одну страницу (для пользователей NUTS)

Если вы уже знаете NUTS, ключевая ментальная модель:

  • NUTS — динамический HMC с адаптивным расширением дерева на каждом переходе (адаптивная длина пути на каждом шаге).
  • MAMS (Metropolis Adjusted Microcanonical Sampler) использует микроканоническую/изокинетическую динамику с фиксированной длиной траектории в прекондиционированном пространстве.
  • LAPS — массивно параллельный GPU/Metal-путь сэмплирования, который применяет MAMS-динамику на тысячах цепей с аппаратно-ориентированным исполнением.

Что нового относительно стандартных реализаций NUTS

  • Ядра переходов фиксированной формы проще эффективно исполнять на SIMD/GPU-бэкендах, чем рекурсивное построение деревьев.
  • Сэмплер построен для очень большого параллелизма цепей (4096 цепей — штатная рабочая точка в этом отчёте).
  • В этом релизе MAMS/LAPS используют микроканонические диагностические гейты и явное раскрытие NCP-vs-centered для funnel-подобных геометрий.

Где этот дизайн сильнее всего

  • Time-to-result в итеративных воркфлоу (нет повторной JIT-компиляции в нашем AOT-пути).
  • Иерархические и мультимасштабные цели, где мы наблюдаем высокий ESS/grad в matched-target запусках.
  • Локальные пути ускорения (Metal) и серверные GPU-пути (CUDA) с единой семантикой сэмплера.

Где NUTS или XLA-тяжёлые стеки могут выиграть

  • Простые низкоразмерные цели после прогрева компиляции (возможна более высокая raw warm-path пропускная способность).
  • Некоторые гладкие концентрированные постериоры на CPU (например, large-n логистическая регрессия в этом отчёте).

Поэтому этот отчёт не заявляет универсального победителя среди сэмплеров — он документирует, где каждая модель исполнения сильнее, с явными оговорками о честности.


1.Протокол и правила честности

Разделение бэкендов (без смешивания)

  • Результаты LAPS Metal публикуются отдельно от LAPS CUDA.
  • Результаты CPU (EPYC) публикуются отдельно от GPU.
  • Ни одна таблица не смешивает значения Metal и CUDA.

Мульти-ран агрегация (против cherry-pick)

  • Финальные таблицы сравнения CPU/GPU используют три независимых сида: 42, 123, 777.
  • Мы публикуем медиану как основное число (устойчиво к выбросам), с mean ± std где полезно.
  • Значения одиночных сидов хранятся только в сырых артефактах, не как заголовочные утверждения.

Time-to-result в итеративном моделировании

  • Байесовская работа — это цикл редактирования (изменить приор, добавить ковариату, репараметризовать, перезапустить).
  • В JIT/XLA-стеках изменения графа/формы часто инвалидируют скомпилированные исполняемые файлы и запускают рекомпиляцию, поэтому cold-start затраты повторяются при исследовании.
  • NextStat использует AOT-компилированные Rust/CUDA-ядра, поэтому wall-clock латентность итерации остаётся близкой к warm-path даже при эволюции моделей.

Раскрытие параметризации funnel

Для std_normal, eight_schools и glm_logistic оба движка сэмплируют одну и ту же целевую плотность (идентичные log-density функции).

Для neal_funnel_10d параметризации различаются в V100 parity run:

  • NS LAPS сэмплирует Non-Centered Parameterization (NCP): log p(v, z) = -v²/18 - 0.5 · Σ(z_i²).
  • BlackJAX сэмплирует centered-параметризацию: log p(v, x) = -v²/18 - 0.5·exp(-v)·Σ(x_i²) - 0.5·(d-1)·v.

Это не одна и та же задача оптимизации. Centered funnel имеет позиционно-зависимую кривизну, фундаментально более сложную для сэмплеров с фиксированной метрикой. Строки neal_funnel в секции 3 и Приложении отражают различия как алгоритма, так и параметризации, и не должны интерпретироваться как сравнение «при прочих равных». Они сохранены для демонстрации поведения сходимости (NS сходится, BlackJAX нет), но исключены из заголовочных ESS/grad утверждений.

  • На CPU теперь есть явная FunnelNcpModel для честных NCP-сравнений (секция 6).
  • Centered FunnelModel остаётся отдельным контролем сложной геометрии.

Алгоритмические изменения в v0.9.6

  • MAMS использует eps_jitter=0.1 по умолчанию (±10% равномерный шум размера шага на переход), ломая периодичность фиксированного L и улучшая tail ESS на периодических целях вроде std_normal.
  • Длина траектории по умолчанию: L = √d в прекондиционированном пространстве (Robnik et al. 2023).

Конфигурация BlackJAX (V100 parity run)

Для предупреждения вопросов о мисконфигурации конкурента, полная конфигурация BlackJAX:

  • Сэмплер: blackjax.adjusted_mclmc с интегратором isokinetic_mclachlan.
  • Прогрев: встроенная blackjax.adjusted_mclmc_find_L_and_step_size (500 итераций, single-chain прогрев, target_accept=0.9, diagonal_preconditioning=True).
  • Длина траектории: настроена встроенным прогревом BlackJAX (L, step_size), затем n_steps = round(L / step_size).
  • Mass matrix: сэмплирование использует настроенную inverse_mass_matrix из прогрева BlackJAX.
  • Multi-chain: 4096 цепей, jax.vmap(run_chain), block_until_ready() + device_get() для честного host-side тайминга.
  • Cold/warm: cold = первый вызов vmap (включает XLA-компиляцию); warm = второй вызов с кэшированным JIT.
  • Init: цепи инициализированы вокруг прогретого single-chain состояния (warmed_state.position + N(0, 0.5)).
  • Seed: 42 (cold), 1042 (warm).
  • Seeds: 42, 123, 777 (для каждого сида, warm run использует seed + 1000 key path).
  • Исходник: benchmarks/gpu_triple_bench.py, функции _blackjax_builtin_warmup() и bench_blackjax().

Конфигурация V100 parity run (NS LAPS, 3 seeds)

  • n_chains=4096, n_warmup=500, n_samples=1000, report_chains=256, seeds=42/123/777.
  • Секции 3/4 публикуют медиану по 3 сидам.
  • R̂ вычислена по 256 report chains (512 полу-цепей), давая существенно более точные диагностики по сравнению с предыдущим 64-chain репортингом.

2.Каноничные результаты LAPS Metal (финальные)

Железо: Apple M5, 10 GPU-ядер, 24 ГБ unified memory.

МодельЦепиw+sWall (с)ESS/sDiv%Статус
std_normal_10d256100+1000.141.1753 6800.0ok
std_normal_10d_4096ch4096200+5000.091.03812 5850.0ok
eight_schools4096500+20000.251.007124 7050.0ok
neal_funnel_10d4096500+20000.311.00622 7910.0ok
neal_funnel_riemannian4096500+20000.271.01014 1420.0ok
glm_logistic_n200_p64096500+20002.151.0054 6470.0ok
glm_logistic_n1000_p204096500+200034.321.0102480.0ok
glm_logistic_n5000_p204096500+200059.061.0151100.0ok

Примечание: строка 256-chain std_normal_10d (R̂ 1.175) демонстрирует минимально жизнеспособное количество цепей; строка 4096-chain — каноничная конфигурация бенчмарка.

На практике это показывает, что локальный Apple Silicon способен запускать массивно параллельные inference-нагрузки дата-центрового масштаба с надёжной диагностикой сходимости — без настройки CUDA и без задержки JIT-компиляции.

Политика quality gates для этой матрицы:

  • MAMS/LAPS: QualityGates::microcanonical() (EBFMI — только предупреждение).
  • NUTS: строгий дефолтный gate сохранён (EBFMI fail < 0.20).

3.CUDA V100 parity run (LAPS vs BlackJAX, 3-seed медиана)

Железо: Tesla V100-PCIE-16GB.

МодельДвижокCold (с)Warm (с)min ESSESS/s (warm)
std_normal_10dNS LAPS GPU1.5540.240159 753680 7851.0062
std_normal_10dBlackJAX GPU14.0640.2251 7717 8471.1010
eight_schoolsNS LAPS GPU1.4250.24175 682314 4761.0065
eight_schoolsBlackJAX GPU11.7690.34628 02075 2551.0080
neal_funnel_10dNS LAPS GPU1.4040.25954 768211 5811.0083
neal_funnel_10dBlackJAX GPU15.5170.4127061 7591.2732
glm_logisticNS LAPS GPU23.7919.25477 8528 4151.0086
glm_logisticBlackJAX GPU90.61577.76519 5832261.0122

Как читать эту таблицу

  • Нулевой JIT-налог: NS LAPS cold остаётся близким к warm (AOT-компилированный Rust/CUDA). Cold-start BlackJAX существенно выше в этой конфигурации (11.8–90.6 с).
  • Warm-start пропускная способность (canonical прогон): NS LAPS выше на всех совпадающих целях в этой конфигурации.
  • neal_funnel — не сравнение «при прочих равных» (см. секцию 1: NS сэмплирует NCP, BlackJAX — centered). В этих 3 сидах R̂ centered-funnel BlackJAX в диапазоне 1.260–1.275 — это ожидаемо из-за сложности параметризации, а не дефект сэмплера.

4.ESS/grad на V100 (фаза сэмплирования, только совпадающие цели, 3-seed медиана)

МодельNS LAPS ESS/gradBlackJAX ESS/gradОтношение (NS/BJ)
std_normal_10d0.3120170.00691745.11×
eight_schools0.0985440.0401042.46×
glm_logistic0.1013700.00263838.43×

neal_funnel исключён из таблицы, так как два движка сэмплируют разные параметризации (см. секцию 1).

Основной вклад в изменение по сравнению с ранними драфтами — нормализация знаменателя: оба движка теперь вычисляют ESS/grad на одном бюджете report_chains.

Практическая интерпретация для этого canonical прогона:

  • NS LAPS достигает более высокого ESS/grad на всех совпадающих целях в этом отчёте.
  • glm_logistic остаётся самой дорогой целью для обоих движков по абсолютному wall time.

5.Верификация качества LAPS на V100 (report_chains=256)

Отдельный запуск с более строгой диагностикой (report_chains=256 → 512 полу-цепей → SE(R̂) ≈ 0.015).

МодельR̂ maxESS_tail minE-BFMIСтатус
StdNormal 10d1.017518 9471.035ok
NealFunnel NCP 10d1.012648 2020.970ok
GLM n=5000 p=201.014949 6600.863ok
GLM n=200 p=61.004455 4230.449ok
NealFunnel centered 10d1.29142570.000fail (ожидаемый контроль)

Это подтверждает, что сходимость LAPS надёжна при достаточном количестве диагностических цепей. Значения R̂ из parity-run (секция 3, report_chains=256) напрямую сопоставимы с quality run.


6.CPU EPYC (MAMS vs NUTS) и исправление паритета funnel

Железо: AMD EPYC 7502P, 32 ядра / 64 потока, 128 ГБ RAM (Hetzner dedicated).

Сводка EPYC multi-seed (42/123/777, 3-run агрегат)

Конфигурация: n_chains=4, n_warmup=1000, n_samples=1000, eps_jitter=0.1.

МодельMAMS ESS/s (медиана)MAMS (mean ± std)NUTS ESS/s (медиана)NUTS (mean ± std)Отношение
std_normal_d2129 592137 761 ± 75 444200 841200 329 ± 13 4600.645
std_normal_d10100 420103 641 ± 4 69285 15995 604 ± 15 8151.179
std_normal_d5013 00713 150 ± 86728 30526 113 ± 3 6380.460
eight_schools98 20193 408 ± 8 22748 57746 018 ± 5 7812.022
logreg_n1000_p10714711 ± 103 8963 914 ± 280.183
logreg_n5000_p203736 ± 4186190 ± 110.200

Наблюдаемый паттерн в этой real-run матрице:

Кейсdimn_dataОтношение MAMS/NUTSЛидер
std_normal_d220.645NUTS
eight_schools1082.022MAMS
std_normal_d10101.179MAMS
std_normal_d50500.460NUTS
logreg_n1000_p101010000.183NUTS
logreg_n5000_p202050000.200NUTS

Почему large-n логистическая регрессия благоприятствует NUTS

  • Стоимость градиента масштабируется как O(n·p) на leapfrog-шаг; при n=5000, p=20 каждый дополнительный шаг дорог.
  • NUTS может завершать траектории раньше через U-turn, тогда как MAMS использует фиксированную длину траектории в прекондиционированном пространстве.
  • По мере роста n геометрия постериора приближается к хорошо обусловленному гауссиану; это сильный режим для NUTS с адаптивной длиной пути.

Практическая рекомендация

  • Предпочитайте MAMS для иерархических / мультимасштабных геометрий.
  • Предпочитайте NUTS для large-n GLM-подобных постериоров на CPU.
  • Разумное продуктовое направление — явная эвристика method="auto" (напр.: GLM с большим n → NUTS; иерархические/funnel-подобные цели → MAMS), с сохранением ручного оверрайда.

Контроль параметризации funnel (EPYC, 3 сида)

Конфигурация: n_chains=4, n_warmup=1000, n_samples=1000.

MAMS:

МодельSeedESS_tailEBFMIСтатус
Centered (FunnelModel)421.0785221n/aok
Centered (FunnelModel)1231.035331n/afail
Centered (FunnelModel)7771.0781244n/aok
NCP (FunnelNcpModel)421.00671 914n/aok
NCP (FunnelNcpModel)1231.01001 897n/aok
NCP (FunnelNcpModel)7771.00481 924n/aok

NUTS:

МодельSeedESS_tailEBFMIСтатус
Centered (FunnelModel)422.384414n/afail
Centered (FunnelModel)1231.363672n/afail
Centered (FunnelModel)7771.948017n/afail
NCP (FunnelNcpModel)421.00262 516n/aok
NCP (FunnelNcpModel)1231.00271 604n/aok
NCP (FunnelNcpModel)7771.00242 385n/aok

Интерпретация:

  • NCP — 6/6 ok по всем сидам для MAMS и NUTS. ESS_tail в диапазоне 1 604–2 516 (NUTS) и 1 897–1 924 (MAMS).
  • Centered — 3/3 fail для NUTS и 1/3 fail для MAMS.
  • Предыдущее несоответствие CPU funnel было методологическим (centered vs NCP), а не проблемой «CPU слабый».
  • FunnelNcpModel — рекомендуемая параметризация бенчмарка для CPU/GPU паритета.
  • Centered FunnelModel сохранён как известный патологический контроль; это демонстрация ограничений, а не регрессия продукта.
  • В этих артефактах EPYC funnel-control EBFMI не экспортирован (n/a в таблицах), поэтому pass/fail основан на R̂/ESS quality gates.

7.Воспроизводимость и метаданные окружения

  • JSON бенчмарка V100 содержит top-level снимок environment (python, jax, cuda, gpu, версии пакетов).
  • Сьют EPYC хранит метаданные железа/конфига/сида и метрики по кейсам; полный снимок окружения на уровне пакетов сейчас только в V100 parity JSON.

Артефакты (все в docs/blog/artifacts/v096-zero-jit-tax/):

  • Матрица V100 3-seed (canonical): v100-multi-seed-matrix-canonical.json
  • Графические данные V100 (canonical): v100-parity-chart-data-canonical.csv, v100-essgrad-ratio-canonical.csv
  • Сырые V100 3-seed (canonical): v100_v096_builtinwarmup_3seed_20260218T224654Z/seed_42/gpu_triple_bench.json, seed_123/..., seed_777/...
  • V100 funnel addendum: v100_ns_funnel_3seed_20260218T231337Z/*, v100_bj_funnel_builtin3seed_20260218T231204Z/*
  • V100 quality run: v100-quality-report256-5models.json
  • V100 + EPYC refresh note: 2026-02-17-v096-refresh-v100-epyc.md
  • EPYC multi-seed матрица: epyc-multi-seed-matrix.json
  • Выход EPYC сьюта: epyc-mams-suite.json
  • EPYC funnel-control: epyc-funnel-control-3seed.json

A.Приложение: V100 neal_funnel (разные параметризации)

Сохранено для прозрачности. Эти строки сравнивают NS LAPS (NCP) с BlackJAX (centered) — это не сравнение «при прочих равных».

МетрикаNS LAPS (NCP)BlackJAX (centered)
Cold (с)1.40415.517
Warm (с)0.2590.412
min_ESS54 768706
ESS/s (warm)211 5811 759
1.00831.2732
ESS/grad0.0713120.000710

Несходимость BlackJAX на centered funnel ожидаема (см. секцию 6: даже NUTS проваливается 3/3 на centered funnel с 4 цепями и стандартным бюджетом). Это сравнение прежде всего демонстрирует, что дефолтная NCP-диспатч NS производит сходящиеся результаты там, где centered-параметризация не сходится.


Ссылки

  • Robnik, Cohn-Gordon, Seljak. Metropolis Adjusted Microcanonical Hamiltonian Monte Carlo (MAMS). arXiv:2503.01707
  • BlackJAX. Composable Bayesian inference in JAX. arXiv:2402.10797

Связанное чтение