Fine-tuning Stable Diffusion
To fine-tune Stable Diffusion, make a FullStableDiffusionConfig
and use it to create a FinetuningRun
. You'll need a Stable Diffusion dataset.
Build a config
Create your FullStableDiffusionConfig
:
from baseten.training import FullStableDiffusionConfig
config = FullStableDiffusionConfig(
input_dataset=dataset
)
Run fine-tuning
Once your config is set, use it to create a FinetuningRun
:
from baseten.training import FinetuningRun
my_run = FinetuningRun.create(
trained_model_name="Cool Stable Diffusion Model",
fine_tuning_config=config,
num_train_epochs=10
)
You can get your run's status with:
Once the run starts (my_run.status
returns RUNNING
), stream logs with:
Tips
- If you want a higher model quality at the expense of longer fine-tuning time:
- Increase
num_train_epochs
. - Set
train_text_encoder=True
. - Increase the number of images in your dataset.
- Increase
- For a faster fine-tuning run but lower model quality, do the opposite.
- If your instance prompt uses words that Stable Diffusion might not be familiar with, set
train_text_encoder=True
- If you don't have a lot of data, keep the learning rate low. For Stable Diffusion,
5e-05
is good.
Monitor your run with Weights & Biases
Blueprint fine-tuning can integrate with Weights & Biases to monitor your fine-tuning run:
- Pass
wandb_api_key="YOUR_API_KEY"
to enable the integration. - Use
image_log_steps
to control how many steps you see images for
This will enable you to see your model generate images as it is being fine-tuned.
Check the FullStableDiffusionConfig
reference for a complete list of parameters.
What's next?
Your model will be automatically deployed
Once the fine-tuning run is complete, your model will be automatically deployed. You'll receive an email when the deployment is finished and the model is ready to invoke.
You can turn off this behavior by setting auto_deploy=False
in FinetuningRun.create()
and instead deploy your model manually.
Once your model is deployed, you can invoke it:
from baseten.models import StableDiffusionPipeline
model = StableDiffusionPipeline(model_id="MODEL_ID")
image, url = model("a ghost-type cat pokemon with four legs and one tail")
image.save("poke-shark.png")
The model returns:
- The generated image (using Pillow)
- A URL to the generated image
For more on your newly deployed model, see the StableDiffusionPipeline
reference.
Here's an example output from the above invocation: