How to finetune full Stable Diffusion
Stable Diffusion is an image generation model that can create a wide variety of images. However, it becomes even more powerful and useful when fine-tuned for specific tasks.
Fine-tuning lets you get a lot of value from small to medium-size datasets by leveraging foundation models. As opposed to training a model from scratch, fine-tuning alters the behavior of an existing model to better match a specific use case.
Fine-tuning can radically alter the capabilities of Stable Diffusion. Riffusion used Stable Diffusion to create music by fine-tuning it on audio spectrographs. Stable Diffusion can be fine-tuned to generate anything that can be represented with pixels.
In this tutorial, we'll fine-tune Stable Diffusion to generate images of new Pokémon.
Step 0: Prerequisites
After signing up for Blueprint by Baseten, you'll need to do three things to complete this tutorial:
- Install the latest version of the Baseten Python client with
pip install --upgrade baseten
- Create an API key
- In your terminal, run:
And paste your API key when prompted.
Following this tutorial will consume credits/billable resources
The Stable Diffusion fine-tuning run in this tutorial guide will consume credits (if available on your account) or billable resources.
Step 1: Create your dataset
Fine-tuning Stable Diffusion is the process of teaching it to associate strings with certain visual objects. With enough training data, you can totally restructure the way Stable Diffusion generates images to make it specialize in a certain style or pattern.
The example dataset we use in the tutorial contains 833 images of Pokémon with associated text files containing brief descriptions of said Pokémon.
Your dataset must be similarly structured:
This folder will be zipped during the upload process, so make sure nothing else is in the folder.
Here's a small sample of the example dataset:
a drawing of a green pokemon with red eyes
(left)a green and yellow toy with a red nose
(middle)a red and white ball with an angry look on its face
(right)
Step 2: Upload dataset
There are three ways to provide a dataset to a fine-tuning run. Click through the tabs to see options.
A "public" URL means a link that you can access without logging in or providing an API key.
The dataset must be a zip file containing the folder structure explained in step 1.
If you want to follow the tutorial using a pre-built dataset (five pictures of a dog from unsplash), use the code below as-is. Otherwise, replace the link with a link to your hosted dataset zip file, or check the other tabs for different dataset upload options.
If you have your dataset on the local machine that you're running your Python code on, you can use the LocalPath
option to upload it as part of your fine-tuning script.
If your fine-tuning script is running on one machine and your dataset lives on another, or you want to upload a dataset once and use it in multiple fine-tuning runs, you'll want to upload the dataset separately.
baseten dataset upload
is a bash command
Open a terminal window and run:
Notes:
- If the
name
parameter is not provided, Blueprint will name your dataset based on the directory name. - If you're doing a Full Stable Diffusion run, instead use
--training-type CLASSIC_STABLE_DIFFUSION
.
You should see:
Upload Progress: 100% |█████████████████████████████████████████████████████████
INFO 🔮 Upload successful!🔮
Dataset ID:
DATASET_ID
Then, for your fine-tuning config (your Python code), you'll use:
Step 3: Assemble fine-tuning config
For the rest of this tutorial, we'll be using Python to configure, create, and deploy a fine-tuned model. Open up a Jupyter notebook or Python file in your local development environment to follow along.
Assembling the config is an opportunity to truly customize the fine-tuning run to meet our exact needs. For a complete reference of every configurable parameter, see the FullStableDiffusionConfig
docs.
In the config for our example data, we're not configuring any values, to demonstrate how Stable
from baseten.training import FullStableDiffusionConfig
config = FullStableDiffusionConfig(
input_dataset=dataset
)
Step 4: Run fine-tuning
Once your config is set, it's time to kick off the fine-tuning run. This process is straightforward, just use:
from baseten.training import FinetuningRun
my_run = FinetuningRun.create(
trained_model_name="Pokemaker",
fine_tuning_config=config
)
The trained_model_name
will be assigned the deployed model.
Fine-tuning a model takes some time. Exactly how long depends on:
- the type of fine-tuning you're doing (full Stable Diffusion generally takes longer than Dreambooth)
- the size of your dataset (more images takes longer)
- the configured
num_train_epochs
ormax_train_steps
(higher number means longer run).
While you wait, you can monitor the run's progress with:
Your model will be automatically deployed
Once the fine-tuning run is complete, your model will be automatically deployed. You'll receive an email when the deployment is finished and the model is ready to invoke.
You can turn off this behavior by setting auto_deploy=False
in FinetuningRun.create()
and instead deploy your model manually.
Step 5: Use fine-tuned model
It's time! You can finally invoke the model. Use:
image, url = model("a ghost-type cat pokemon with four legs and one tail")
image.save("poke-shark.png")
You'll get a Pillow image of your model output as well as a URL to access the output in the future.
If you want to access the model later, you can do so by instantiating a StableDiffusionPipeline
with the model ID:
from baseten.models import StableDiffusionPipeline
model = StableDiffusionPipeline(model_id="MODEL_ID")
image, url = model("a ghost-type cat pokemon with four legs and one tail")
image.save("poke-shark.png")
Here's an example output from the above invocation: