Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better fix for load_from_checkpoint() not working with absolute path on Windows #2294

Merged
merged 4 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))

- Fix for `load_from_checkpoint()` not working with absolute path on Windows ([#2294](https://github.com/PyTorchLightning/pytorch-lightning/pull/2294))

- Fixed an issue how _has_len handles `NotImplementedError` e.g. raised by `torchtext.data.Iterator` ([#2293](https://github.com/PyTorchLightning/pytorch-lightning/pull/2293)), ([#2307](https://github.com/PyTorchLightning/pytorch-lightning/pull/2307))

- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
Expand All @@ -49,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `overfit_batches`, `limit_{val|test}_batches` flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
- Added metrics
- Added metrics
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
* Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488), [#2062](https://github.com/PyTorchLightning/pytorch-lightning/pull/2062))
Expand All @@ -59,7 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
Expand All @@ -74,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862))
- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))
- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))
- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
- Re-Enable Logger's `ImportError`s ([#1938](https://github.com/PyTorchLightning/pytorch-lightning/pull/1938))
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
Expand Down Expand Up @@ -107,7 +109,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631))
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
- Fixed multiple calls of `EarlyStopping` callback ([#1863](https://github.com/PyTorchLightning/pytorch-lightning/pull/1863))
- Fixed multiple calls of `EarlyStopping` callback ([#1863](https://github.com/PyTorchLightning/pytorch-lightning/pull/1863))
- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
- Fixed root node resolution for SLURM cluster with dash in host name ([#1954](https://github.com/PyTorchLightning/pytorch-lightning/pull/1954))
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@


def load(path_or_url: str, map_location=None):
parsed = urlparse(path_or_url)
if parsed.scheme == '' or Path(path_or_url).is_file():
# no scheme or local file
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
return torch.load(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stylistic nitpick: we don't need to have this else clause. I personally find it more "assuring" in a way without the else clause b/c you can be sure that the return statement will catch everything haha

return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)