-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathexperiment_stereoPrediction_baseline.py
47 lines (39 loc) · 1.17 KB
/
experiment_stereoPrediction_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import matplotlib
matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
# import train
import train_baseline as train
options = {
# global setup settings, and checkpoints
'name': 'stereoPrediction_baseline',
'seed': 123,
'checkpoint_output_directory': 'checkpoints',
# model and dataset
'dataset_file': 'datasets.dataset_stereoPrediction',
'model_file': 'models.model_stereoPrediction_baseline',
'pretrained_model_path': None,
# training parameters
'image_dim': 64,
'batch_size': 16,
'loss': 'squared_error',
'learning_rate': 1e-3,
'decay_after': 20,
'num_epochs': 100,
'batches_per_epoch': 2 * 100,
'save_after': 10
}
modelOptions = {
'batch_size': options['batch_size'],
'npx': options['image_dim'],
'input_seqlen': 1,
'target_seqlen': 1,
'buffer_len': 1,
'dynamic_filter_size': (13, 1)
}
options['modelOptions'] = modelOptions
datasetOptions = {
'batch_size': options['batch_size'],
'image_size': options['image_dim'],
'num_frames': modelOptions['input_seqlen'] + modelOptions['target_seqlen']
}
options['datasetOptions'] = datasetOptions
train.train(options)