Skip to content

Commit 5472398

Browse files
KohakuBlueleafkohya-ss
authored andcommitted
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (kohya-ss#1178)
* support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <[email protected]>
1 parent 742ebd1 commit 5472398

5 files changed

+173
-22
lines changed

cache_dataset_meta.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import random
3+
4+
from accelerate.utils import set_seed
5+
6+
import library.train_util as train_util
7+
import library.config_util as config_util
8+
from library.config_util import (
9+
ConfigSanitizer,
10+
BlueprintGenerator,
11+
)
12+
import library.custom_train_functions as custom_train_functions
13+
from library.utils import setup_logging, add_logging_arguments
14+
15+
setup_logging()
16+
import logging
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def make_dataset(args):
22+
train_util.prepare_dataset_args(args, True)
23+
setup_logging(args, reset=True)
24+
25+
use_dreambooth_method = args.in_json is None
26+
use_user_config = args.dataset_config is not None
27+
28+
if args.seed is None:
29+
args.seed = random.randint(0, 2**32)
30+
set_seed(args.seed)
31+
32+
# データセットを準備する
33+
if args.dataset_class is None:
34+
blueprint_generator = BlueprintGenerator(
35+
ConfigSanitizer(True, True, False, True)
36+
)
37+
if use_user_config:
38+
logger.info(f"Loading dataset config from {args.dataset_config}")
39+
user_config = config_util.load_user_config(args.dataset_config)
40+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
41+
if any(getattr(args, attr) is not None for attr in ignored):
42+
logger.warning(
43+
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
44+
", ".join(ignored)
45+
)
46+
)
47+
else:
48+
if use_dreambooth_method:
49+
logger.info("Using DreamBooth method.")
50+
user_config = {
51+
"datasets": [
52+
{
53+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
54+
args.train_data_dir, args.reg_data_dir
55+
)
56+
}
57+
]
58+
}
59+
else:
60+
logger.info("Training with captions.")
61+
user_config = {
62+
"datasets": [
63+
{
64+
"subsets": [
65+
{
66+
"image_dir": args.train_data_dir,
67+
"metadata_file": args.in_json,
68+
}
69+
]
70+
}
71+
]
72+
}
73+
74+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
75+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
76+
blueprint.dataset_group
77+
)
78+
else:
79+
# use arbitrary dataset class
80+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
81+
return train_dataset_group
82+
83+
84+
def setup_parser() -> argparse.ArgumentParser:
85+
parser = argparse.ArgumentParser()
86+
add_logging_arguments(parser)
87+
train_util.add_dataset_arguments(parser, True, True, True)
88+
train_util.add_training_arguments(parser, True)
89+
config_util.add_config_arguments(parser)
90+
custom_train_functions.add_custom_train_arguments(parser)
91+
return parser
92+
93+
94+
if __name__ == "__main__":
95+
parser = setup_parser()
96+
97+
args, unknown = parser.parse_known_args()
98+
args = train_util.read_config_from_file(args, parser)
99+
if args.max_token_length is None:
100+
args.max_token_length = 75
101+
args.cache_meta = True
102+
103+
dataset_group = make_dataset(args)

library/config_util.py

+4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
110110
bucket_reso_steps: int = 64
111111
bucket_no_upscale: bool = False
112112
prior_loss_weight: float = 1.0
113+
cache_meta: bool = False
114+
use_cached_meta: bool = False
113115

114116

115117
@dataclass
@@ -225,6 +227,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
225227
"min_bucket_reso": int,
226228
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
227229
"network_multiplier": float,
230+
"cache_meta": bool,
231+
"use_cached_meta": bool,
228232
}
229233

230234
# options handled by argparse but not handled by user config

library/train_util.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from huggingface_hub import hf_hub_download
6464
import numpy as np
6565
from PIL import Image
66+
import imagesize
6667
import cv2
6768
import safetensors.torch
6869
import traceback
@@ -1033,8 +1034,7 @@ def cache_text_encoder_outputs(
10331034
)
10341035

10351036
def get_image_size(self, image_path):
1036-
image = Image.open(image_path)
1037-
return image.size
1037+
return imagesize.get(image_path)
10381038

10391039
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
10401040
img = load_image(image_path)
@@ -1396,6 +1396,8 @@ def __init__(
13961396
bucket_no_upscale: bool,
13971397
prior_loss_weight: float,
13981398
debug_dataset: bool,
1399+
cache_meta: bool,
1400+
use_cached_meta: bool,
13991401
) -> None:
14001402
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset, trust_cache)
14011403

@@ -1452,26 +1454,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
14521454
logger.warning(f"not directory: {subset.image_dir}")
14531455
return [], []
14541456

1455-
img_paths = glob_images(subset.image_dir, "*")
1457+
sizes = None
1458+
if use_cached_meta:
1459+
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
1460+
# [img_path, caption, resolution]
1461+
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
1462+
metas = f.readlines()
1463+
metas = [x.strip().split("<|##|>") for x in metas]
1464+
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
1465+
1466+
if use_cached_meta:
1467+
img_paths = [x[0] for x in metas]
1468+
else:
1469+
img_paths = glob_images(subset.image_dir, "*")
1470+
sizes = [None]*len(img_paths)
14561471
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
14571472

1458-
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1459-
captions = []
1460-
missing_captions = []
1461-
for img_path in img_paths:
1462-
cap_for_img = read_caption(img_path, subset.caption_extension)
1463-
if cap_for_img is None and subset.class_tokens is None:
1464-
logger.warning(
1465-
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1466-
)
1467-
captions.append("")
1468-
missing_captions.append(img_path)
1469-
else:
1470-
if cap_for_img is None:
1471-
captions.append(subset.class_tokens)
1473+
if use_cached_meta:
1474+
captions = [x[1] for x in metas]
1475+
missing_captions = [x[0] for x in metas if x[1] == ""]
1476+
else:
1477+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1478+
captions = []
1479+
missing_captions = []
1480+
for img_path in img_paths:
1481+
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
1482+
if cap_for_img is None and subset.class_tokens is None:
1483+
logger.warning(
1484+
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1485+
)
1486+
captions.append("")
14721487
missing_captions.append(img_path)
14731488
else:
1474-
captions.append(cap_for_img)
1489+
if cap_for_img is None:
1490+
captions.append(subset.class_tokens)
1491+
missing_captions.append(img_path)
1492+
else:
1493+
captions.append(cap_for_img)
14751494

14761495
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
14771496

@@ -1488,7 +1507,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
14881507
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
14891508
break
14901509
logger.warning(missing_caption)
1491-
return img_paths, captions
1510+
1511+
if cache_meta:
1512+
logger.info(f"cache metadata for {subset.image_dir}")
1513+
if sizes is None or sizes[0] is None:
1514+
sizes = [self.get_image_size(img_path) for img_path in img_paths]
1515+
# [img_path, caption, resolution]
1516+
data = [
1517+
(img_path, caption, " ".join(str(x) for x in size))
1518+
for img_path, caption, size in zip(img_paths, captions, sizes)
1519+
]
1520+
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
1521+
f.write("\n".join(["<|##|>".join(x) for x in data]))
1522+
logger.info(f"cache metadata done for {subset.image_dir}")
1523+
1524+
return img_paths, captions, sizes
14921525

14931526
logger.info("prepare images.")
14941527
num_train_images = 0
@@ -1507,7 +1540,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15071540
)
15081541
continue
15091542

1510-
img_paths, captions = load_dreambooth_dir(subset)
1543+
img_paths, captions, sizes = load_dreambooth_dir(subset)
15111544
if len(img_paths) < 1:
15121545
logger.warning(
15131546
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1519,8 +1552,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15191552
else:
15201553
num_train_images += subset.num_repeats * len(img_paths)
15211554

1522-
for img_path, caption in zip(img_paths, captions):
1555+
for img_path, caption, size in zip(img_paths, captions, sizes):
15231556
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
1557+
if size is not None:
1558+
info.image_size = size
15241559
if subset.is_reg:
15251560
reg_infos.append(info)
15261561
else:
@@ -3294,6 +3329,12 @@ def add_dataset_arguments(
32943329
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
32953330
):
32963331
# dataset common
3332+
parser.add_argument(
3333+
"--cache_meta", action="store_true"
3334+
)
3335+
parser.add_argument(
3336+
"--use_cached_meta", action="store_true"
3337+
)
32973338
parser.add_argument(
32983339
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
32993340
)

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ easygui==0.98.3
1515
toml==0.10.2
1616
voluptuous==0.13.1
1717
huggingface-hub==0.20.1
18+
# for Image utils
19+
imagesize==1.4.1
1820
# for BLIP captioning
1921
# requests==2.28.2
2022
# timm==0.6.12

train_network.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import time
88
import json
9+
import pickle
910
from multiprocessing import Value
1011
import toml
1112

@@ -23,7 +24,7 @@
2324

2425
import library.train_util as train_util
2526
from library.train_util import (
26-
DreamBoothDataset,
27+
DreamBoothDataset, DatasetGroup
2728
)
2829
import library.config_util as config_util
2930
from library.config_util import (

0 commit comments

Comments
 (0)