Skip to content

Fine-tuning Stable Diffusion

To fine-tune Stable Diffusion, make a FullStableDiffusionConfig and use it to create a FinetuningRun. You'll need a Stable Diffusion dataset.

Build a config

Create your FullStableDiffusionConfig:

from baseten.training import FullStableDiffusionConfig

config = FullStableDiffusionConfig(
    input_dataset=dataset
)

Run fine-tuning

Once your config is set, use it to create a FinetuningRun:

from baseten.training import FinetuningRun

my_run = FinetuningRun.create(
    trained_model_name="Cool Stable Diffusion Model",
    fine_tuning_config=config,
    num_train_epochs=10
)

You can get your run's status with:

my_run.status

Once the run starts (my_run.status returns RUNNING), stream logs with:

my_run.stream_logs()

Tips

  • If you want a higher model quality at the expense of longer fine-tuning time:
    • Increase num_train_epochs.
    • Set train_text_encoder=True.
    • Increase the number of images in your dataset.
  • For a faster fine-tuning run but lower model quality, do the opposite.
  • If your instance prompt uses words that Stable Diffusion might not be familiar with, set train_text_encoder=True
  • If you don't have a lot of data, keep the learning rate low. For Stable Diffusion, 5e-05 is good.

Monitor your run with Weights & Biases

Blueprint fine-tuning can integrate with Weights & Biases to monitor your fine-tuning run:

  • Pass wandb_api_key="YOUR_API_KEY" to enable the integration.
  • Use image_log_steps to control how many steps you see images for

This will enable you to see your model generate images as it is being fine-tuned.

Check the FullStableDiffusionConfig reference for a complete list of parameters.

What's next?

Your model will be automatically deployed

Once the fine-tuning run is complete, your model will be automatically deployed. You'll receive an email when the deployment is finished and the model is ready to invoke.

You can turn off this behavior by setting auto_deploy=False in FinetuningRun.create() and instead deploy your model manually.

Once your model is deployed, you can invoke it:

from baseten.models import StableDiffusionPipeline

model = StableDiffusionPipeline(model_id="MODEL_ID")
image, url = model("a ghost-type cat pokemon with four legs and one tail")
image.save("poke-shark.png")

The model returns:

  • The generated image (using Pillow)
  • A URL to the generated image

For more on your newly deployed model, see the StableDiffusionPipeline reference.

Here's an example output from the above invocation:

Example stable diffusion output