Skip to content

Add workaround for pytorch issue with serializing dtype #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 63 additions & 157 deletions examples/tutorials/tutorial_02_qm9.ipynb
Original file line number Diff line number Diff line change
@@ -74,17 +74,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:root:Downloading GDB-9 atom references...\n",
"INFO:root:Done.\n",
"INFO:root:Downloading GDB-9 data...\n",
"INFO:root:Done.\n",
"INFO:root:Extracting files...\n",
"INFO:root:Done.\n",
"INFO:root:Parse xyz files...\n",
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 133885/133885 [01:51<00:00, 1197.01it/s]\n",
"INFO:root:Write atoms to db...\n",
"INFO:root:Done.\n",
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00, 3.22it/s]\n"
"100%|██████████| 10/10 [00:06<00:00, 1.57it/s]\n"
]
}
],
@@ -180,8 +170,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean atomization energy / atom: -4.23855610005026\n",
"Std. dev. atomization energy / atom: 0.1926821633801207\n"
"Mean atomization energy / atom: -4.247325399125455\n",
"Std. dev. atomization energy / atom: 0.1801580985912772\n"
]
}
],
@@ -323,8 +313,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpt2e_xbgs\n",
"INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpt2e_xbgs/_remote_module_non_scriptable.py\n"
"INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpsaotd4ge\n",
"INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpsaotd4ge/_remote_module_non_sriptable.py\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:268: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n",
" rank_zero_warn(\n"
]
}
],
@@ -366,13 +358,12 @@
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True, used: False\n",
"GPU available: True (cuda), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1814: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1764: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.\n",
" rank_zero_warn(\n",
"Missing logger folder: ./qm9tut/lightning_logs\n",
"\n",
" | Name | Type | Params\n",
"---------------------------------------------------\n",
@@ -387,31 +378,12 @@
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.030420780181884766,
"initial": 0,
"n": 0,
"ncols": 189,
"nrows": 17,
"postfix": null,
"prefix": "Sanity Checking",
"rate": null,
"total": null,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"text/plain": "Sanity Checking: 0it [00:00, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
"version_minor": 0,
"model_id": "bb3ac9516922417a92ce5bbbc9906f28"
}
},
"metadata": {},
"output_type": "display_data"
@@ -420,139 +392,70 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n",
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:72: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 100. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:98: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 100. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n",
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" rank_zero_warn(\n",
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
" rank_zero_warn(\n"
]
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.06877684593200684,
"initial": 0,
"n": 0,
"ncols": 189,
"nrows": 17,
"postfix": null,
"prefix": "Training",
"rate": null,
"total": null,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"text/plain": "Training: 0it [00:00, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"model_id": "085ce80df8244dd3ac73effcda6407ee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
"version_minor": 0,
"model_id": "a3dfe943394e46ea9c2f8650e1d21041"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.01809835433959961,
"initial": 0,
"n": 0,
"ncols": 189,
"nrows": 17,
"postfix": null,
"prefix": "Validation",
"rate": null,
"total": null,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"text/plain": "Validation: 0it [00:00, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
"version_minor": 0,
"model_id": "fc72ad793efd412cbcb6f9faa830dfda"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.0167539119720459,
"initial": 0,
"n": 0,
"ncols": 189,
"nrows": 17,
"postfix": null,
"prefix": "Validation",
"rate": null,
"total": null,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"text/plain": "Validation: 0it [00:00, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
"version_minor": 0,
"model_id": "420695c4425f4c7d8e7c38f8baed25a7"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.018230438232421875,
"initial": 0,
"n": 0,
"ncols": 189,
"nrows": 17,
"postfix": null,
"prefix": "Validation",
"rate": null,
"total": null,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"text/plain": "Validation: 0it [00:00, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
"version_minor": 0,
"model_id": "febc0505aa174edfb9f3c5ab275f548b"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
]
}
],
"source": [
@@ -647,24 +550,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Result dictionary: {'energy_U0': tensor([10414.7993, 12365.0673, 10338.4029, 11250.6555, 11797.4884, 11231.0736,\n",
" 10804.1721, 12671.1666, 10319.0031, 11211.2785, 11333.0383, 11408.8837,\n",
" 11796.9613, 10804.0342, 11758.7678, 11350.5577, 10662.9685, 10325.3857,\n",
" 10784.3073, 10319.5092, 12830.0989, 10765.0104, 8765.8276, 11759.1255,\n",
" 9778.4204, 11797.7638, 11778.4216, 8298.1309, 9809.9145, 11351.5860,\n",
" 11798.3985, 11778.0430, 11797.1257, 10802.4717, 10753.5472, 10325.7682,\n",
" 13238.6101, 10356.5500, 11331.8923, 12792.5905, 9790.3895, 10784.3336,\n",
" 9760.3332, 9758.1088, 11852.6609, 10772.1722, 10198.4599, 12225.7697,\n",
" 10326.0554, 10804.3006, 9809.5706, 11785.8515, 11211.2164, 12365.1869,\n",
" 11350.6455, 11351.2174, 10651.8422, 10357.2399, 10803.3964, 11779.4107,\n",
" 10803.2706, 9363.1532, 10784.1137, 10683.4049, 9401.3194, 9363.7182,\n",
" 11797.7716, 10376.6795, 11817.2000, 10216.2281, 10822.8966, 9311.9195,\n",
" 11370.8836, 10357.9132, 8765.8390, 11797.6732, 11350.4620, 11351.6353,\n",
" 11522.6320, 11351.7427, 9828.9803, 11696.6340, 9332.3702, 11796.8731,\n",
" 10395.7840, 11779.0953, 11370.6693, 10803.5162, 10317.9830, 11676.6975,\n",
" 11675.6500, 13116.9391, 10414.5667, 10783.8263, 10803.2187, 13377.5496,\n",
" 10356.8070, 11129.3613, 11370.7541, 11370.7576], dtype=torch.float64,\n",
" grad_fn=<SubBackward0>)}\n"
"Result dictionary: {'energy_U0': tensor([-11901.9678, -10829.1715, -10493.9096, -11365.3375, -9995.8116,\n",
" -10451.7412, -10851.2470, -11006.8458, -10494.9420, -11368.3874,\n",
" -8844.8229, -11902.6537, -9918.1257, -9956.3337, -10833.8384,\n",
" -12016.3231, -12344.6808, -11981.6151, -11842.0390, -10573.4037,\n",
" -10930.8419, -10414.9862, -10340.3468, -11508.4475, -10553.1781,\n",
" -11464.8257, -11010.9114, -10573.8298, -11546.2505, -10398.0184,\n",
" -11901.8865, -12382.0646, -11805.9859, -11468.2166, -12303.6954,\n",
" -11982.0471, -11942.9695, -10972.0845, -12742.0990, -12305.2618,\n",
" -9995.6813, -11326.8686, -13931.4072, -10534.4627, -11945.2061,\n",
" -12557.3998, -11943.9106, -10568.2193, -11538.0142, -10492.3288,\n",
" -9857.8994, -11368.2026, -11506.9391, -10965.4910, -10973.1663,\n",
" -11584.8918, -11503.7264, -12990.9329, -12518.4351, -11543.0566,\n",
" -11408.7530, -11942.5794, -13317.8285, -9597.8316, -10930.5504,\n",
" -12460.0102, -11802.8971, -10395.8514, -13355.2561, -9478.2067,\n",
" -5291.8420, -10411.8928, -11804.3231, -11766.3743, -10532.8525,\n",
" -9604.8805, -12478.7421, -11747.7678, -11368.4521, -9609.7054,\n",
" -12381.6398, -10635.5377, -11867.4939, -11767.7288, -10473.4594,\n",
" -11267.5563, -11845.0998, -12304.8664, -11582.9844, -11542.7391,\n",
" -10531.9801, -10973.6226, -11403.2258, -10489.5223, -11585.6760,\n",
" -10929.6288, -11908.0952, -12917.9566, -9458.0325, -13433.7984],\n",
" dtype=torch.float64, grad_fn=<AddBackward0>)}\n"
]
}
],
@@ -733,7 +639,7 @@
"output_type": "stream",
"text": [
"Keys: ['_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_i', '_idx_j']\n",
"Prediction: tensor([1064.8288], dtype=torch.float64, grad_fn=<SubBackward0>)\n"
"Prediction: tensor([-1103.2246], dtype=torch.float64, grad_fn=<AddBackward0>)\n"
]
}
],
@@ -778,7 +684,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction: 1064.8287982940674\n"
"Prediction: -1103.2246329784393\n"
]
}
],
21 changes: 11 additions & 10 deletions src/schnetpack/model/base.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

from schnetpack.transform import Transform
import schnetpack.properties as properties
from schnetpack.utils import as_dtype

import torch
import torch.nn as nn
@@ -30,11 +31,11 @@ def __init__(
representation: nn.Module,
output_module: nn.Module,
postprocessors: Optional[List[Transform]] = None,
input_dtype: torch.dtype = torch.float32,
input_dtype_str: str = "float32",
do_postprocessing: bool = True,
):
super().__init__(
input_dtype=input_dtype,
input_dtype_str=input_dtype_str,
postprocessors=postprocessors,
do_postprocessing=do_postprocessing,
)
@@ -58,18 +59,18 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def __init__(
self,
postprocessors: Optional[List[Transform]] = None,
input_dtype: torch.dtype = torch.float32,
input_dtype_str: str = "float32",
do_postprocessing: bool = True,
):
"""
Args:
postprocessors: Post-processing transforms tha may be initialized using te `datamodule`, but are not
applied during training.
input_dtype: The dtype of real inputs.
postprocessors: Post-processing transforms tha may be initialized using the
`datamodule`, but are not applied during training.
input_dtype_str: The dtype of real inputs as string.
do_postprocessing: If true, post-processing is activated.
"""
super().__init__()
self.input_dtype = input_dtype
self.input_dtype_str = input_dtype_str
self.do_postprocessing = do_postprocessing
self.postprocessors = nn.ModuleList(postprocessors)
self.required_derivatives: Optional[List[str]] = None
@@ -138,7 +139,7 @@ def __init__(
input_modules: List[nn.Module] = None,
output_modules: List[nn.Module] = None,
postprocessors: Optional[List[Transform]] = None,
input_dtype: torch.dtype = torch.float32,
input_dtype_str: str = "float32",
do_postprocessing: Optional[bool] = None,
):
"""
@@ -149,11 +150,11 @@ def __init__(
output_modules: Modules that predict output properties from the representation.
postprocessors: Post-processing transforms that may be initialized using te `datamodule`, but are not
applied during training.
input_dtype: The dtype of real inputs.
input_dtype_str: The dtype of real inputs.
do_postprocessing: If true, post-processing is activated.
"""
super().__init__(
input_dtype=input_dtype,
input_dtype_str=input_dtype_str,
postprocessors=postprocessors,
do_postprocessing=do_postprocessing,
)
2 changes: 1 addition & 1 deletion src/schnetpack/task.py
Original file line number Diff line number Diff line change
@@ -130,7 +130,7 @@ def __init__(
self.grad_enabled = len(self.model.required_derivatives) > 0
self.lr = optimizer_args["lr"]
self.warmup_steps = warmup_steps
self.save_hyperparameters(ignore=["model"])
self.save_hyperparameters()

def setup(self, stage=None):
if stage == "fit":
14 changes: 8 additions & 6 deletions src/schnetpack/transform/casting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from typing import Dict, Optional
from schnetpack.utils import as_dtype

import torch

@@ -16,10 +17,10 @@ class CastMap(Transform):
is_preprocessor: bool = True
is_postprocessor: bool = True

def __init__(self, type_map: Dict[torch.dtype, torch.dtype]):
def __init__(self, type_map: Dict[str, str]):
"""
Args:
type_map: dict with soource_type: target_type
type_map: dict with source_type: target_type (as strings)
"""
super().__init__()
self.type_map = type_map
@@ -29,20 +30,21 @@ def forward(
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
for k, v in inputs.items():
if v.dtype in self.type_map:
inputs[k] = v.to(dtype=self.type_map[v.dtype])
vdtype = str(v.dtype).split(".")[-1]
if vdtype in self.type_map:
inputs[k] = v.to(dtype=as_dtype(self.type_map[vdtype]))
return inputs


class CastTo32(CastMap):
"""Cast all float64 tensors to float32"""

def __init__(self):
super().__init__(type_map={torch.float64: torch.float32})
super().__init__(type_map={"float64": "float32"})


class CastTo64(CastMap):
"""Cast all float32 tensors to float64"""

def __init__(self):
super().__init__(type_map={torch.float32: torch.float64})
super().__init__(type_map={"float32": "float64"})
34 changes: 34 additions & 0 deletions src/schnetpack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,40 @@

from schnetpack import properties as spk_properties

TORCH_DTYPES = {
"float32": torch.float32,
"float64": torch.float64,
"float": torch.float,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"half": torch.half,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"short": torch.short,
"int32": torch.int32,
"int": torch.int,
"int64": torch.int64,
"long": torch.long,
"complex64": torch.complex64,
"cfloat": torch.cfloat,
"complex128": torch.complex128,
"cdouble": torch.cdouble,
"quint8": torch.quint8,
"qint8": torch.qint8,
"qint32": torch.qint32,
"bool": torch.bool,
"quint4x2": torch.quint4x2,
"quint2x4": torch.quint2x4,
}

TORCH_DTYPES.update({"torch." + k: v for k, v in TORCH_DTYPES.items()})


def as_dtype(dtype_str: str) -> torch.dtype:
"""Convert a string to torch.dtype"""
return TORCH_DTYPES[dtype_str]


def int2precision(precision: Union[int, torch.dtype]):
"""