63
63
from huggingface_hub import hf_hub_download
64
64
import numpy as np
65
65
from PIL import Image
66
+ import imagesize
66
67
import cv2
67
68
import safetensors .torch
68
69
from library .lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
@@ -1080,8 +1081,7 @@ def cache_text_encoder_outputs(
1080
1081
)
1081
1082
1082
1083
def get_image_size (self , image_path ):
1083
- image = Image .open (image_path )
1084
- return image .size
1084
+ return imagesize .get (image_path )
1085
1085
1086
1086
def load_image_with_face_info (self , subset : BaseSubset , image_path : str ):
1087
1087
img = load_image (image_path )
@@ -1425,6 +1425,8 @@ def __init__(
1425
1425
bucket_no_upscale : bool ,
1426
1426
prior_loss_weight : float ,
1427
1427
debug_dataset : bool ,
1428
+ cache_meta : bool ,
1429
+ use_cached_meta : bool ,
1428
1430
) -> None :
1429
1431
super ().__init__ (tokenizer , max_token_length , resolution , network_multiplier , debug_dataset )
1430
1432
@@ -1484,26 +1486,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1484
1486
logger .warning (f"not directory: { subset .image_dir } " )
1485
1487
return [], []
1486
1488
1487
- img_paths = glob_images (subset .image_dir , "*" )
1489
+ sizes = None
1490
+ if use_cached_meta :
1491
+ logger .info (f"using cached metadata: { subset .image_dir } /dataset.txt" )
1492
+ # [img_path, caption, resolution]
1493
+ with open (f"{ subset .image_dir } /dataset.txt" , "r" , encoding = "utf-8" ) as f :
1494
+ metas = f .readlines ()
1495
+ metas = [x .strip ().split ("<|##|>" ) for x in metas ]
1496
+ sizes = [tuple (int (res ) for res in x [2 ].split (" " )) for x in metas ]
1497
+
1498
+ if use_cached_meta :
1499
+ img_paths = [x [0 ] for x in metas ]
1500
+ else :
1501
+ img_paths = glob_images (subset .image_dir , "*" )
1502
+ sizes = [None ]* len (img_paths )
1488
1503
logger .info (f"found directory { subset .image_dir } contains { len (img_paths )} image files" )
1489
1504
1490
- # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1491
- captions = []
1492
- missing_captions = []
1493
- for img_path in img_paths :
1494
- cap_for_img = read_caption ( img_path , subset . caption_extension , subset . enable_wildcard )
1495
- if cap_for_img is None and subset . class_tokens is None :
1496
- logger . warning (
1497
- f"neither caption file nor class tokens are found. use empty caption for { img_path } / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: { img_path } "
1498
- )
1499
- captions . append ( "" )
1500
- missing_captions . append ( img_path )
1501
- else :
1502
- if cap_for_img is None :
1503
- captions .append (subset . class_tokens )
1505
+ if use_cached_meta :
1506
+ captions = [x [ 1 ] for x in metas ]
1507
+ missing_captions = [x [ 0 ] for x in metas if x [ 1 ] == "" ]
1508
+ else :
1509
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1510
+ captions = []
1511
+ missing_captions = []
1512
+ for img_path in img_paths :
1513
+ cap_for_img = read_caption ( img_path , subset . caption_extension , subset . enable_wildcard )
1514
+ if cap_for_img is None and subset . class_tokens is None :
1515
+ logger . warning (
1516
+ f"neither caption file nor class tokens are found. use empty caption for { img_path } / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: { img_path } "
1517
+ )
1518
+ captions .append ("" )
1504
1519
missing_captions .append (img_path )
1505
1520
else :
1506
- captions .append (cap_for_img )
1521
+ if cap_for_img is None :
1522
+ captions .append (subset .class_tokens )
1523
+ missing_captions .append (img_path )
1524
+ else :
1525
+ captions .append (cap_for_img )
1507
1526
1508
1527
self .set_tag_frequency (os .path .basename (subset .image_dir ), captions ) # タグ頻度を記録
1509
1528
@@ -1520,7 +1539,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1520
1539
logger .warning (missing_caption + f"... and { remaining_missing_captions } more" )
1521
1540
break
1522
1541
logger .warning (missing_caption )
1523
- return img_paths , captions
1542
+
1543
+ if cache_meta :
1544
+ logger .info (f"cache metadata for { subset .image_dir } " )
1545
+ if sizes is None or sizes [0 ] is None :
1546
+ sizes = [self .get_image_size (img_path ) for img_path in img_paths ]
1547
+ # [img_path, caption, resolution]
1548
+ data = [
1549
+ (img_path , caption , " " .join (str (x ) for x in size ))
1550
+ for img_path , caption , size in zip (img_paths , captions , sizes )
1551
+ ]
1552
+ with open (f"{ subset .image_dir } /dataset.txt" , "w" , encoding = "utf-8" ) as f :
1553
+ f .write ("\n " .join (["<|##|>" .join (x ) for x in data ]))
1554
+ logger .info (f"cache metadata done for { subset .image_dir } " )
1555
+
1556
+ return img_paths , captions , sizes
1524
1557
1525
1558
logger .info ("prepare images." )
1526
1559
num_train_images = 0
@@ -1539,7 +1572,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1539
1572
)
1540
1573
continue
1541
1574
1542
- img_paths , captions = load_dreambooth_dir (subset )
1575
+ img_paths , captions , sizes = load_dreambooth_dir (subset )
1543
1576
if len (img_paths ) < 1 :
1544
1577
logger .warning (
1545
1578
f"ignore subset with image_dir='{ subset .image_dir } ': no images found / 画像が見つからないためサブセットを無視します"
@@ -1551,8 +1584,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1551
1584
else :
1552
1585
num_train_images += subset .num_repeats * len (img_paths )
1553
1586
1554
- for img_path , caption in zip (img_paths , captions ):
1587
+ for img_path , caption , size in zip (img_paths , captions , sizes ):
1555
1588
info = ImageInfo (img_path , subset .num_repeats , caption , subset .is_reg , img_path )
1589
+ if size is not None :
1590
+ info .image_size = size
1556
1591
if subset .is_reg :
1557
1592
reg_infos .append ((info , subset ))
1558
1593
else :
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
3355
3390
parser : argparse .ArgumentParser , support_dreambooth : bool , support_caption : bool , support_caption_dropout : bool
3356
3391
):
3357
3392
# dataset common
3393
+ parser .add_argument (
3394
+ "--cache_meta" , action = "store_true"
3395
+ )
3396
+ parser .add_argument (
3397
+ "--use_cached_meta" , action = "store_true"
3398
+ )
3358
3399
parser .add_argument (
3359
3400
"--train_data_dir" , type = str , default = None , help = "directory for train images / 学習画像データのディレクトリ"
3360
3401
)
0 commit comments