63
63
from huggingface_hub import hf_hub_download
64
64
import numpy as np
65
65
from PIL import Image
66
+ import imagesize
66
67
import cv2
67
68
import safetensors .torch
68
69
import traceback
@@ -1033,8 +1034,7 @@ def cache_text_encoder_outputs(
1033
1034
)
1034
1035
1035
1036
def get_image_size (self , image_path ):
1036
- image = Image .open (image_path )
1037
- return image .size
1037
+ return imagesize .get (image_path )
1038
1038
1039
1039
def load_image_with_face_info (self , subset : BaseSubset , image_path : str ):
1040
1040
img = load_image (image_path )
@@ -1396,6 +1396,8 @@ def __init__(
1396
1396
bucket_no_upscale : bool ,
1397
1397
prior_loss_weight : float ,
1398
1398
debug_dataset : bool ,
1399
+ cache_meta : bool ,
1400
+ use_cached_meta : bool ,
1399
1401
) -> None :
1400
1402
super ().__init__ (tokenizer , max_token_length , resolution , network_multiplier , debug_dataset , trust_cache )
1401
1403
@@ -1452,26 +1454,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1452
1454
logger .warning (f"not directory: { subset .image_dir } " )
1453
1455
return [], []
1454
1456
1455
- img_paths = glob_images (subset .image_dir , "*" )
1457
+ sizes = None
1458
+ if use_cached_meta :
1459
+ logger .info (f"using cached metadata: { subset .image_dir } /dataset.txt" )
1460
+ # [img_path, caption, resolution]
1461
+ with open (f"{ subset .image_dir } /dataset.txt" , "r" , encoding = "utf-8" ) as f :
1462
+ metas = f .readlines ()
1463
+ metas = [x .strip ().split ("<|##|>" ) for x in metas ]
1464
+ sizes = [tuple (int (res ) for res in x [2 ].split (" " )) for x in metas ]
1465
+
1466
+ if use_cached_meta :
1467
+ img_paths = [x [0 ] for x in metas ]
1468
+ else :
1469
+ img_paths = glob_images (subset .image_dir , "*" )
1470
+ sizes = [None ]* len (img_paths )
1456
1471
logger .info (f"found directory { subset .image_dir } contains { len (img_paths )} image files" )
1457
1472
1458
- # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1459
- captions = []
1460
- missing_captions = []
1461
- for img_path in img_paths :
1462
- cap_for_img = read_caption ( img_path , subset . caption_extension )
1463
- if cap_for_img is None and subset . class_tokens is None :
1464
- logger . warning (
1465
- f"neither caption file nor class tokens are found. use empty caption for { img_path } / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: { img_path } "
1466
- )
1467
- captions . append ( "" )
1468
- missing_captions . append ( img_path )
1469
- else :
1470
- if cap_for_img is None :
1471
- captions .append (subset . class_tokens )
1473
+ if use_cached_meta :
1474
+ captions = [x [ 1 ] for x in metas ]
1475
+ missing_captions = [x [ 0 ] for x in metas if x [ 1 ] == "" ]
1476
+ else :
1477
+ # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
1478
+ captions = []
1479
+ missing_captions = []
1480
+ for img_path in img_paths :
1481
+ cap_for_img = read_caption ( img_path , subset . caption_extension , subset . enable_wildcard )
1482
+ if cap_for_img is None and subset . class_tokens is None :
1483
+ logger . warning (
1484
+ f"neither caption file nor class tokens are found. use empty caption for { img_path } / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: { img_path } "
1485
+ )
1486
+ captions .append ("" )
1472
1487
missing_captions .append (img_path )
1473
1488
else :
1474
- captions .append (cap_for_img )
1489
+ if cap_for_img is None :
1490
+ captions .append (subset .class_tokens )
1491
+ missing_captions .append (img_path )
1492
+ else :
1493
+ captions .append (cap_for_img )
1475
1494
1476
1495
self .set_tag_frequency (os .path .basename (subset .image_dir ), captions ) # タグ頻度を記録
1477
1496
@@ -1488,7 +1507,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1488
1507
logger .warning (missing_caption + f"... and { remaining_missing_captions } more" )
1489
1508
break
1490
1509
logger .warning (missing_caption )
1491
- return img_paths , captions
1510
+
1511
+ if cache_meta :
1512
+ logger .info (f"cache metadata for { subset .image_dir } " )
1513
+ if sizes is None or sizes [0 ] is None :
1514
+ sizes = [self .get_image_size (img_path ) for img_path in img_paths ]
1515
+ # [img_path, caption, resolution]
1516
+ data = [
1517
+ (img_path , caption , " " .join (str (x ) for x in size ))
1518
+ for img_path , caption , size in zip (img_paths , captions , sizes )
1519
+ ]
1520
+ with open (f"{ subset .image_dir } /dataset.txt" , "w" , encoding = "utf-8" ) as f :
1521
+ f .write ("\n " .join (["<|##|>" .join (x ) for x in data ]))
1522
+ logger .info (f"cache metadata done for { subset .image_dir } " )
1523
+
1524
+ return img_paths , captions , sizes
1492
1525
1493
1526
logger .info ("prepare images." )
1494
1527
num_train_images = 0
@@ -1507,7 +1540,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1507
1540
)
1508
1541
continue
1509
1542
1510
- img_paths , captions = load_dreambooth_dir (subset )
1543
+ img_paths , captions , sizes = load_dreambooth_dir (subset )
1511
1544
if len (img_paths ) < 1 :
1512
1545
logger .warning (
1513
1546
f"ignore subset with image_dir='{ subset .image_dir } ': no images found / 画像が見つからないためサブセットを無視します"
@@ -1519,8 +1552,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
1519
1552
else :
1520
1553
num_train_images += subset .num_repeats * len (img_paths )
1521
1554
1522
- for img_path , caption in zip (img_paths , captions ):
1555
+ for img_path , caption , size in zip (img_paths , captions , sizes ):
1523
1556
info = ImageInfo (img_path , subset .num_repeats , caption , subset .is_reg , img_path )
1557
+ if size is not None :
1558
+ info .image_size = size
1524
1559
if subset .is_reg :
1525
1560
reg_infos .append (info )
1526
1561
else :
@@ -3294,6 +3329,12 @@ def add_dataset_arguments(
3294
3329
parser : argparse .ArgumentParser , support_dreambooth : bool , support_caption : bool , support_caption_dropout : bool
3295
3330
):
3296
3331
# dataset common
3332
+ parser .add_argument (
3333
+ "--cache_meta" , action = "store_true"
3334
+ )
3335
+ parser .add_argument (
3336
+ "--use_cached_meta" , action = "store_true"
3337
+ )
3297
3338
parser .add_argument (
3298
3339
"--train_data_dir" , type = str , default = None , help = "directory for train images / 学習画像データのディレクトリ"
3299
3340
)
0 commit comments