JAX: компиляция vs исполнение
Бенчмарк, который действительно нужен: в коротких научных ML-циклах задержка компиляции может доминировать над временем исполнения.
2026-02-08 · 6 мин. чтения
Многие ML-бенчмарки измеряют только пропускную способность в установившемся режиме. Это правильная метрика, если вы гоняете модель часами. Но в научных пайплайнах значительная часть работы происходит в коротких циклах: подбор гиперпараметров, повторяющиеся небольшие фиты, интерактивные итерации анализа, короткие обучения для ablations. В таких режимах задержка компиляции может доминировать в общей стоимости.
Аннотация. Нам нужен не бенчмарк «examples/sec в установившемся режиме», а измерение двух режимов: время до первого результата (TTFR) в новом процессе (cold start) и тёплая пропускная способность после прогрева компиляции и кэшей. Чтобы это можно было публиковать, каждый запуск мы рассматриваем как снимок: с сырыми распределениями, зафиксированным окружением и явной политикой кэша.
0.Репликабельность: минимальный запуск и артефакты
Эта статья опирается на seed-харнесс ML-сьюта в репозитории публичных бенчмарков. Он специально сделан config-first: публикуемый результат — это JSON-артефакты со схемой, таймингами по компонентам и манифестом окружения.
Минимальный запуск из монорепозитория
Если вы запускаете из монорепозитория nextstat.io (и у вас уже есть venv), вот минимально воспроизводимый сценарий. Он создаёт каталог out/ml, пишет артефакты по кейсам и генерирует готовый фрагмент README-таблицы.
./.venv/bin/python benchmarks/nextstat-public-benchmarks/suites/ml/suite.py \
--deterministic --out-dir benchmarks/nextstat-public-benchmarks/out/ml
./.venv/bin/python benchmarks/nextstat-public-benchmarks/suites/ml/report.py \
--suite benchmarks/nextstat-public-benchmarks/out/ml/ml_suite.json \
--out benchmarks/nextstat-public-benchmarks/out/README_snippet_ml.mdОпционально: включить JAX backend
Seed-харнесс поддерживает numpy всегда и включает опциональные кейсыjax_jit_*. Если JAX не установлен, эти кейсы переходят в статус warn и это фиксируется в JSON.
pip install -r benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cpu.txt
# для GPU-окружений (CUDA 12):
pip install -r benchmarks/nextstat-public-benchmarks/env/python/requirements-ml-jax-cuda12.txtАртефакты и схемы
| Артефакт | Схема | Назначение |
|---|---|---|
| out/ml/cases/*.json | nextstat.ml_benchmark_result.v1 | Один кейс: cold/warm распределения, конфиг, окружение, статус. |
| out/ml/ml_suite.json | nextstat.ml_benchmark_suite_result.v1 | Индекс всех кейсов + ссылки на файлы + sha256. |
| out/README_snippet_ml.md | — | Короткая сводка для людей (генерируется из suite JSON; не заменяет JSON‑артефакты). |
Каноничные JSON Schema для артефактов ML‑сьюта (seed‑контракт): benchmarks/nextstat-public-benchmarks/manifests/schema/ml_benchmark_result_v1.schema.json и benchmarks/nextstat-public-benchmarks/manifests/schema/ml_benchmark_suite_result_v1.schema.json. Каждый кейс фиксирует config (backend/workload/n/dtype, cold_runs, warm_iters, cache_policy),meta (Python/OS/версии NumPy/JAX) и, когда применимо, device (platform/kind/count). Политика кэша реализуется через переменные окружения (например, JAX_COMPILATION_CACHE_DIR), а в публикуемом артефакте закрепляется как значение cache_policy.
1.Два режима, две метрики
Режим A: cold start / время до первого результата
Включает: время импорта, трассировку графа, компиляцию и первое выполнение. Это метрика, важная для коротких запусков.
Режим B: тёплая пропускная способность
Установившееся выполнение после прогрева кэшей компиляции, загрузки ядер и в уже запущенном процессе. Это важно для длинных прогонов.
Публикация одного числа без указания режима не имеет смысла.
2.Определения: что именно мы измеряем
Для публикуемых запусков мы сообщаем тайминги по компонентам, а не одно агрегированное число:
| Метрика | Что измеряет |
|---|---|
| import_s | Импорт и инициализация backend’а в новом процессе |
| first_call_s | Первый вызов: трассировка + компиляция (если есть) + первое выполнение |
| ttfr_s | Time‑to‑first‑result: сумма import_s + first_call_s |
| warm.calls_s | Распределение по тёплым вызовам в отдельном процессе после компиляции (median/p95) |
Для публикуемых кейсов важны распределения (median/p95), а не единичные значения. В seed‑контракте TTFR фиксируется как ttfr_s.
Для GPU-бэкендов нужно явно синхронизироваться (или «блокироваться до готовности»), иначе вы измерите только CPU-dispatch.
В seed-харнессе для JAX это зафиксировано как часть протокола: результат каждого вызова блокируется через block_until_ready() (best-effort).
3.Протокол бенчмарка (что должно быть явно задано)
- ›Новый это процесс (fresh process) или долгоживущий (persistent)
- ›Состояние кэша (чистый кэш vs прогретый кэш)
- ›Размеры датасета и раскладка (layout) данных
- ›Что включено/исключено из окна измерения
- ›Политика синхронизации для GPU/async рантаймов (иначе измеряется dispatch)
- ›Число прогонов cold start (
--cold-runs) и число тёплых итераций (--warm-iters)
Для бенчмарков холодного старта единственный честный базовый сценарий это новый процесс с явно объявленной политикой кэша.
4.Политика кэша: «cold» означает разное
Результаты по задержке компиляции крайне чувствительны к кэшированию. Поэтому вместо притворства, что существует один «cold start», мы публикуем явные режимы:
- ›Холодный процесс, тёплый кэш — новый Python-процесс, но допускается постоянный кэш компиляции
- ›Холодный процесс, холодный кэш — новый Python-процесс и пустой каталог кэша компиляции (когда это осуществимо)
- ›Тёплый процесс — тот же долгоживущий процесс (типично для интерактивного анализа)
Если мы не можем надёжно очистить кэш (потому что рантайм хранит его вне нашего контроля), мы рассматриваем это как ограничение и публикуем его как часть условий эксперимента.
В seed-харнессе кэш компиляции JAX задаётся явно через JAX_COMPILATION_CACHE_DIR, а для режима process_cold каталог кэша best-effort удаляется перед каждым fresh‑process прогоном.
5.Что мы будем публиковать
Для каждого снимка:
- ›Распределения холодного старта (не просто один тайминг)
- ›Распределения тёплой пропускной способности
- ›Манифест базовой линии (версии, железо, настройки)
- ›Политика кэша и версия харнесса
Спецификация публикации: публичные бенчмарки.
6.Почему это часть бенчмарк-программы NextStat
Ценность NextStat не в том, чтобы «выиграть микробенчмарк». Ценность в том, чтобы целые научные пайплайны становились быстрее, воспроизводимее и проще для аудита. Компромиссы compile vs execution это часть этой истории, когда ML стоит внутри цикла.
7.Ограничения и режимы отказа (что считается «корректностью»)
- ›GPU тайминги требуют строгой синхронизации. Без
block_until_ready()(или эквивалента) измерение становится некорректным. - ›«Холодный кэш» не всегда достижим. Если рантайм кэширует вне контролируемого каталога, это фиксируется как условие эксперимента, а не замалчивается.
- ›Опциональные зависимости — часть протокола. Если JAX отсутствует, кейсы
jax_jit_*должны бытьwarnс причинойmissing_dependency: jax, а не «тихо пропущены». - ›Запросили GPU — убедитесь, что это реально GPU. Seed-харнесс валидирует, что при
platform=gpuJAX действительно выбрал GPU; иначе результат помечается какwarnс причинойgpu_unavailable. - ›Ошибки воркера должны быть видны. Если отдельный процесс‑воркер не смог записать корректный JSON или вернул ошибку, кейс помечается как
failedс причинойworker_failed. - ›Публиковать нужно распределения, не «одно число». Минимально: median + p95 по cold TTFR и median по тёплым вызовам.
Минимальные ссылки
- ›JAX documentation — описание JIT, кэширования и модели исполнения.
- ›Спецификация публичных бенчмарков — требования к публикации снапшотов (артефакты, схемы, манифесты).
A.Приложение: состояние seed-харнесса (на сегодня)
В seed-репозитории публичных бенчмарков уже есть запускаемый ML-сьют в suites/ml/:
- ›измеряет распределение TTFR в режиме fresh process (несколько независимых запусков воркера)
- ›измеряет распределение тёплых вызовов как
warm.calls_sв отдельном процессе после компиляции - ›по умолчанию включает
numpyи опциональныеjax_jit_cpu/jax_jit_gpuкейсы - ›если JAX отсутствует, JAX‑кейсы публикуются как
warnс причинойmissing_dependency: jax
