Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning Migration #837

Merged
merged 79 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
1b8bfce
initial version pytorch lightning
karl-richter Sep 6, 2022
32d5b6e
first version lightning fit()
karl-richter Sep 14, 2022
e5b0e7b
added lightning_logs to gitignore
karl-richter Sep 19, 2022
48f81b3
converted test function to lightning
karl-richter Sep 19, 2022
5390e41
converted predict function to lightning
karl-richter Sep 19, 2022
0c19ed5
added compute_components support for lightning
karl-richter Sep 26, 2022
15023dd
added minimal training support for lightning
karl-richter Sep 26, 2022
ca9ee5e
added epochs to lightning config
karl-richter Sep 26, 2022
b464fe0
Merge branch 'main' into lightning
karl-richter Sep 26, 2022
ed1d935
moved trainer config to utils
karl-richter Sep 26, 2022
76466af
handle multi-batch predictions
karl-richter Sep 26, 2022
f9430c1
refactoring
karl-richter Sep 26, 2022
341d422
added custom logger for metrics
karl-richter Oct 5, 2022
8a673b5
renamed metrics logger
karl-richter Oct 6, 2022
48a6574
added scheduler to lightning
karl-richter Oct 8, 2022
9237743
fixed uncertainty prediction metrics
karl-richter Oct 8, 2022
f4c2c14
added predict_mode flag in lightning model
karl-richter Oct 9, 2022
3642d3b
updated torchmetrics imports based on deprecation warnings
karl-richter Oct 10, 2022
214f7b2
add denormalization support for torchmetrics
karl-richter Oct 10, 2022
4b5a223
replace tqdm with RichProgressBar
karl-richter Oct 10, 2022
9613998
add rich in requirements.txt
karl-richter Oct 10, 2022
a95c846
custom colors in rich progress bar
karl-richter Oct 10, 2022
3f96363
re-added denomralization in metrics
karl-richter Oct 10, 2022
f05bb6d
refactored metrics in time_net and added docs
karl-richter Oct 10, 2022
e7ab7f0
refactored metrics in time_net
karl-richter Oct 10, 2022
3081fb8
support arbitrary loggers
karl-richter Oct 11, 2022
f658676
changed model saving loading test
karl-richter Oct 11, 2022
2c061aa
Merge branch 'main' into lightning
karl-richter Oct 11, 2022
110807a
support lightning lr finder
karl-richter Oct 12, 2022
37f9a32
refactored minimal training implementation
karl-richter Oct 12, 2022
fc77146
refactored minimal training implementation
karl-richter Oct 12, 2022
ed0d222
configure learning rate finder
karl-richter Oct 12, 2022
8c2bfcb
removed outdated pytests
karl-richter Oct 12, 2022
144ec67
add early stopping support
karl-richter Oct 12, 2022
854a4ea
dynamically choose early stopping target
karl-richter Oct 13, 2022
6b18376
deactivate additinal loggers temprarily
karl-richter Oct 13, 2022
6b73b27
refactored saving loading
karl-richter Oct 13, 2022
14254a5
bugfix for small training batches
karl-richter Oct 13, 2022
47bd4ba
outsource metrics
karl-richter Oct 13, 2022
51505ad
support saving / loading model
karl-richter Oct 14, 2022
610b7de
include model in saving checkpoint
karl-richter Oct 14, 2022
e31c104
move optimizer and scheduler back to configure
karl-richter Oct 14, 2022
f64713e
pass denormalization as function + cleanup
karl-richter Oct 16, 2022
29de6ad
Merge branch 'main' into lightning
karl-richter Oct 16, 2022
126e539
Merge branch 'main' into lightning
karl-richter Oct 16, 2022
13a8d13
moved lr_finder_args to configure.py
karl-richter Oct 16, 2022
b460101
added checkpoints to gitignore
karl-richter Oct 16, 2022
02bdc83
removed legacy lr_finder
karl-richter Oct 16, 2022
509eaa0
removed unused imports
karl-richter Oct 16, 2022
1236abd
removed lightning notebook
karl-richter Oct 16, 2022
6a16ef0
Merge pull request #792 from karl-richter/lightning
karl-richter Oct 17, 2022
39388ee
Merge branch 'main' into lightning
karl-richter Oct 18, 2022
6162c03
Update log_every_n_steps
karl-richter Oct 18, 2022
d990de5
pass denormalization as parameter
karl-richter Oct 18, 2022
1e46a16
Merge branch 'main' into lightning
karl-richter Oct 18, 2022
2006c6c
Merge branch 'main' into lightning
karl-richter Oct 18, 2022
3683bd4
Merge branch 'main' into lightning
karl-richter Oct 19, 2022
51674a9
migrate custom logger to tensorboard
karl-richter Oct 19, 2022
c089b1e
Merge branch 'main' into lightning
karl-richter Oct 21, 2022
d81b12e
Resolve circular import dependcies
karl-richter Oct 21, 2022
0fa28cb
[lightning] Learning rate finder optimization (#892)
karl-richter Oct 25, 2022
0f9f06e
early stopping configuration
karl-richter Oct 27, 2022
ed66ed7
Merge branch 'main' into lightning
karl-richter Oct 27, 2022
77579f6
removed legacy metrics
karl-richter Oct 27, 2022
1b17c68
early stopping configuration
karl-richter Oct 27, 2022
fecf91e
silence tensorboard deprecation warnings
karl-richter Oct 27, 2022
09ea0f2
early stopping configuration
karl-richter Oct 27, 2022
57aad19
remove epoch plot from metrics
karl-richter Oct 27, 2022
89e84a0
added docs for trainer_config
karl-richter Nov 4, 2022
ccbac82
Merge branch 'main' into lightning
karl-richter Nov 7, 2022
43bdcb9
added support for layer visualization
karl-richter Nov 7, 2022
20ac087
fixed isort
karl-richter Nov 7, 2022
6f9c52f
fixed isort + flake8
karl-richter Nov 7, 2022
d5c3b28
reduce warning messages
karl-richter Nov 7, 2022
fcb4709
reduce warning messages
karl-richter Nov 7, 2022
e097853
fixed flake8
karl-richter Nov 7, 2022
a1db9aa
fixed flake8
karl-richter Nov 7, 2022
48a0e3d
fixed flake8
karl-richter Nov 8, 2022
2dc717e
Merge branch 'main' into lightning
karl-richter Nov 11, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ site/
**/*.DS_Store
tests/test-logs/
**/test_save_model.np
lightning_logs/
logs/
*.ckpt
*.pt
tests/metrics/*.json
tests/metrics/*.png
tests/metrics/*.svg

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
87 changes: 52 additions & 35 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class Train:
batch_size: Union[int, None]
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable]
optimizer: Union[str, torch.optim.Optimizer]
optimizer_args: dict = field(default_factory=dict)
scheduler: torch.optim.lr_scheduler._LRScheduler = None
scheduler_args: dict = field(default_factory=dict)
newer_samples_weight: float = 1.0
newer_samples_start: float = 0.0
reg_delay_pct: float = 0.5
Expand All @@ -103,13 +106,20 @@ class Train:
reg_lambda_season: float = None
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
early_stopping: bool = False
lr_finder_args: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
self.set_quantiles()
assert self.newer_samples_weight >= 1.0
assert self.newer_samples_start >= 0.0
assert self.newer_samples_start < 1.0
self.set_loss_func()
self.set_optimizer()
self.set_scheduler()

def set_loss_func(self):
if type(self.loss_func) == str:
if self.loss_func.lower() in ["huber", "smoothl1", "smoothl1loss"]:
self.loss_func = torch.nn.SmoothL1Loss(reduction="none")
Expand Down Expand Up @@ -171,19 +181,48 @@ def set_auto_batch_epoch(
# also set lambda_delay:
self.lambda_delay = int(self.reg_delay_pct * self.epochs)

def get_optimizer(self, model_parameters):
return utils_torch.create_optimizer_from_config(self.optimizer, model_parameters, self.learning_rate)

def get_scheduler(self, optimizer, steps_per_epoch):
return torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.learning_rate,
epochs=self.epochs,
steps_per_epoch=steps_per_epoch,
pct_start=0.3,
anneal_strategy="cos",
div_factor=100.0,
final_div_factor=5000.0,
def set_optimizer(self):
"""
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding torch optimizer.
The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config(
self.optimizer, self.optimizer_args
)
Comment on lines +184 to +191
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ourownstory I would propose to move the logic from create_optimizer_from_config directly into the set_optimzer function since that's how we also handle it for set_loss_func and the other setter functions. What do you think?


def set_scheduler(self):
"""
Set the scheduler and scheduler args.
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 100.0,
"final_div_factor": 5000.0,
}
)

def set_lr_finder_args(self, dataset_size, num_batches):
"""
Set the lr_finder_args.
This is the range of learning rates to test.
"""
num_training = 150 + int(np.log10(100 + dataset_size) * 25)
if num_batches < num_training:
log.warning(
f"Learning rate finder: The number of batches ({num_batches}) is too small than the required number for the learning rate finder ({num_training}). The results might not be optimal."
)
# num_training = num_batches
self.lr_finder_args.update(
{
"min_lr": 1e-6,
"max_lr": 10,
"num_training": num_training,
"early_stop_threshold": None,
}
)

def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0):
Expand All @@ -200,28 +239,6 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re
delay_weight = 1
return delay_weight

def find_learning_rate(self, model, dataset, repeat: int = 2):
if issubclass(self.loss_func.__class__, torch.nn.modules.loss._Loss):
try:
loss_func = getattr(torch.nn.modules.loss, self.loss_func_name)()
except AttributeError:
raise ValueError("automatic learning rate only supported for regular torch loss functions.")
else:
raise ValueError("automatic learning rate only supported for regular torch loss functions.")
lrs = [0.1]
for i in range(repeat):
lr = utils_torch.lr_range_test(
model,
dataset,
loss_func=loss_func,
optimizer=self.optimizer,
batch_size=self.batch_size,
)
lrs.append(lr)
lrs_log10_mean = sum([np.log10(x) for x in lrs]) / len(lrs)
learning_rate = 10**lrs_log10_mean
return learning_rate


@dataclass
class Trend:
Expand Down
Loading