NextStatNextStat

JAX: компиляция vs исполнение

Бенчмарк, который действительно нужен: в коротких научных ML-циклах задержка компиляции может доминировать над временем исполнения.

JAXЗадержка компиляцииMLБенчмарки

2026-02-08 · 6 мин. чтения


Многие ML-бенчмарки измеряют только пропускную способность в установившемся режиме. Это правильная метрика, если вы гоняете модель часами. Но в научных пайплайнах значительная часть работы происходит в коротких циклах: подбор гиперпараметров, повторяющиеся небольшие фиты, интерактивные итерации анализа, короткие обучения для ablations. В таких режимах задержка компиляции может доминировать в общей стоимости.

Аннотация. Нам нужен не бенчмарк «examples/sec в установившемся режиме», а измерение двух режимов: время до первого результата (TTFR) в новом процессе (cold start) и тёплая пропускная способность после прогрева компиляции и кэшей. Чтобы это можно было публиковать, каждый запуск мы рассматриваем как снимок: с сырыми распределениями, зафиксированным окружением и явной политикой кэша.

0.Репликабельность: минимальный запуск и артефакты

Эта статья опирается на seed-харнесс ML-сьюта в репозитории публичных бенчмарков. Он специально сделан config-first: публикуемый результат — это JSON-артефакты со схемой, таймингами по компонентам и манифестом окружения.

Минимальный запуск из монорепозитория

Если вы запускаете из монорепозитория nextstat.io (и у вас уже есть venv), вот минимально воспроизводимый сценарий. Он создаёт каталог out/ml, пишет артефакты по кейсам и генерирует готовый фрагмент README-таблицы.

bash
./.venv/bin/python benchmarks/nextstat-public-benchmarks/suites/ml/suite.py \
  --deterministic --out-dir benchmarks/nextstat-public-benchmarks/out/ml

./.venv/bin/python benchmarks/nextstat-public-benchmarks/suites/ml/report.py \
  --suite benchmarks/nextstat-public-benchmarks/out/ml/ml_suite.json \
  --out benchmarks/nextstat-public-benchmarks/out/README_snippet_ml.md

Опционально: включить JAX backend

Seed-харнесс поддерживает numpy всегда и включает опциональные кейсыjax_jit_*. Если JAX не установлен, эти кейсы переходят в статус warn и это фиксируется в JSON.

bash
pip install -r benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cpu.txt

# для GPU-окружений (CUDA 12):
pip install -r benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cuda12.txt

Артефакты и схемы

АртефактСхемаНазначение
out/ml/cases/*.jsonnextstat.ml_benchmark_result.v1Один кейс: cold/warm распределения, конфиг, окружение, статус.
out/ml/ml_suite.jsonnextstat.ml_benchmark_suite_result.v1Индекс всех кейсов + ссылки на файлы + sha256.
out/README_snippet_ml.mdКороткая сводка для людей (генерируется из suite JSON; не заменяет JSON‑артефакты).

Каноничные JSON Schema для артефактов ML‑сьюта (seed‑контракт): benchmarks/nextstat-public-benchmarks/manifests/schema/ml_benchmark_result_v1.schema.json и benchmarks/nextstat-public-benchmarks/manifests/schema/ml_benchmark_suite_result_v1.schema.json. Каждый кейс фиксирует config (backend/workload/n/dtype, cold_runs, warm_iters, cache_policy),meta (Python/OS/версии NumPy/JAX) и, когда применимо, device (platform/kind/count). Политика кэша реализуется через переменные окружения (например, JAX_COMPILATION_CACHE_DIR), а в публикуемом артефакте закрепляется как значение cache_policy.

1.Два режима, две метрики

Режим A: cold start / время до первого результата

Включает: время импорта, трассировку графа, компиляцию и первое выполнение. Это метрика, важная для коротких запусков.

Режим B: тёплая пропускная способность

Установившееся выполнение после прогрева кэшей компиляции, загрузки ядер и в уже запущенном процессе. Это важно для длинных прогонов.

Публикация одного числа без указания режима не имеет смысла.


2.Определения: что именно мы измеряем

Для публикуемых запусков мы сообщаем тайминги по компонентам, а не одно агрегированное число:

МетрикаЧто измеряет
import_sИмпорт и инициализация backend’а в новом процессе
first_call_sПервый вызов: трассировка + компиляция (если есть) + первое выполнение
ttfr_sTime‑to‑first‑result: сумма import_s + first_call_s
warm.calls_sРаспределение по тёплым вызовам в отдельном процессе после компиляции (median/p95)

Для публикуемых кейсов важны распределения (median/p95), а не единичные значения. В seed‑контракте TTFR фиксируется как ttfr_s.

Для GPU-бэкендов нужно явно синхронизироваться (или «блокироваться до готовности»), иначе вы измерите только CPU-dispatch.

В seed-харнессе для JAX это зафиксировано как часть протокола: результат каждого вызова блокируется через block_until_ready() (best-effort).


3.Протокол бенчмарка (что должно быть явно задано)

  • Новый это процесс (fresh process) или долгоживущий (persistent)
  • Состояние кэша (чистый кэш vs прогретый кэш)
  • Размеры датасета и раскладка (layout) данных
  • Что включено/исключено из окна измерения
  • Политика синхронизации для GPU/async рантаймов (иначе измеряется dispatch)
  • Число прогонов cold start (--cold-runs) и число тёплых итераций (--warm-iters)

Для бенчмарков холодного старта единственный честный базовый сценарий это новый процесс с явно объявленной политикой кэша.


4.Политика кэша: «cold» означает разное

Результаты по задержке компиляции крайне чувствительны к кэшированию. Поэтому вместо притворства, что существует один «cold start», мы публикуем явные режимы:

  • Холодный процесс, тёплый кэш — новый Python-процесс, но допускается постоянный кэш компиляции
  • Холодный процесс, холодный кэш — новый Python-процесс и пустой каталог кэша компиляции (когда это осуществимо)
  • Тёплый процесс — тот же долгоживущий процесс (типично для интерактивного анализа)

Если мы не можем надёжно очистить кэш (потому что рантайм хранит его вне нашего контроля), мы рассматриваем это как ограничение и публикуем его как часть условий эксперимента.

В seed-харнессе кэш компиляции JAX задаётся явно через JAX_COMPILATION_CACHE_DIR, а для режима process_cold каталог кэша best-effort удаляется перед каждым fresh‑process прогоном.


5.Что мы будем публиковать

Для каждого снимка:

  • Распределения холодного старта (не просто один тайминг)
  • Распределения тёплой пропускной способности
  • Манифест базовой линии (версии, железо, настройки)
  • Политика кэша и версия харнесса

Спецификация публикации: публичные бенчмарки.


6.Почему это часть бенчмарк-программы NextStat

Ценность NextStat не в том, чтобы «выиграть микробенчмарк». Ценность в том, чтобы целые научные пайплайны становились быстрее, воспроизводимее и проще для аудита. Компромиссы compile vs execution это часть этой истории, когда ML стоит внутри цикла.


7.Ограничения и режимы отказа (что считается «корректностью»)

  • GPU тайминги требуют строгой синхронизации. Без block_until_ready() (или эквивалента) измерение становится некорректным.
  • «Холодный кэш» не всегда достижим. Если рантайм кэширует вне контролируемого каталога, это фиксируется как условие эксперимента, а не замалчивается.
  • Опциональные зависимости — часть протокола. Если JAX отсутствует, кейсы jax_jit_* должны быть warn с причиной missing_dependency: jax, а не «тихо пропущены».
  • Запросили GPU — убедитесь, что это реально GPU. Seed-харнесс валидирует, что при platform=gpu JAX действительно выбрал GPU; иначе результат помечается как warn с причиной gpu_unavailable.
  • Ошибки воркера должны быть видны. Если отдельный процесс‑воркер не смог записать корректный JSON или вернул ошибку, кейс помечается как failed с причиной worker_failed.
  • Публиковать нужно распределения, не «одно число». Минимально: median + p95 по cold TTFR и median по тёплым вызовам.

Минимальные ссылки


A.Приложение: состояние seed-харнесса (на сегодня)

В seed-репозитории публичных бенчмарков уже есть запускаемый ML-сьют в suites/ml/:

  • измеряет распределение TTFR в режиме fresh process (несколько независимых запусков воркера)
  • измеряет распределение тёплых вызовов как warm.calls_s в отдельном процессе после компиляции
  • по умолчанию включает numpy и опциональные jax_jit_cpu/jax_jit_gpu кейсы
  • если JAX отсутствует, JAX‑кейсы публикуются как warn с причиной missing_dependency: jax

Связанное чтение