Обучение с SignificanceLoss
Сквозной дифференцируемый пайплайн
Это руководство показывает, как обучать нейросеть в PyTorch, где функция потерь равна профилированной значимости открытия Z₀. Градиент идет от статистического теста, через дифференцируемую гистограмму, и дальше до весов нейросети.
Обзор архитектуры
Нейросеть NextStat (Rust/CUDA)
───────── ────────────────────
Входные признаки HistFactory model
│ (систематики, фоны)
▼ │
Classifier(x) -> scores │
│ │
▼ │
SoftHistogram(scores) -> bins ─────────▶ SignificanceLoss
│ │
│ ◄─── ∂(-Z₀)/∂bins ─┘
▼
loss.backward() -> ∂loss/∂weights -> optimizer.step()Шаг 1: подготовьте статистическую модель
NextStat использует формат HistFactory: тот же JSON, который производит pyhf. Если у вас уже есть pyhf workspace, его можно загрузить напрямую.
import json
import nextstat
# Загрузить pyhf-style workspace JSON
with open("workspace.json") as f:
ws = json.load(f)
model = nextstat.from_pyhf(ws)
# Модель содержит: каналы, сэмплы, систематики, наблюдаемые данныеШаг 2: создайте функцию потерь
from nextstat.torch import SignificanceLoss
# SignificanceLoss оборачивает профилированную Z₀ в привычный интерфейс __call__.
# По умолчанию возвращает -Z₀, чтобы минимизация SGD максимизировала значимость.
loss_fn = SignificanceLoss(
model,
signal_sample_name="signal", # какой сэмпл контролирует нейросеть
device="auto", # "cuda", "metal" или "auto"
negate=True, # -Z₀ для минимизации (по умолчанию)
eps=1e-12, # численная устойчивость в sqrt
)
print(f"Бины сигнала: {loss_fn.n_bins}") # например, 10
print(f"Мешающие параметры: {loss_fn.n_params}") # например, 23Шаг 3: дифференцируемое биннингование
Обычная гистограмма недифференцируема (жесткие границы бинов дают нулевой градиент).SoftHistogram решает это через гауссов KDE или сигмоидальные аппроксимации.
from nextstat.torch import SoftHistogram
# Границы бинов должны соответствовать статистической модели
soft_hist = SoftHistogram(
bin_edges=torch.linspace(0.0, 1.0, 11), # 10 бинов на [0, 1]
bandwidth=0.05, # ширина KDE (меньше = резче, но шумнее)
mode="kde", # "kde" (Gaussian) или "sigmoid" (быстрее)
)
# Использование:
scores = classifier(batch_features) # [N] непрерывные выходы
histogram = soft_hist(scores, weights) # [10] дифференцируемые счета по бинам| Режим | Скорость | Качество градиента |
|---|---|---|
| kde | O(N × B) | Гладко, низкая дисперсия. Рекомендуется для обучения. |
| sigmoid | O(N × B) | Более резкие бины, но градиенты шумнее. Удобно для fine-tuning. |
Шаг 4: цикл обучения
import torch
classifier = MyClassifier(input_dim=20, hidden=64).cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
for epoch in range(50):
for batch_x, batch_w in dataloader:
optimizer.zero_grad()
# Прямой проход: NN → scores → soft histogram → -Z₀
scores = classifier(batch_x.cuda())
histogram = soft_hist(scores, batch_w.cuda())
loss = loss_fn(histogram.double().cuda())
# Обратный проход: градиенты проходят через NextStat к весам NN
loss.backward()
optimizer.step()
# Мониторинг: negate, чтобы получить положительную Z₀
with torch.no_grad():
z0 = -loss.item()
print(f"Эпоха {epoch}: Z₀ = {z0:.3f}σ")Шаг 5: интероперабельность (JAX, CuPy)
Если ваш data pipeline использует JAX или CuPy, используйте as_tensor(), чтобы конвертировать массивы без копирования через протокол DLPack.
from nextstat.torch import as_tensor
# JAX → PyTorch (без копирования на GPU через DLPack)
import jax.numpy as jnp
jax_hist = jnp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(jax_hist).double()
# CuPy → PyTorch (без копирования через __dlpack__)
import cupy as cp
cupy_hist = cp.array([10.0, 20.0, 30.0, 40.0])
torch_hist = as_tensor(cupy_hist).double()
# Работает с: PyTorch, JAX, CuPy, NumPy, Apache Arrow, спискиРасширенное: прямой доступ к якобиану
Для внешних оптимизаторов (SciPy, Optuna) или анализа на уровне бинов можно извлечь сырой градиент ∂q₀/∂signal, не проходя через autograd:
from nextstat.torch import signal_jacobian, signal_jacobian_numpy
# Как тензор PyTorch (на том же устройстве)
grad = signal_jacobian(signal_hist, loss_fn.session)
# Как массив NumPy (для SciPy / Optuna)
grad_np = signal_jacobian_numpy(signal_hist, loss_fn.session)
# Быстрое прореживание: найти бины с низким вкладом
important = grad.abs() > 0.01
print(f"Важные бины: {important.sum()}/{len(important)}")Расширенное: пакетная оценка
from nextstat.torch import batch_profiled_q0_loss
# Оценить сразу несколько гистограмм (например, элементы ансамбля)
histograms = torch.stack([hist_1, hist_2, hist_3]) # [3, n_bins]
q0_list = batch_profiled_q0_loss(histograms, loss_fn.session)
# q0_list = [Tensor(q0_1), Tensor(q0_2), Tensor(q0_3)]Советы
- dtype - SignificanceLoss ожидает float64. Всегда вызывайте
.double()перед передачей тензоров. - подбор bandwidth - начните с
bandwidth="auto", затем уменьшайте для более четких бинов после стабилизации обучения. - learning rate - ландшафт потерь невыпуклый (два L-BFGS-B фита за шаг). Используйте Adam с lr ≈ 1e-3 ... 1e-4.
- warm-up - рассмотрите предобучение с кросс-энтропией на несколько эпох перед переключением на SignificanceLoss.
- Metal (macOS) - Apple Silicon поддерживается. Сигнальная гистограмма загружается через CPU (без zero-copy), но L-BFGS-B фиты выполняются на Metal GPU.
