Skip to content

Commit c869dd8

Browse files
authored
make evaluate private (#1260)
* make evaluate private * changelog
1 parent 6dfe995 commit c869dd8

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
### Changed
2222

23-
-
23+
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))
2424

2525
### Deprecated
2626

pytorch_lightning/trainer/evaluation_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def reset_test_dataloader(self, *args):
217217
def reset_val_dataloader(self, *args):
218218
"""Warning: this is just empty shell for code implemented in other class."""
219219

220-
def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
220+
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
221221
"""Run evaluation code.
222222
223223
Args:
@@ -365,7 +365,7 @@ def run_evaluation(self, test_mode: bool = False):
365365
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)
366366

367367
# run evaluation
368-
eval_results = self.evaluate(self.model, dataloaders, max_batches, test_mode)
368+
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
369369
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
370370
eval_results)
371371

pytorch_lightning/trainer/trainer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -893,10 +893,10 @@ def run_pretrain_routine(self, model: LightningModule):
893893
# dummy validation progress bar
894894
self.val_progress_bar = tqdm(disable=True)
895895

896-
eval_results = self.evaluate(model,
897-
self.val_dataloaders,
898-
self.num_sanity_val_steps,
899-
False)
896+
eval_results = self._evaluate(model,
897+
self.val_dataloaders,
898+
self.num_sanity_val_steps,
899+
False)
900900
_, _, _, callback_metrics, _ = self.process_output(eval_results)
901901

902902
# close progress bars

tests/test_deprecated.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
9595

9696
trainer = Trainer(logger=False)
9797
# TODO: why `dataloder` is required if it is not used
98-
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
98+
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
9999
assert result == {'val_loss': 0.6}
100100

101101
model = ModelVer0_7(hparams)
@@ -106,5 +106,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
106106

107107
trainer = Trainer(logger=False)
108108
# TODO: why `dataloder` is required if it is not used
109-
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
109+
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
110110
assert result == {'val_loss': 0.7}

0 commit comments

Comments
 (0)