Skip to content

Commit bf66d59

Browse files
authored
Add workaround for pytorch issue with serializing dtype (#442)
* Add workaround for pytorch issue with serializing dtype * Add some dtype aliases
1 parent 8cf3519 commit bf66d59

File tree

5 files changed

+117
-174
lines changed

5 files changed

+117
-174
lines changed

Diff for: examples/tutorials/tutorial_02_qm9.ipynb

+63-157
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,7 @@
7474
"name": "stderr",
7575
"output_type": "stream",
7676
"text": [
77-
"INFO:root:Downloading GDB-9 atom references...\n",
78-
"INFO:root:Done.\n",
79-
"INFO:root:Downloading GDB-9 data...\n",
80-
"INFO:root:Done.\n",
81-
"INFO:root:Extracting files...\n",
82-
"INFO:root:Done.\n",
83-
"INFO:root:Parse xyz files...\n",
84-
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 133885/133885 [01:51<00:00, 1197.01it/s]\n",
85-
"INFO:root:Write atoms to db...\n",
86-
"INFO:root:Done.\n",
87-
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00, 3.22it/s]\n"
77+
"100%|██████████| 10/10 [00:06<00:00, 1.57it/s]\n"
8878
]
8979
}
9080
],
@@ -180,8 +170,8 @@
180170
"name": "stdout",
181171
"output_type": "stream",
182172
"text": [
183-
"Mean atomization energy / atom: -4.23855610005026\n",
184-
"Std. dev. atomization energy / atom: 0.1926821633801207\n"
173+
"Mean atomization energy / atom: -4.247325399125455\n",
174+
"Std. dev. atomization energy / atom: 0.1801580985912772\n"
185175
]
186176
}
187177
],
@@ -323,8 +313,10 @@
323313
"name": "stderr",
324314
"output_type": "stream",
325315
"text": [
326-
"INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpt2e_xbgs\n",
327-
"INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpt2e_xbgs/_remote_module_non_scriptable.py\n"
316+
"INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpsaotd4ge\n",
317+
"INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpsaotd4ge/_remote_module_non_sriptable.py\n",
318+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:268: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n",
319+
" rank_zero_warn(\n"
328320
]
329321
}
330322
],
@@ -366,13 +358,12 @@
366358
"name": "stderr",
367359
"output_type": "stream",
368360
"text": [
369-
"GPU available: True, used: False\n",
361+
"GPU available: True (cuda), used: False\n",
370362
"TPU available: False, using: 0 TPU cores\n",
371363
"IPU available: False, using: 0 IPUs\n",
372364
"HPU available: False, using: 0 HPUs\n",
373-
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1814: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.\n",
365+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1764: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.\n",
374366
" rank_zero_warn(\n",
375-
"Missing logger folder: ./qm9tut/lightning_logs\n",
376367
"\n",
377368
" | Name | Type | Params\n",
378369
"---------------------------------------------------\n",
@@ -387,31 +378,12 @@
387378
},
388379
{
389380
"data": {
390-
"application/json": {
391-
"ascii": false,
392-
"bar_format": null,
393-
"colour": null,
394-
"elapsed": 0.030420780181884766,
395-
"initial": 0,
396-
"n": 0,
397-
"ncols": 189,
398-
"nrows": 17,
399-
"postfix": null,
400-
"prefix": "Sanity Checking",
401-
"rate": null,
402-
"total": null,
403-
"unit": "it",
404-
"unit_divisor": 1000,
405-
"unit_scale": false
406-
},
381+
"text/plain": "Sanity Checking: 0it [00:00, ?it/s]",
407382
"application/vnd.jupyter.widget-view+json": {
408-
"model_id": "",
409383
"version_major": 2,
410-
"version_minor": 0
411-
},
412-
"text/plain": [
413-
"Sanity Checking: 0it [00:00, ?it/s]"
414-
]
384+
"version_minor": 0,
385+
"model_id": "bb3ac9516922417a92ce5bbbc9906f28"
386+
}
415387
},
416388
"metadata": {},
417389
"output_type": "display_data"
@@ -420,139 +392,70 @@
420392
"name": "stderr",
421393
"output_type": "stream",
422394
"text": [
423-
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
395+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
424396
" rank_zero_warn(\n",
425-
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:72: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 100. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
397+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:98: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 100. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
426398
" warning_cache.warn(\n",
427-
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
399+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
428400
" rank_zero_warn(\n",
429-
"/home/mitx/anaconda3/envs/spk2_test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
401+
"/home/kschuett/anaconda3/envs/spkdev/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
430402
" rank_zero_warn(\n"
431403
]
432404
},
433405
{
434406
"data": {
435-
"application/json": {
436-
"ascii": false,
437-
"bar_format": null,
438-
"colour": null,
439-
"elapsed": 0.06877684593200684,
440-
"initial": 0,
441-
"n": 0,
442-
"ncols": 189,
443-
"nrows": 17,
444-
"postfix": null,
445-
"prefix": "Training",
446-
"rate": null,
447-
"total": null,
448-
"unit": "it",
449-
"unit_divisor": 1000,
450-
"unit_scale": false
451-
},
407+
"text/plain": "Training: 0it [00:00, ?it/s]",
452408
"application/vnd.jupyter.widget-view+json": {
453-
"model_id": "085ce80df8244dd3ac73effcda6407ee",
454409
"version_major": 2,
455-
"version_minor": 0
456-
},
457-
"text/plain": [
458-
"Training: 0it [00:00, ?it/s]"
459-
]
410+
"version_minor": 0,
411+
"model_id": "a3dfe943394e46ea9c2f8650e1d21041"
412+
}
460413
},
461414
"metadata": {},
462415
"output_type": "display_data"
463416
},
464417
{
465418
"data": {
466-
"application/json": {
467-
"ascii": false,
468-
"bar_format": null,
469-
"colour": null,
470-
"elapsed": 0.01809835433959961,
471-
"initial": 0,
472-
"n": 0,
473-
"ncols": 189,
474-
"nrows": 17,
475-
"postfix": null,
476-
"prefix": "Validation",
477-
"rate": null,
478-
"total": null,
479-
"unit": "it",
480-
"unit_divisor": 1000,
481-
"unit_scale": false
482-
},
419+
"text/plain": "Validation: 0it [00:00, ?it/s]",
483420
"application/vnd.jupyter.widget-view+json": {
484-
"model_id": "",
485421
"version_major": 2,
486-
"version_minor": 0
487-
},
488-
"text/plain": [
489-
"Validation: 0it [00:00, ?it/s]"
490-
]
422+
"version_minor": 0,
423+
"model_id": "fc72ad793efd412cbcb6f9faa830dfda"
424+
}
491425
},
492426
"metadata": {},
493427
"output_type": "display_data"
494428
},
495429
{
496430
"data": {
497-
"application/json": {
498-
"ascii": false,
499-
"bar_format": null,
500-
"colour": null,
501-
"elapsed": 0.0167539119720459,
502-
"initial": 0,
503-
"n": 0,
504-
"ncols": 189,
505-
"nrows": 17,
506-
"postfix": null,
507-
"prefix": "Validation",
508-
"rate": null,
509-
"total": null,
510-
"unit": "it",
511-
"unit_divisor": 1000,
512-
"unit_scale": false
513-
},
431+
"text/plain": "Validation: 0it [00:00, ?it/s]",
514432
"application/vnd.jupyter.widget-view+json": {
515-
"model_id": "",
516433
"version_major": 2,
517-
"version_minor": 0
518-
},
519-
"text/plain": [
520-
"Validation: 0it [00:00, ?it/s]"
521-
]
434+
"version_minor": 0,
435+
"model_id": "420695c4425f4c7d8e7c38f8baed25a7"
436+
}
522437
},
523438
"metadata": {},
524439
"output_type": "display_data"
525440
},
526441
{
527442
"data": {
528-
"application/json": {
529-
"ascii": false,
530-
"bar_format": null,
531-
"colour": null,
532-
"elapsed": 0.018230438232421875,
533-
"initial": 0,
534-
"n": 0,
535-
"ncols": 189,
536-
"nrows": 17,
537-
"postfix": null,
538-
"prefix": "Validation",
539-
"rate": null,
540-
"total": null,
541-
"unit": "it",
542-
"unit_divisor": 1000,
543-
"unit_scale": false
544-
},
443+
"text/plain": "Validation: 0it [00:00, ?it/s]",
545444
"application/vnd.jupyter.widget-view+json": {
546-
"model_id": "",
547445
"version_major": 2,
548-
"version_minor": 0
549-
},
550-
"text/plain": [
551-
"Validation: 0it [00:00, ?it/s]"
552-
]
446+
"version_minor": 0,
447+
"model_id": "febc0505aa174edfb9f3c5ab275f548b"
448+
}
553449
},
554450
"metadata": {},
555451
"output_type": "display_data"
452+
},
453+
{
454+
"name": "stderr",
455+
"output_type": "stream",
456+
"text": [
457+
"`Trainer.fit` stopped: `max_epochs=3` reached.\n"
458+
]
556459
}
557460
],
558461
"source": [
@@ -647,24 +550,27 @@
647550
"name": "stdout",
648551
"output_type": "stream",
649552
"text": [
650-
"Result dictionary: {'energy_U0': tensor([10414.7993, 12365.0673, 10338.4029, 11250.6555, 11797.4884, 11231.0736,\n",
651-
" 10804.1721, 12671.1666, 10319.0031, 11211.2785, 11333.0383, 11408.8837,\n",
652-
" 11796.9613, 10804.0342, 11758.7678, 11350.5577, 10662.9685, 10325.3857,\n",
653-
" 10784.3073, 10319.5092, 12830.0989, 10765.0104, 8765.8276, 11759.1255,\n",
654-
" 9778.4204, 11797.7638, 11778.4216, 8298.1309, 9809.9145, 11351.5860,\n",
655-
" 11798.3985, 11778.0430, 11797.1257, 10802.4717, 10753.5472, 10325.7682,\n",
656-
" 13238.6101, 10356.5500, 11331.8923, 12792.5905, 9790.3895, 10784.3336,\n",
657-
" 9760.3332, 9758.1088, 11852.6609, 10772.1722, 10198.4599, 12225.7697,\n",
658-
" 10326.0554, 10804.3006, 9809.5706, 11785.8515, 11211.2164, 12365.1869,\n",
659-
" 11350.6455, 11351.2174, 10651.8422, 10357.2399, 10803.3964, 11779.4107,\n",
660-
" 10803.2706, 9363.1532, 10784.1137, 10683.4049, 9401.3194, 9363.7182,\n",
661-
" 11797.7716, 10376.6795, 11817.2000, 10216.2281, 10822.8966, 9311.9195,\n",
662-
" 11370.8836, 10357.9132, 8765.8390, 11797.6732, 11350.4620, 11351.6353,\n",
663-
" 11522.6320, 11351.7427, 9828.9803, 11696.6340, 9332.3702, 11796.8731,\n",
664-
" 10395.7840, 11779.0953, 11370.6693, 10803.5162, 10317.9830, 11676.6975,\n",
665-
" 11675.6500, 13116.9391, 10414.5667, 10783.8263, 10803.2187, 13377.5496,\n",
666-
" 10356.8070, 11129.3613, 11370.7541, 11370.7576], dtype=torch.float64,\n",
667-
" grad_fn=<SubBackward0>)}\n"
553+
"Result dictionary: {'energy_U0': tensor([-11901.9678, -10829.1715, -10493.9096, -11365.3375, -9995.8116,\n",
554+
" -10451.7412, -10851.2470, -11006.8458, -10494.9420, -11368.3874,\n",
555+
" -8844.8229, -11902.6537, -9918.1257, -9956.3337, -10833.8384,\n",
556+
" -12016.3231, -12344.6808, -11981.6151, -11842.0390, -10573.4037,\n",
557+
" -10930.8419, -10414.9862, -10340.3468, -11508.4475, -10553.1781,\n",
558+
" -11464.8257, -11010.9114, -10573.8298, -11546.2505, -10398.0184,\n",
559+
" -11901.8865, -12382.0646, -11805.9859, -11468.2166, -12303.6954,\n",
560+
" -11982.0471, -11942.9695, -10972.0845, -12742.0990, -12305.2618,\n",
561+
" -9995.6813, -11326.8686, -13931.4072, -10534.4627, -11945.2061,\n",
562+
" -12557.3998, -11943.9106, -10568.2193, -11538.0142, -10492.3288,\n",
563+
" -9857.8994, -11368.2026, -11506.9391, -10965.4910, -10973.1663,\n",
564+
" -11584.8918, -11503.7264, -12990.9329, -12518.4351, -11543.0566,\n",
565+
" -11408.7530, -11942.5794, -13317.8285, -9597.8316, -10930.5504,\n",
566+
" -12460.0102, -11802.8971, -10395.8514, -13355.2561, -9478.2067,\n",
567+
" -5291.8420, -10411.8928, -11804.3231, -11766.3743, -10532.8525,\n",
568+
" -9604.8805, -12478.7421, -11747.7678, -11368.4521, -9609.7054,\n",
569+
" -12381.6398, -10635.5377, -11867.4939, -11767.7288, -10473.4594,\n",
570+
" -11267.5563, -11845.0998, -12304.8664, -11582.9844, -11542.7391,\n",
571+
" -10531.9801, -10973.6226, -11403.2258, -10489.5223, -11585.6760,\n",
572+
" -10929.6288, -11908.0952, -12917.9566, -9458.0325, -13433.7984],\n",
573+
" dtype=torch.float64, grad_fn=<AddBackward0>)}\n"
668574
]
669575
}
670576
],
@@ -733,7 +639,7 @@
733639
"output_type": "stream",
734640
"text": [
735641
"Keys: ['_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_i', '_idx_j']\n",
736-
"Prediction: tensor([1064.8288], dtype=torch.float64, grad_fn=<SubBackward0>)\n"
642+
"Prediction: tensor([-1103.2246], dtype=torch.float64, grad_fn=<AddBackward0>)\n"
737643
]
738644
}
739645
],
@@ -778,7 +684,7 @@
778684
"name": "stdout",
779685
"output_type": "stream",
780686
"text": [
781-
"Prediction: 1064.8287982940674\n"
687+
"Prediction: -1103.2246329784393\n"
782688
]
783689
}
784690
],

Diff for: src/schnetpack/model/base.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from schnetpack.transform import Transform
66
import schnetpack.properties as properties
7+
from schnetpack.utils import as_dtype
78

89
import torch
910
import torch.nn as nn
@@ -30,11 +31,11 @@ def __init__(
3031
representation: nn.Module,
3132
output_module: nn.Module,
3233
postprocessors: Optional[List[Transform]] = None,
33-
input_dtype: torch.dtype = torch.float32,
34+
input_dtype_str: str = "float32",
3435
do_postprocessing: bool = True,
3536
):
3637
super().__init__(
37-
input_dtype=input_dtype,
38+
input_dtype_str=input_dtype_str,
3839
postprocessors=postprocessors,
3940
do_postprocessing=do_postprocessing,
4041
)
@@ -58,18 +59,18 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
5859
def __init__(
5960
self,
6061
postprocessors: Optional[List[Transform]] = None,
61-
input_dtype: torch.dtype = torch.float32,
62+
input_dtype_str: str = "float32",
6263
do_postprocessing: bool = True,
6364
):
6465
"""
6566
Args:
66-
postprocessors: Post-processing transforms tha may be initialized using te `datamodule`, but are not
67-
applied during training.
68-
input_dtype: The dtype of real inputs.
67+
postprocessors: Post-processing transforms tha may be initialized using the
68+
`datamodule`, but are not applied during training.
69+
input_dtype_str: The dtype of real inputs as string.
6970
do_postprocessing: If true, post-processing is activated.
7071
"""
7172
super().__init__()
72-
self.input_dtype = input_dtype
73+
self.input_dtype_str = input_dtype_str
7374
self.do_postprocessing = do_postprocessing
7475
self.postprocessors = nn.ModuleList(postprocessors)
7576
self.required_derivatives: Optional[List[str]] = None
@@ -138,7 +139,7 @@ def __init__(
138139
input_modules: List[nn.Module] = None,
139140
output_modules: List[nn.Module] = None,
140141
postprocessors: Optional[List[Transform]] = None,
141-
input_dtype: torch.dtype = torch.float32,
142+
input_dtype_str: str = "float32",
142143
do_postprocessing: Optional[bool] = None,
143144
):
144145
"""
@@ -149,11 +150,11 @@ def __init__(
149150
output_modules: Modules that predict output properties from the representation.
150151
postprocessors: Post-processing transforms that may be initialized using te `datamodule`, but are not
151152
applied during training.
152-
input_dtype: The dtype of real inputs.
153+
input_dtype_str: The dtype of real inputs.
153154
do_postprocessing: If true, post-processing is activated.
154155
"""
155156
super().__init__(
156-
input_dtype=input_dtype,
157+
input_dtype_str=input_dtype_str,
157158
postprocessors=postprocessors,
158159
do_postprocessing=do_postprocessing,
159160
)

0 commit comments

Comments
 (0)