Нейронная суррогатная дистилляция
Обучение наносекундных суррогатов из правдоподобия NextStat
NextStat может служить высокоточным оракулом для обучения нейросетевых суррогатов поверхности правдоподобия HistFactory. Суррогат работает за наносекунды вместо миллисекунд, обеспечивая MCMC в реальном времени, глобальные EFT-фиты или интерактивные дашборды.
Рабочий процесс
Sample parameter space (Sobol / LHS) → Evaluate NLL + gradient via NextStat GPU → Train a small MLP → Deploy the surrogate for real-time inference. NextStat предоставляет истину; суррогат обеспечивает скорость.
Быстрый старт
import nextstat
from nextstat.distill import generate_dataset, train_mlp_surrogate, predict_nll
model = nextstat.from_pyhf(workspace_json)
# 1. Generate 100k (params, NLL, gradient) tuples
ds = generate_dataset(model, n_samples=100_000, method="sobol")
print(f"{ds.n_samples} points, {ds.n_params} params")
print(f"NLL range: [{ds.nll.min():.1f}, {ds.nll.max():.1f}]")
# 2. Train a surrogate MLP (built-in convenience)
surrogate = train_mlp_surrogate(ds, epochs=100, device="cuda")
# 3. Predict NLL at new points (nanoseconds per eval)
import numpy as np
test_params = np.array(model.parameter_init())
pred_nll = predict_nll(surrogate, test_params)
print(f"Surrogate NLL: {pred_nll:.2f}")Методы сэмплирования
| Method | Покрытие | Лучше всего для |
|---|---|---|
| sobol | Квазислучайный, малое расхождение | По умолчанию. Лучшее покрытие минимумом точек. |
| lhs | Стратифицированный по измерениям | Хорошее покрытие, не требует степени двойки. |
| uniform | Чисто случайный | Базовое сравнение. |
| gaussian | Концентрация вблизи MLE | Тонкая настройка вблизи минимума. Фокусные суррогаты. |
Пользовательский обучающий цикл
Для продакшн-использования конвертируйте набор данных в PyTorch и напишите своё обучение:
from nextstat.distill import generate_dataset, to_torch_dataset
import torch
import torch.nn.functional as F
ds = generate_dataset(model, n_samples=500_000, method="sobol")
train_ds = to_torch_dataset(ds)
loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True)
surrogate = torch.nn.Sequential(
torch.nn.Linear(ds.n_params, 256), torch.nn.SiLU(),
torch.nn.Linear(256, 256), torch.nn.SiLU(),
torch.nn.Linear(256, 1),
).cuda()
optimizer = torch.optim.Adam(surrogate.parameters(), lr=1e-3)
for epoch in range(100):
for params_batch, nll_batch, grad_batch in loader:
params_batch = params_batch.cuda()
nll_batch = nll_batch.cuda()
grad_batch = grad_batch.cuda()
pred = surrogate(params_batch).squeeze()
loss = F.mse_loss(pred, nll_batch)
# Опциональный: gradient-informed training (Sobolev loss)
params_batch.requires_grad_(True)
pred_g = surrogate(params_batch).squeeze()
pred_grad = torch.autograd.grad(pred_g.sum(), params_batch, create_graph=True)[0]
loss += 0.1 * F.mse_loss(pred_grad, grad_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()Форматы экспорта
| Function | Формат | Случай использования |
|---|---|---|
| to_torch_dataset(ds) | TensorDataset | PyTorch DataLoader для обучения |
| to_numpy(ds) | dict of ndarray | SciPy, sklearn, JAX |
| to_npz(ds, path) | .npz (compressed) | Постоянное хранилище, воспроизводимость |
| to_parquet(ds, path) | .parquet (zstd) | Конвейеры Polars, DuckDB, Spark |
| from_npz(path) | .npz → Dataset | Загрузка ранее сохранённого набора данных |
Валидация
Всегда валидируйте суррогат против точных вычислений NextStat:
import numpy as np
from nextstat.distill import predict_nll
# Сравнить суррогат и точное значение в случайных точках
test_params = np.random.default_rng(99).uniform(
ds.parameter_bounds[:, 0], ds.parameter_bounds[:, 1],
size=(1000, ds.n_params)
)
pred = predict_nll(surrogate, test_params)
exact = np.array([model.nll(p.tolist()) for p in test_params])
rmse = np.sqrt(np.mean((pred - exact) ** 2))
max_err = np.max(np.abs(pred - exact))
print(f"RMSE: {rmse:.4f}, Макс. ошибка: {max_err:.4f}")Когда использовать суррогаты
- MCMC с большим числом параметров — суррогат заменяет дорогие вызовы NLL во внутреннем цикле HMC/NUTS.
- Интерактивные дашборды — контуры правдоподобия в реальном времени при перемещении ползунков.
- Глобальные EFT-фиты — сканирование 100+ коэффициентов Вильсона, где точные фиты непрактичны.
- Не рекомендуется для финальных результатов — всегда валидируйте точным фитом NextStat. Суррогат для исследования, точное вычисление — для публикации.
