-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_inpaint_imagenet_config.yaml
59 lines (59 loc) · 1.56 KB
/
train_inpaint_imagenet_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# lightning.pytorch==2.0.0
seed_everything: 42
trainer:
default_root_dir: './outputs/training/inpaint/imagenet'
accelerator: gpu
strategy: ddp
devices: 8
num_nodes: 1
precision: bf16-mixed
max_epochs: 20
max_steps: -1
overfit_batches: 0.0
val_check_interval: 1.0
check_val_every_n_epoch: 1
num_sanity_val_steps: 2
log_every_n_steps: 50
enable_checkpointing: true
benchmark: true
detect_anomaly: false
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: dirac-inpaint-imagenet
callbacks:
- class_path: scripts.utils.EMAModelCheckpoint
init_args:
save_top_k: 1
verbose: true
monitor: val/weighted_mse_loss
mode: min
filename: epoch-{epoch}-val-loss{val/weighted_mse_loss:.4f}
auto_insert_metric_name: false
save_last: true
- class_path: scripts.utils.EMA
init_args:
decay: 0.9999
validate_original_weights: false
every_n_steps: 1
cpu_offload: false
model:
dt: 0.02
experiment_config_file: configs/experiments/inpainting_gaussian_imagenet.yaml
loss_type: 'weighted_mse_loss'
lr: 0.0001
residual_prediction: true
model_arch: large
weight_decay: 0.0
logger_type: wandb
full_val_only_last_epoch: true
num_log_images: 10
model_conditioning: noise
data:
class_path: pl_modules.image_data_module.ImageNetDataModule
init_args:
batch_size: 4
sample_rate_dict: {'train': 1.0, 'val': 1.0, 'test': 1.0}
distributed_sampler: true