Skip to content

Commit a44b2aa

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add utility for computing AIC/BIC/MLL from a model (#2785)
Summary: Pull Request resolved: #2785 Add utility for computing in-sample model fit metrics Reviewed By: saitcakmak Differential Revision: D71827991 fbshipit-source-id: d69f08eddce95e547421998c596eb89ff7d2d6fb
1 parent 557d016 commit a44b2aa

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

botorch/utils/evaluation.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from math import log
8+
9+
import torch
10+
from botorch.utils.transforms import is_fully_bayesian
11+
from gpytorch.models.exact_gp import ExactGP
12+
13+
MLL = "MLL"
14+
AIC = "AIC"
15+
BIC = "BIC"
16+
17+
18+
def compute_in_sample_model_fit_metric(model: ExactGP, criterion: str) -> float:
19+
"""Compute a in-sample model fit metric.
20+
21+
Args:
22+
model: A fitted ExactGP.
23+
criterion: Evaluation criterion. One of "MLL", "AIC", "BIC". AIC
24+
penalizes the MLL based on the number of parameters. BIC uses
25+
a slightly different penalty based on the number of parameters
26+
and data points.
27+
28+
Returns:
29+
The in-sample evaluation metric.
30+
"""
31+
if criterion not in (AIC, BIC, MLL):
32+
raise ValueError(f"Invalid evaluation criterion {criterion}.")
33+
if is_fully_bayesian(model=model):
34+
model.train(reset=False)
35+
else:
36+
model.train()
37+
with torch.no_grad():
38+
output = model(*model.train_inputs)
39+
output = model.likelihood(output)
40+
mll = output.log_prob(model.train_targets)
41+
# compute average MLL over MCMC samples if the model is fully bayesian
42+
mll_scalar = mll.mean().item()
43+
model.eval()
44+
num_params = sum(p.numel() for p in model.parameters())
45+
if is_fully_bayesian(model=model):
46+
num_params /= mll.shape[0]
47+
if criterion == AIC:
48+
return 2 * num_params - 2 * mll_scalar
49+
elif criterion == BIC:
50+
return num_params * log(model.train_inputs[0].shape[-2]) - 2 * mll_scalar
51+
return mll_scalar

sphinx/source/utils.rst

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ Dispatcher
3232
.. automodule:: botorch.utils.dispatcher
3333
:members:
3434

35+
Evaluation
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
.. automodule:: botorch.utils.evaluation
38+
:members:
39+
3540
Low-Rank Cholesky Update Utils
3641
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3742
.. automodule:: botorch.utils.low_rank

test/utils/test_evaluation.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from itertools import product
7+
from math import log, pi
8+
9+
import torch
10+
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
11+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
12+
from botorch.models.gp_regression import SingleTaskGP
13+
from botorch.test_utils.mock import mock_optimize
14+
from botorch.utils.evaluation import AIC, BIC, compute_in_sample_model_fit_metric, MLL
15+
from botorch.utils.testing import BotorchTestCase
16+
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
17+
18+
19+
class TestEvaluation(BotorchTestCase):
20+
@mock_optimize
21+
def test_compute_in_sample_model_fit_metric(self):
22+
torch.manual_seed(0)
23+
for dtype, model_cls in product(
24+
(torch.float, torch.double), (SingleTaskGP, SaasFullyBayesianSingleTaskGP)
25+
):
26+
train_X = torch.linspace(
27+
0, 1, 10, dtype=dtype, device=self.device
28+
).unsqueeze(-1)
29+
train_Y = torch.sin(2 * pi * train_X)
30+
model = model_cls(train_X=train_X, train_Y=train_Y)
31+
if model_cls is SingleTaskGP:
32+
fit_gpytorch_mll(ExactMarginalLogLikelihood(model.likelihood, model))
33+
else:
34+
fit_fully_bayesian_model_nuts(
35+
model,
36+
warmup_steps=8,
37+
num_samples=6,
38+
thinning=2,
39+
disable_progbar=True,
40+
)
41+
num_params = sum(p.numel() for p in model.parameters())
42+
if model_cls is SaasFullyBayesianSingleTaskGP:
43+
num_params /= 3 # divide by number of MCMC samples
44+
mll = compute_in_sample_model_fit_metric(model=model, criterion=MLL)
45+
aic = compute_in_sample_model_fit_metric(model=model, criterion=AIC)
46+
bic = compute_in_sample_model_fit_metric(model=model, criterion=BIC)
47+
self.assertEqual(aic, 2 * num_params - 2 * mll)
48+
self.assertEqual(bic, log(10) * num_params - 2 * mll)
49+
# test invalid criterion
50+
with self.assertRaisesRegex(
51+
ValueError, "Invalid evaluation criterion invalid."
52+
):
53+
compute_in_sample_model_fit_metric(model=model, criterion="invalid")

0 commit comments

Comments
 (0)