Skip to content

Fine-tuning FLAN-T5

To fine-tune with FLAN-T5, make a FlanT5BaseConfig and use it to create a FinetuningRun. You'll need a FLAN-T5 dataset.

Build a config

Create your FlanT5BaseConfig:

from baseten.training import FlanT5BaseConfig

config = FlanT5BaseConfig(
    input_dataset=dataset,
    source_col_name="my_source_column_name",
    target_col_name="my_target_column_name"
)

Run fine-tuning

Once your config is set, use it to create a FinetuningRun:

from baseten.training import FinetuningRun

my_run = FinetuningRun.create(
    trained_model_name="My FLAN Model",
    fine_tuning_config=config
)

You can get your run's status with:

my_run.status

Once the run starts (my_run.status returns RUNNING), stream logs with:

my_run.stream_logs()

Tips

Monitor your run with Weights & Biases

Blueprint fine-tuning can integrate with Weights & Biases to monitor your fine-tuning run:

  • Pass wandb_api_key="YOUR_API_KEY" to enable the integration.
  • Use image_log_steps to control how many steps you see images for

This will enable you to see your model generate images as it is being fine-tuned.

If your model doesn't generate the results you'd like, there may be a couple issues:

  1. Learning Rate: Your learning rate might've been too high or too low. A smaller learning rate ensures a more gradual convergence but may require more training time, while a larger learning rate may cause the model to overshoot the optimal solution. To remedy this, you can use the learning_rate parameter to set your learning rate as well as use the lr_scheduler_type in the FlanT5BaseConfig.

  2. Dataset: Your dataset might not have enough examples or the examples are poorly formatted. FLAN-T5 may run into issues when your dataset contains non-English characters or special characters (like <). It's important to remove these and preprocess your dataset before kicking off the job. You may also find that using an existing instruction template found here can improve your model performance.

  3. Weight Decay: You may've overfit or underfit on the dataset. Overfitting is when the model begins to memorize your training data. Underfitting is when the model struggle to find a function that maps from your input data to your output data. You can apply weight decay (L2 regularization) to prevent overfitting and improve generalization. Start with a small weight decay value (e.g., 1e-5 or 1e-4) and experiment with different values to find the right balance between overfitting and underfitting.

  4. Model Size: For certain tasks, smaller models can really struggle with the task. For example, large language models are notoriously bad at math word problems. It may benefit to train your dataset on a larger model that has more parameters. These models can generalize better to more complex tasks, yet they may require longer to train.

Check the FlanT5BaseConfig reference for a complete list of parameters.

What's next?

Your model will be automatically deployed

Once the fine-tuning run is complete, your model will be automatically deployed. You'll receive an email when the deployment is finished and the model is ready to invoke.

You can turn off this behavior by setting auto_deploy=False in FinetuningRun.create() and instead deploy your model manually.

Once your model is deployed, you can invoke it:

from baseten.models import FlanT5

model = FlanT5("MODEL_ID")
invocation = model("Answer the following question: What is 1 + 1?", max_length=512, early_stopping=True)

View our docs on the FlanT5() model object here

The model returns:

["2"]