Skip to content

DreamboothConfig

DreamboothConfig dataclass

Bases: FinetuningConfig

Fine-tuning config for Dreambooth fine-tuning with Stable Diffusion.

Examples:

from baseten.training import Dataset, DreamboothConfig

config = DreamboothConfig(
    instance_prompt="photo of olliedog", # Dog's name is "Ollie"
    input_dataset=Dataset("DATASET_ID"),
    class_prompt="a photo of a dog",
    num_train_epochs=10
)

Parameters:

Name Type Description Default
instance_prompt str

The prompt with an identifier specifying the instance that you're teaching Stable Diffusion

required
input_dataset DatasetIdentifier

An identifier, either an ID or a public URL, for the Dataset that Dreambooth should use

required
wandb_api_key Optional[str]

API key for Weights & Biases to monitor your model training

None
pretrained_model_name_or_path str

Path to pretrained model or model identifier from huggingface.co/models.

'CompVis/stable-diffusion-v1-4'
revision Optional[str]

Revision of pretrained model identifier from huggingface.co/models.

None
tokenizer_name Optional[str]

Pretrained tokenizer name or path if not the same as model_name

None
class_prompt Optional[str]

The prompt to specify images in the same class as your instance images. This helps regularize the model (e.g. so that not all prompts with "dog" look like your dog but only "olliedog" does)

None
with_prior_preservation bool

Flag to use prior preservation loss

False
prior_loss_weight float

The weight of the prior preservation loss

1.0
gradient_accumulation_steps int

Number of gradient accumulation steps to use for training. Defaults to 1.

1
num_class_images int

The number of class images to use for fine-tuning, only relevant if using prior preservation. If greater than the number of class images in your dataset, it will generate the remainder using the base model.

100
seed Optional[int]

The random seed to use for fine-tuning

None
resolution int

The resolution of your input images. Images will be resized to this value.

512
center_crop bool

Whether to center crop the images to the resolution

False
train_text_encoder bool

Whether to train the text encoder alongside the UNet

False
train_batch_size int

The batch size to use for training. This value can cause OOMs if too large. We recommend using between 1-4 based on the resolution of your images.

1
sample_batch_size int

The batch size to use for sampling images.

1
num_train_epochs int

The number of epochs to train for. If you set max_train_steps, this will be ignored.

1
max_train_steps int

The number of training steps to train for. If you set this, num_train_epochs will be ignored. Defaults to 1000 train steps.

1000
learning_rate float

Initial learning rate (after the potential warmup period) to use.

1e-06
lr_scheduler str

The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"].

'constant'
lr_warmup_steps int

The number of steps to warmup for the learning rate schedule.

500
adam_beta1 float

The beta1 parameter for the Adam optimizer.

0.9
adam_beta2 float

The beta2 parameter for the Adam optimizer.

0.999
adam_weight_decay float

The weight decay value for the Adam optimizer.

0.01
adam_epsilon float

The epsilon value for the Adam optimizer.

1e-08
max_grad_norm float

The max gradient normalization to clip gradients to. Helps prevent exploding gradients.

1.0
mixed_precision str

Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16) or None for no mixed precision. This value can cause OOMs in some cases with batch size. We recommend using fp16.

'fp16'
image_log_steps int

The number of steps to log sample images to Weights and Biases. This allows you to visually assess your model during training. Only relevant if wandb_api_key is set.

20