Fine-tuning Dreambooth
To fine-tune with Dreambooth, make a DreamboothConfig
and use it to create a FinetuningRun
. You'll need a Dreambooth dataset.
Build a config
Create your DreamboothConfig
:
from baseten.training import DreamboothConfig
config = DreamboothConfig(
instance_prompt="photo of olliedog", # Dog's name is "Ollie"
input_dataset=dataset,
num_train_epochs=10
)
Run fine-tuning
Once your config is set, use it to create a FinetuningRun
:
from baseten.training import FinetuningRun
my_run = FinetuningRun.create(
trained_model_name="Dog Dreambooth",
fine_tuning_config=config
)
You can get your run's status with:
Once the run starts (my_run.status
returns RUNNING
), stream logs with:
Tips
- If you want a higher model quality at the expense of longer fine-tuning time:
- Increase
num_train_epochs
. - Set
train_text_encoder=True
. - Increase the number of images in your dataset.
- Increase
- For a faster fine-tuning run but lower model quality, do the opposite.
- If your instance prompt uses words that Stable Diffusion might not be familiar with, set
train_text_encoder=True
- If you don't have a lot of data, keep the learning rate low. For Dreambooth,
5e-06
is good.
Monitor your run with Weights & Biases
Blueprint fine-tuning can integrate with Weights & Biases to monitor your fine-tuning run:
- Pass
wandb_api_key="YOUR_API_KEY"
to enable the integration. - Use
image_log_steps
to control how many steps you see images for
This will enable you to see your model generate images as it is being fine-tuned.
Check the DreamboothConfig
reference for a complete list of parameters.
FAQ
What should my instance prompt be?
Dreambooth teaches Stable Diffusion to associate a word with a visual concept (aka the object in your fine-tuning dataset). To do so, it needs a word that doesn't already have any concepts associated with it. Make up a meaningful but unique string to describe the concept.
In the example, the dataset is 115 pictures of a dog named Ollie. So we use the word olliedog
as a unique descriptor. But your instance prompt could be any string that isn't a real word and wouldn't be found in the existing model's training data.
How do I regularize the model?
To make sure your model can generate images of both your specific object and a generic object — i.e. it can generate an image of your dog and a random dog — regularize the model with:
class_prompt
set to the generic object (e.g."a photo of a dog"
)with_prior_preservation=True
and prior preservation images in your dataset.
What's next?
Your model will be automatically deployed
Once the fine-tuning run is complete, your model will be automatically deployed. You'll receive an email when the deployment is finished and the model is ready to invoke.
You can turn off this behavior by setting auto_deploy=False
in FinetuningRun.create()
and instead deploy your model manually.
Once your model is deployed, you can invoke it:
from baseten.models import StableDiffusionPipeline
model = StableDiffusionPipeline(model_id="MODEL_ID")
image, url = model("portrait of olliedog as an andy warhol painting")
image.save("ollie-warhol.png")
The model returns:
- The generated image (using Pillow)
- A URL to the generated image
For more on your newly deployed model, see the StableDiffusionPipeline
reference.
Example output:
These images were generated for the following prompts:
portrait of olliedog as an andy warhol painting
(left)side profile of olliedog as a van gogh painting
(right)