-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy path__init__.py
108 lines (88 loc) · 3.37 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import Dict, Any, Union
from omegaconf import DictConfig
from data.qa import (
QADataset,
QAwithIdkDataset,
)
from data.collators import (
DataCollatorForSupervisedDataset,
)
from data.unlearn import ForgetRetainDataset
from data.pretraining import PretrainingDataset, CompletionDataset
DATASET_REGISTRY: Dict[str, Any] = {}
COLLATOR_REGISTRY: Dict[str, Any] = {}
def _register_data(data_class):
DATASET_REGISTRY[data_class.__name__] = data_class
def _register_collator(collator_class):
COLLATOR_REGISTRY[collator_class.__name__] = collator_class
def _load_single_dataset(dataset_name, dataset_cfg: DictConfig, **kwargs):
dataset_handler_name = dataset_cfg.get("handler")
assert dataset_handler_name is not None, ValueError(
f"{dataset_name} handler not set"
)
dataset_handler = DATASET_REGISTRY.get(dataset_handler_name)
if dataset_handler is None:
raise NotImplementedError(
f"{dataset_handler_name} not implemented or not registered"
)
dataset_args = dataset_cfg.args
return dataset_handler(**dataset_args, **kwargs)
def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs):
dataset = {}
for dataset_name, dataset_cfg in dataset_cfgs.items():
dataset[dataset_name] = _load_single_dataset(
dataset_name, dataset_cfg, **kwargs
)
if len(dataset) == 1:
# return a single dataset
return list(dataset.values())[0]
# return mapping to multiple datasets
return dataset
def get_data(data_cfg: DictConfig, mode="train", **kwargs):
data = {}
data_cfg = dict(data_cfg)
anchor = data_cfg.pop("anchor", "forget")
for split, dataset_cfgs in data_cfg.items():
data[split] = get_datasets(dataset_cfgs, **kwargs)
if mode == "train":
return data
elif mode == "unlearn":
unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")}
unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor)
data["train"] = unlearn_dataset
for split in unlearn_splits:
data.pop(split)
return data
def _get_single_collator(collator_name: str, collator_cfg: DictConfig, **kwargs):
collator_handler_name = collator_cfg.get("handler")
assert collator_handler_name is not None, ValueError(
f"{collator_name} handler not set"
)
collator_handler = COLLATOR_REGISTRY.get(collator_handler_name)
if collator_handler is None:
raise NotImplementedError(
f"{collator_handler_name} not implemented or not registered"
)
collator_args = collator_cfg.args
return collator_handler(**collator_args, **kwargs)
def get_collators(collator_cfgs, **kwargs):
collators = {}
for collator_name, collator_cfg in collator_cfgs.items():
collators[collator_name] = _get_single_collator(
collator_name, collator_cfg, **kwargs
)
if len(collators) == 1:
# return a single collator
return list(collators.values())[0]
# return collators in a dict
return collators
# Register datasets
_register_data(QADataset)
_register_data(QAwithIdkDataset)
_register_data(PretrainingDataset)
_register_data(CompletionDataset)
# Register composite datasets used in unlearning
# groups: unlearn
_register_data(ForgetRetainDataset)
# Register collators
_register_collator(DataCollatorForSupervisedDataset)