diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index 08714d270787..6609dd8079ed 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -16,7 +16,7 @@ from typing import Dict, List, Optional, Union import torch -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text import AudioLabelDataset @@ -246,7 +246,10 @@ def change_labels(self, new_labels: List[str]): # Update config self._cfg.labels = new_labels - self._cfg.decoder.params = new_decoder_config + + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) if 'train_ds' in self._cfg and self._cfg.train_ds is not None: self._cfg.train_ds.labels = new_labels diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index bdfc672fc131..c4c1297a09b9 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -256,17 +256,20 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): # Override number of classes if placeholder provided logging.info( "\nReplacing old number of classes ({}) with new number of classes - {}".format( - decoder_config.params['num_classes'], len(vocabulary) + decoder_config['params']['num_classes'], len(vocabulary) ) ) - decoder_config.params['num_classes'] = len(vocabulary) + decoder_config['params']['num_classes'] = len(vocabulary) del self.decoder self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) self._wer = WERBPE(tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True) # Update config - self._cfg.decoder.params = decoder_config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index ebd2b2021cc2..9c7cfc67976b 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Union import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from nemo.collections.asr.data.audio_to_text import AudioToCharDataset, TarredAudioToCharDataset @@ -154,7 +154,10 @@ def change_vocabulary(self, new_vocabulary: List[str]): self._wer = WER(vocabulary=self.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True) # Update config - self._cfg.decoder.params = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, False) + self._cfg.decoder = new_decoder_config + OmegaConf.set_struct(self._cfg.decoder, True) + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") def _setup_dataloader_from_config(self, config: Optional[Dict]): diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 58c070f612a8..93f0b548ec7f 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -78,7 +78,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._cfg = config - self.save_hyperparameters(self._cfg) + self.save_hyperparameters(OmegaConf.to_container(self._cfg, resolve=True)) self._train_dl = None self._validation_dl = None self._test_dl = None @@ -231,6 +231,7 @@ def load_from_checkpoint( Loads ModelPT from checkpoint, with some maintenance of restoration. For documentation, please refer to LightningModule.load_from_checkpoin() documentation. """ + # TODO (@titu1994): When PTL 0.9+ is supported, add `strict=False` flag to constructor checkpoint = None try: cls.__set_model_restore_state(is_being_restored=True) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 01a50798a661..d8fe9b25ba4a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,12 +1,12 @@ numpy>=1.18.2 onnx>=1.7.0 -pytorch-lightning>=0.8.5 +pytorch-lightning==0.8.5 python-dateutil torch wget wrapt ruamel.yaml scikit-learn -omegaconf==2.0.1rc11 -hydra-core==1.0.0rc3 +omegaconf==2.0.1rc12 +hydra-core==1.0.0rc4 transformers>=2.11.0 \ No newline at end of file