-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save test predictions on multiple GPUs #2926
Changes from all commits
b2eaf71
0888391
d590cab
5d6633b
3b44a44
4b84912
65a4c05
7cbcc46
1f535d6
71a8c99
bd45203
0a9fe65
6e82b20
68b2b54
8e974d2
de984c4
c09efa4
bf6b696
031729e
e8b35f1
7b9fd75
0d91462
ca6762f
9f23dc2
373ab8e
0a1f9c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,7 +134,7 @@ | |
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType | ||
from pytorch_lightning.core.step_result import Result, EvalResult | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
from pytorch_lightning.trainer.supporters import PredictionCollection | ||
|
||
try: | ||
import torch_xla.distributed.parallel_loader as xla_pl | ||
|
@@ -278,6 +278,7 @@ def _evaluate( | |
|
||
# bookkeeping | ||
outputs = [] | ||
predictions = PredictionCollection(self.global_rank, self.world_size) | ||
|
||
# convert max_batches to list | ||
if isinstance(max_batches, int): | ||
|
@@ -370,6 +371,12 @@ def _evaluate( | |
|
||
# track outputs for collation | ||
if output is not None: | ||
|
||
# Add step predictions to prediction collection to write later | ||
do_write_predictions = is_result_obj and test_mode | ||
if do_write_predictions: | ||
predictions.add(output.pop('predictions', None)) | ||
|
||
dl_outputs.append(output) | ||
|
||
self.__eval_add_step_metrics(output) | ||
|
@@ -388,6 +395,9 @@ def _evaluate( | |
# log callback metrics | ||
self.__update_callback_metrics(eval_results, using_eval_result) | ||
|
||
# Write predictions to disk if they're available. | ||
predictions.to_disk() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should just be an internal function of the prediction object. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can’t wait till the end of epoch to write predictions bc we will accumulate too much memory. the write needs to happen at every batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we want to deal with writing to the cache file (w/
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# enable train mode again | ||
model.train() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.