Skip to content

Commit ae97c8b

Browse files
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <[email protected]>
1 parent 381c449 commit ae97c8b

5 files changed

+173
-22
lines changed

cache_dataset_meta.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import random
3+
4+
from accelerate.utils import set_seed
5+
6+
import library.train_util as train_util
7+
import library.config_util as config_util
8+
from library.config_util import (
9+
ConfigSanitizer,
10+
BlueprintGenerator,
11+
)
12+
import library.custom_train_functions as custom_train_functions
13+
from library.utils import setup_logging, add_logging_arguments
14+
15+
setup_logging()
16+
import logging
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def make_dataset(args):
22+
train_util.prepare_dataset_args(args, True)
23+
setup_logging(args, reset=True)
24+
25+
use_dreambooth_method = args.in_json is None
26+
use_user_config = args.dataset_config is not None
27+
28+
if args.seed is None:
29+
args.seed = random.randint(0, 2**32)
30+
set_seed(args.seed)
31+
32+
# データセットを準備する
33+
if args.dataset_class is None:
34+
blueprint_generator = BlueprintGenerator(
35+
ConfigSanitizer(True, True, False, True)
36+
)
37+
if use_user_config:
38+
logger.info(f"Loading dataset config from {args.dataset_config}")
39+
user_config = config_util.load_user_config(args.dataset_config)
40+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
41+
if any(getattr(args, attr) is not None for attr in ignored):
42+
logger.warning(
43+
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
44+
", ".join(ignored)
45+
)
46+
)
47+
else:
48+
if use_dreambooth_method:
49+
logger.info("Using DreamBooth method.")
50+
user_config = {
51+
"datasets": [
52+
{
53+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
54+
args.train_data_dir, args.reg_data_dir
55+
)
56+
}
57+
]
58+
}
59+
else:
60+
logger.info("Training with captions.")
61+
user_config = {
62+
"datasets": [
63+
{
64+
"subsets": [
65+
{
66+
"image_dir": args.train_data_dir,
67+
"metadata_file": args.in_json,
68+
}
69+
]
70+
}
71+
]
72+
}
73+
74+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
75+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
76+
blueprint.dataset_group
77+
)
78+
else:
79+
# use arbitrary dataset class
80+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
81+
return train_dataset_group
82+
83+
84+
def setup_parser() -> argparse.ArgumentParser:
85+
parser = argparse.ArgumentParser()
86+
add_logging_arguments(parser)
87+
train_util.add_dataset_arguments(parser, True, True, True)
88+
train_util.add_training_arguments(parser, True)
89+
config_util.add_config_arguments(parser)
90+
custom_train_functions.add_custom_train_arguments(parser)
91+
return parser
92+
93+
94+
if __name__ == "__main__":
95+
parser = setup_parser()
96+
97+
args, unknown = parser.parse_known_args()
98+
args = train_util.read_config_from_file(args, parser)
99+
if args.max_token_length is None:
100+
args.max_token_length = 75
101+
args.cache_meta = True
102+
103+
dataset_group = make_dataset(args)

library/config_util.py

+4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
111111
bucket_reso_steps: int = 64
112112
bucket_no_upscale: bool = False
113113
prior_loss_weight: float = 1.0
114+
cache_meta: bool = False
115+
use_cached_meta: bool = False
114116

115117

116118
@dataclass
@@ -228,6 +230,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
228230
"min_bucket_reso": int,
229231
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
230232
"network_multiplier": float,
233+
"cache_meta": bool,
234+
"use_cached_meta": bool,
231235
}
232236

233237
# options handled by argparse but not handled by user config

library/train_util.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from huggingface_hub import hf_hub_download
6464
import numpy as np
6565
from PIL import Image
66+
import imagesize
6667
import cv2
6768
import safetensors.torch
6869
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
@@ -1080,8 +1081,7 @@ def cache_text_encoder_outputs(
10801081
)
10811082

10821083
def get_image_size(self, image_path):
1083-
image = Image.open(image_path)
1084-
return image.size
1084+
return imagesize.get(image_path)
10851085

10861086
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
10871087
img = load_image(image_path)
@@ -1425,6 +1425,8 @@ def __init__(
14251425
bucket_no_upscale: bool,
14261426
prior_loss_weight: float,
14271427
debug_dataset: bool,
1428+
cache_meta: bool,
1429+
use_cached_meta: bool,
14281430
) -> None:
14291431
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
14301432

@@ -1484,26 +1486,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
14841486
logger.warning(f"not directory: {subset.image_dir}")
14851487
return [], []
14861488

1487-
img_paths = glob_images(subset.image_dir, "*")
1489+
sizes = None
1490+
if use_cached_meta:
1491+
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
1492+
# [img_path, caption, resolution]
1493+
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
1494+
metas = f.readlines()
1495+
metas = [x.strip().split("<|##|>") for x in metas]
1496+
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
1497+
1498+
if use_cached_meta:
1499+
img_paths = [x[0] for x in metas]
1500+
else:
1501+
img_paths = glob_images(subset.image_dir, "*")
1502+
sizes = [None]*len(img_paths)
14881503
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
14891504

1490-
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1491-
captions = []
1492-
missing_captions = []
1493-
for img_path in img_paths:
1494-
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
1495-
if cap_for_img is None and subset.class_tokens is None:
1496-
logger.warning(
1497-
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1498-
)
1499-
captions.append("")
1500-
missing_captions.append(img_path)
1501-
else:
1502-
if cap_for_img is None:
1503-
captions.append(subset.class_tokens)
1505+
if use_cached_meta:
1506+
captions = [x[1] for x in metas]
1507+
missing_captions = [x[0] for x in metas if x[1] == ""]
1508+
else:
1509+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1510+
captions = []
1511+
missing_captions = []
1512+
for img_path in img_paths:
1513+
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
1514+
if cap_for_img is None and subset.class_tokens is None:
1515+
logger.warning(
1516+
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
1517+
)
1518+
captions.append("")
15041519
missing_captions.append(img_path)
15051520
else:
1506-
captions.append(cap_for_img)
1521+
if cap_for_img is None:
1522+
captions.append(subset.class_tokens)
1523+
missing_captions.append(img_path)
1524+
else:
1525+
captions.append(cap_for_img)
15071526

15081527
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
15091528

@@ -1520,7 +1539,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15201539
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
15211540
break
15221541
logger.warning(missing_caption)
1523-
return img_paths, captions
1542+
1543+
if cache_meta:
1544+
logger.info(f"cache metadata for {subset.image_dir}")
1545+
if sizes is None or sizes[0] is None:
1546+
sizes = [self.get_image_size(img_path) for img_path in img_paths]
1547+
# [img_path, caption, resolution]
1548+
data = [
1549+
(img_path, caption, " ".join(str(x) for x in size))
1550+
for img_path, caption, size in zip(img_paths, captions, sizes)
1551+
]
1552+
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
1553+
f.write("\n".join(["<|##|>".join(x) for x in data]))
1554+
logger.info(f"cache metadata done for {subset.image_dir}")
1555+
1556+
return img_paths, captions, sizes
15241557

15251558
logger.info("prepare images.")
15261559
num_train_images = 0
@@ -1539,7 +1572,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15391572
)
15401573
continue
15411574

1542-
img_paths, captions = load_dreambooth_dir(subset)
1575+
img_paths, captions, sizes = load_dreambooth_dir(subset)
15431576
if len(img_paths) < 1:
15441577
logger.warning(
15451578
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1551,8 +1584,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
15511584
else:
15521585
num_train_images += subset.num_repeats * len(img_paths)
15531586

1554-
for img_path, caption in zip(img_paths, captions):
1587+
for img_path, caption, size in zip(img_paths, captions, sizes):
15551588
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
1589+
if size is not None:
1590+
info.image_size = size
15561591
if subset.is_reg:
15571592
reg_infos.append((info, subset))
15581593
else:
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
33553390
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
33563391
):
33573392
# dataset common
3393+
parser.add_argument(
3394+
"--cache_meta", action="store_true"
3395+
)
3396+
parser.add_argument(
3397+
"--use_cached_meta", action="store_true"
3398+
)
33583399
parser.add_argument(
33593400
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
33603401
)

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ easygui==0.98.3
1515
toml==0.10.2
1616
voluptuous==0.13.1
1717
huggingface-hub==0.20.1
18+
# for Image utils
19+
imagesize==1.4.1
1820
# for BLIP captioning
1921
# requests==2.28.2
2022
# timm==0.6.12

train_network.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import random
77
import time
88
import json
9+
import pickle
910
from multiprocessing import Value
1011
import toml
1112

@@ -23,7 +24,7 @@
2324

2425
import library.train_util as train_util
2526
from library.train_util import (
26-
DreamBoothDataset,
27+
DreamBoothDataset, DatasetGroup
2728
)
2829
import library.config_util as config_util
2930
from library.config_util import (

0 commit comments

Comments
 (0)