Skip to content

Latest commit

 

History

History
259 lines (206 loc) · 11.2 KB

DREAMBOOTH.md

File metadata and controls

259 lines (206 loc) · 11.2 KB

Dreambooth (single-subject training)

Background

The term Dreambooth refers to a technique developed by Google to inject subjects by finetuning them into a model using a small set of high quality images (paper)

In the context of fine-tuning, Dreambooth adds new techniques to help prevent model collapse due to eg. overfitting or artifacts.

Regularisation images

Regularisation images are typically generated by the model you are training, using a token that resembles your class.

They do not have to be synthetic images generated by the model, but this possibly has better performance than using real data (eg. photographs of real persons).

Example: If you are training in images of a male subject, your regularisation data would be photographs or synthetic generated samples of random male subjects.

🟢 Regularisation images can be configured as a separate dataset, allowing them to mix evenly with your training data.

Rare token training

A concept of dubious value from the original paper was to do a reverse search through the model's tokenizer vocabulary to find a "rare" string that had very little training associated to it.

Since that time, the idea has evolved and debated, with an opposing camp deciding to train against a celebrity's name that looks similar enough, as this requires less compute.

🟡 Rare token training is supported in SimpleTuner, but there's no tool available to help you find one.

Prior preservation loss

The model contains something called a "prior" which could, in theory, be preserved during Dreambooth training. In experiments with Stable Diffusion however, it didn't seem to help - the model just overfits on its own knowledge.

🟢 (#1031) Prior preservation loss is supported in SimpleTuner when training LyCORIS adapters by setting is_regularisation_data on that dataset.

Masked loss

Image masks may be defined in pairs with image data. The dark portions of the mask will cause the loss calculations to ignore these parts of the image.

An example script exists to generate these masks, given an input_dir and output_dir:

python generate_dataset_masks.py --input_dir /images/input \
                      --output_dir /images/output \
                      --text_input "person"

However, this does not have any advanced functionality such as mask padding blurring.

When defining your image mask dataset:

  • Every image must have a mask. Use an all-white image if you do not want to mask.
  • Set dataset_type=conditioning on your conditioning (mask) data folder
  • Set conditioning_type=mask on your mask dataset
  • Set conditioning_data= to your conditioning dataset id on your image dataset
[
    {
        "id": "dreambooth-data",
        "type": "local",
        "dataset_type": "image",
        "conditioning_data": "dreambooth-conditioning",
        "instance_data_dir": "/training/datasets/test_datasets/dreambooth",
        "cache_dir_vae": "/training/cache/vae/sdxl/dreambooth-data",
        "caption_strategy": "instanceprompt",
        "instance_prompt": "an dreambooth",
        "metadata_backend": "discovery",
        "resolution": 1024,
        "minimum_image_size": 1024,
        "maximum_image_size": 1024,
        "target_downsample_size": 1024,
        "crop": true,
        "crop_aspect": "square",
        "crop_style": "center",
        "resolution_type": "pixel_area"
    },
    {
        "id": "dreambooth-conditioning",
        "type": "local",
        "dataset_type": "conditioning",
        "instance_data_dir": "/training/datasets/test_datasets/dreambooth_mask",
        "resolution": 1024,
        "minimum_image_size": 1024,
        "maximum_image_size": 1024,
        "target_downsample_size": 1024,
        "crop": true,
        "crop_aspect": "square",
        "crop_style": "center",
        "resolution_type": "pixel_area",
        "conditioning_type": "mask"
    },
    {
        "id": "an example backend for text embeds.",
        "dataset_type": "text_embeds",
        "default": true,
        "type": "local",
        "cache_dir": "/training/cache/text/sdxl-base/masked_loss"
    }
]

Setup

Following the tutorial is required before you can continue into Dreambooth-specific configuration.

For DeepFloyd tuning, it's recommended to visit this page for specific tips related to that model's setup.

Quantised model training

Tested on Apple and NVIDIA systems, Hugging Face Optimum-Quanto can be used to reduce the precision and VRAM requirements.

Inside your SimpleTuner venv:

pip install optimum-quanto
# choices: int8-quanto, int4-quanto, int2-quanto, fp8-quanto
# int8-quanto was tested with a single subject dreambooth LoRA.
# fp8-quanto does not work on Apple systems. you must use int levels.
# int2-quanto is pretty extreme and gets the whole rank-1 LoRA down to about 13.9GB VRAM.
# may the gods have mercy on your soul, should you push things Too Far.
export TRAINER_EXTRA_ARGS="--base_model_precision=int8-quanto"

# Maybe you want the text encoders to remain full precision so your text embeds are cake.
# We unload the text encoders before training, so, that's not an issue during training time - only during pre-caching.
# Alternatively, you can go ham on quantisation here and run them in int4 or int8 mode, because no one can stop you.
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --text_encoder_1_precision=no_change --text_encoder_2_precision=no_change"

# When you're quantising the model, we're not in pure bf16 anymore.
# Since adamw_bf16 will never work with this setup, select another optimiser.
# I know the spelling is different than everywhere else, but we're in too deep to fix it now.
export OPTIMIZER="optimi-lion" # or maybe optimi-stableadamw

Inside our dataloader config multidatabackend-dreambooth.json, it will look something like this:

[
    {
        "id": "subjectname-data-512px",
        "type": "local",
        "instance_data_dir": "/training/datasets/subjectname",
        "caption_strategy": "instanceprompt",
        "instance_prompt": "subjectname",
        "cache_dir_vae": "/training/vae_cache/subjectname",
        "repeats": 100,
        "crop": false,
        "resolution": 512,
        "resolution_type": "pixel_area",
        "minimum_image_size": 192
    },
    {
        "id": "subjectname-data-1024px",
        "type": "local",
        "instance_data_dir": "/training/datasets/subjectname",
        "caption_strategy": "instanceprompt",
        "instance_prompt": "subjectname",
        "cache_dir_vae": "/training/vae_cache/subjectname-1024px",
        "repeats": 100,
        "crop": false,
        "resolution": 1024,
        "resolution_type": "pixel_area",
        "minimum_image_size": 768
    },
    {
        "id": "regularisation-data",
        "type": "local",
        "instance_data_dir": "/training/datasets/regularisation",
        "caption_strategy": "instanceprompt",
        "instance_prompt": "a picture of a man",
        "cache_dir_vae": "/training/vae_cache/regularisation",
        "repeats": 0,
        "resolution": 512,
        "resolution_type": "pixel_area",
        "minimum_image_size": 192,
        "is_regularisation_data": true
    },
    {
        "id": "regularisation-data-1024px",
        "type": "local",
        "instance_data_dir": "/training/datasets/regularisation",
        "caption_strategy": "instanceprompt",
        "instance_prompt": "a picture of a man",
        "cache_dir_vae": "/training/vae_cache/regularisation-1024px",
        "repeats": 0,
        "resolution": 1024,
        "resolution_type": "pixel_area",
        "minimum_image_size": 768,
        "is_regularisation_data": true
    },
    {
        "id": "textembeds",
        "type": "local",
        "dataset_type": "text_embeds",
        "default": true,
        "cache_dir": "/training/text_cache/sdxl_base"
    }
]

Some key values have been tweaked to make training a single subject easier:

  • We now have two datasets configured twice, for a total of four datasets. Regularisation data is optional, and training may work better without it. You can remove that dataset from the list if desired.
  • Resolution is set to 512px and 1024px mixed bucketing which can help improve training speed and convergence
  • Minimum image size is set to 192px or 768px which will allow us to upsample some smaller images, which might be needed for datasets with a few important but low resolution images.
  • caption_strategy is now instanceprompt, which means we will use instance_prompt value for every image in the dataset as its caption.
    • Note: Using the instance prompt is the traditional method of Dreambooth training, but short captions may work better. If you find the model fails to generalise, it may be worth attempting to use captions.

For a regularisation dataset:

  • Set repeats very high on your Dreambooth subject so that your image count in the Dreambooth data is multiplied repeats times to surpass the image count of your regularisation set
    • If your Regularisation set has 1000 images, and you have 10 images in your training set, you'd want a repeats value of at least 100 to get fast results
  • minimum_image_size has been increased to ensure we don't introduce too many low-quality artifacts
  • Similarly, using more descriptive captions may help avoid forgetting. Switching from instanceprompt to textfile or other strategies will require creating .txt files for each image.
  • When is_regularisation_data (or 🇺🇸 is_regularization_data with a z, for the American users) is set, the data from this set will be fed into the base model to obtain a prediction that can be used as a loss target for the student LyCORIS model.
    • Note, currently this only functions on a LyCORIS adapter.

Selecting an instance prompt

As mentioned earlier, the original focus of Dreambooth was the selection of rare tokens to train on.

Alternatively, one might use the real name of their subject, or a 'similar enough' celebrity.

After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar.

Refiner tuning

If you're a fan of the SDXL refiner, you may find that it causes your generations to "ruin" the results of your Dreamboothed model.

SimpleTuner supports training the SDXL refiner using LoRA and full rank.

This requires a couple considerations:

  • The images should be purely high-quality
  • The text embeds cannot be shared with the base model's
  • The VAE embeds can be shared with the base model

You'll need to update cache_dir in your dataloader configuration, multidatabackend.json:

[
    {
        "id": "textembeds",
        "type": "local",
        "dataset_type": "text_embeds",
        "default": true,
        "cache_dir": "/training/text_cache/sdxl_refiner"
    }
]

If you wish to target a specific aesthetic score with your data, you can add this to config/config.json:

"--data_aesthetic_score": 5.6,

Update 5.6 to the score you would like to target. The default is 7.0.

⚠️ When training the SDXL refiner, your validation prompts will be ignored. Instead, random images from your datasets will be refined.