This repository was archived by the owner on Aug 20, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexperiment_args.py
135 lines (118 loc) · 3.97 KB
/
experiment_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from dataclasses import dataclass, field
import logging
from os.path import join, isabs, exists
from util import (
get_current_ts,
create_dir_if_not_exists,
get_root_path,
string_to_underscore
)
logger = logging.getLogger(__name__)
@dataclass
class ExperimentArguments:
"""
Arguments related to the dataset
"""
dataset_name: str = field(metadata={
'help': 'The name of the dataset'
})
market: str = field(
metadata={'help': 'The market if classification'}
)
metapaths: list = field(
default=None,
metadata={'help': 'The metapaths to use'}
)
market_source: str = field(
default='diffbot',
metadata={'help': 'The source of the market labels',
'choices': ['diffbot']}
)
experiment_name: str = field(
default='default',
metadata={'help': 'The name of the experiment'}
)
debug: str = field(
default=False,
metadata={'help': 'Whether to use debug logging'}
)
node_embds_path: str = field(
default=None,
metadata={'help': 'The path to the text node embeddings if any. Can be an absolute path'
'or path starting from the root of project directory'}
)
natts_path: str = field(
default=None,
metadata={
'help': 'The path to the node attribute specifications'
})
task: str = field(
default='binary_classification',
metadata={'help': 'The training and testing task'}
)
model: str = field(
default='hrgcn',
metadata={'help': 'The name of the model to use. Default is magnn',
'choices': ['han', 'hrgcn', 'mlp']}
)
def __post_init__(self):
market_name = '_'.join(self.market.lower().replace('-', ' ').split())
exp_str = string_to_underscore(self.experiment_name)
save_path_name = f'{exp_str}_{self.dataset_name}_{market_name}_{self.model}'
self.log_dir = join(get_root_path(), 'logs', save_path_name + f'_{get_current_ts()}')
create_dir_if_not_exists(self.log_dir)
if self.node_embds_path is not None:
self.node_embds_path = join(get_root_path(), self.node_embds_path)
if self.natts_path is not None:
if not isabs(self.natts_path):
self.natts_path = self.natts_path.lstrip('.')
self.natts_path = join(get_root_path(), self.natts_path)
assert exists(self.natts_path), f"The natts path {self.natts_path} does not exist"
if self.model == 'han':
assert self.metapaths is not None, 'Need to specify metapaths for han model'
@dataclass
class TrainingArguments:
num_layers: str = field(
default=2,
metadata={'help': 'Number of layers. Default is 2'})
device: int = field(
default=0,
metadata={'help': 'The cuda device'}
)
hidden_dim: int = field(
default=64,
metadata={'help': 'Dimension of the node hidden state. Default is 64.'}
)
num_heads: int = field(
default=8,
metadata={'help': 'Number of the attention heads. Default is 8.'}
)
num_epochs: int = field(
default=100,
metadata={'help': 'Number of epochs. Default is 100'}
)
patience: int = field(
default=10,
metadata={'help': 'How long to wait after last time validation loss improved. Default is 10'}
)
repeat: int = field(
default=1,
metadata={'help': 'Repeat the training and testing for N times. Default is 1.'}
)
dropout_rate: float = field(
default=0.5
)
val_metric: str = field(
default='f1',
metadata={'help': 'The validation metric in which early stopping is determined'}
)
lr: float = field(
default=0.005
)
weight_decay: float = field(
default=0.001
)
print_val_epochs: int = field(
default=10,
metadata={'help': 'Number of epochs betweeen printing validation results'}
)