NextStatNextStat

Нейронная суррогатная дистилляция

Обучение наносекундных суррогатов из правдоподобия NextStat

NextStat может служить высокоточным оракулом для обучения нейросетевых суррогатов поверхности правдоподобия HistFactory. Суррогат работает за наносекунды вместо миллисекунд, обеспечивая MCMC в реальном времени, глобальные EFT-фиты или интерактивные дашборды.

Рабочий процесс

Sample parameter space (Sobol / LHS) → Evaluate NLL + gradient via NextStat GPU → Train a small MLP → Deploy the surrogate for real-time inference. NextStat предоставляет истину; суррогат обеспечивает скорость.

Быстрый старт

import nextstat
from nextstat.distill import generate_dataset, train_mlp_surrogate, predict_nll

model = nextstat.from_pyhf(workspace_json)

# 1. Generate 100k (params, NLL, gradient) tuples
ds = generate_dataset(model, n_samples=100_000, method="sobol")
print(f"{ds.n_samples} points, {ds.n_params} params")
print(f"NLL range: [{ds.nll.min():.1f}, {ds.nll.max():.1f}]")

# 2. Train a surrogate MLP (built-in convenience)
surrogate = train_mlp_surrogate(ds, epochs=100, device="cuda")

# 3. Predict NLL at new points (nanoseconds per eval)
import numpy as np
test_params = np.array(model.parameter_init())
pred_nll = predict_nll(surrogate, test_params)
print(f"Surrogate NLL: {pred_nll:.2f}")

Методы сэмплирования

MethodПокрытиеЛучше всего для
sobolКвазислучайный, малое расхождениеПо умолчанию. Лучшее покрытие минимумом точек.
lhsСтратифицированный по измерениямХорошее покрытие, не требует степени двойки.
uniformЧисто случайныйБазовое сравнение.
gaussianКонцентрация вблизи MLEТонкая настройка вблизи минимума. Фокусные суррогаты.

Пользовательский обучающий цикл

Для продакшн-использования конвертируйте набор данных в PyTorch и напишите своё обучение:

from nextstat.distill import generate_dataset, to_torch_dataset
import torch
import torch.nn.functional as F

ds = generate_dataset(model, n_samples=500_000, method="sobol")
train_ds = to_torch_dataset(ds)
loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True)

surrogate = torch.nn.Sequential(
    torch.nn.Linear(ds.n_params, 256), torch.nn.SiLU(),
    torch.nn.Linear(256, 256), torch.nn.SiLU(),
    torch.nn.Linear(256, 1),
).cuda()

optimizer = torch.optim.Adam(surrogate.parameters(), lr=1e-3)

for epoch in range(100):
    for params_batch, nll_batch, grad_batch in loader:
        params_batch = params_batch.cuda()
        nll_batch = nll_batch.cuda()
        grad_batch = grad_batch.cuda()

        pred = surrogate(params_batch).squeeze()
        loss = F.mse_loss(pred, nll_batch)

        # Опциональный: gradient-informed training (Sobolev loss)
        params_batch.requires_grad_(True)
        pred_g = surrogate(params_batch).squeeze()
        pred_grad = torch.autograd.grad(pred_g.sum(), params_batch, create_graph=True)[0]
        loss += 0.1 * F.mse_loss(pred_grad, grad_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Форматы экспорта

FunctionФорматСлучай использования
to_torch_dataset(ds)TensorDatasetPyTorch DataLoader для обучения
to_numpy(ds)dict of ndarraySciPy, sklearn, JAX
to_npz(ds, path).npz (compressed)Постоянное хранилище, воспроизводимость
to_parquet(ds, path).parquet (zstd)Конвейеры Polars, DuckDB, Spark
from_npz(path).npz → DatasetЗагрузка ранее сохранённого набора данных

Валидация

Всегда валидируйте суррогат против точных вычислений NextStat:

import numpy as np
from nextstat.distill import predict_nll

# Сравнить суррогат и точное значение в случайных точках
test_params = np.random.default_rng(99).uniform(
    ds.parameter_bounds[:, 0], ds.parameter_bounds[:, 1],
    size=(1000, ds.n_params)
)

pred = predict_nll(surrogate, test_params)
exact = np.array([model.nll(p.tolist()) for p in test_params])

rmse = np.sqrt(np.mean((pred - exact) ** 2))
max_err = np.max(np.abs(pred - exact))
print(f"RMSE: {rmse:.4f}, Макс. ошибка: {max_err:.4f}")

Когда использовать суррогаты

  • MCMC с большим числом параметров — суррогат заменяет дорогие вызовы NLL во внутреннем цикле HMC/NUTS.
  • Интерактивные дашборды — контуры правдоподобия в реальном времени при перемещении ползунков.
  • Глобальные EFT-фиты — сканирование 100+ коэффициентов Вильсона, где точные фиты непрактичны.
  • Не рекомендуется для финальных результатов — всегда валидируйте точным фитом NextStat. Суррогат для исследования, точное вычисление — для публикации.