Skip to content

LlamaConfig

LlamaConfig dataclass

Bases: FinetuningConfig

Training config for LLaMA-7B.

Examples:

from baseten.training import Dataset, LlamaConfig

config = LlamaConfig(
    input_dataset=Dataset("DATASET_ID"),
    epochs=3,
    learning_rate=5e-5,
    max_steps=1000,
    train_batch_size=8,
    sample_batch_size=8,
    report_to="wandb"
)

Parameters:

Name Type Description Default
model_id str

Pretrained model to fine-tune

'decapoda-research/llama-7b-hf'
source_col_name str

Name of the source column in the input CSV

'source'
target_col_name str

Name of the target column in the input CSV

'target'
evaluation_strategy str

Interval for evaluation (default: "epoch")

'epoch'
train_batch_size int

Batch size for training (default: 8)

16
train_micro_batch_size int

Micro batch size for training (default: 4)

8
sample_batch_size int

Batch size for sampling (default: 8)

16
sample_micro_batch_size int

Micro batch size for sampling (default: 4)

8
gradient_accumulation bool

Whether to perform gradient accumulation (default: True)

True
gradient_checkpointing bool

Whether to use gradient checkpointing (default: False)

False
learning_rate float

Learning rate for optimizer (default: 5e-5)

5e-05
weight_decay float

Weight decay for optimizer (default: 0.0)

0.0
adam_beta1 float

Beta1 parameter for AdamW optimizer (default: 0.9)

0.9
adam_beta2 float

Beta2 parameter for AdamW optimizer (default: 0.999)

0.999
adam_epsilon float

Epsilon parameter for AdamW optimizer (default: 1e-8)

1e-08
max_grad_norm float

Maximum gradient norm (default: 1.0)

1.0
epochs float

Number of epochs to train for (default: 3.0)

3.0
max_steps int

Maximum number of steps to train for (default: -1)

-1
warmup_steps int

Number of warmup steps (default: 0)

0
logging_steps int

Interval for logging (default: 500)

10
seed int

Random seed (default: 42)

42
fp16 bool

Whether to use mixed precision training (default: True)

True
run_name str

Name for the run (default: "llama-7b")

'blueprint_llama_7b'
disable_tqdm bool

Whether to disable tqdm progress bars (default: True)

False
label_smoothing_factor float

The label smoothing factor to use (default: 0.0)

0.0
adafactor bool

Whether to use the Adafactor optimizer (default: False)

False
report_to str

Destination for metrics reporting (default: None)

'wandb'
max_length int

Maximum input length (default: 256)

512
lora_r int

Number of attention heads for LORA (default: 8)

8
lora_target_modules list

List of target modules for LORA (default: ["q_proj", "v_proj"])

field(default_factory=lambda : ['q_proj', 'v_proj'])
lora_alpha float

Alpha parameter for LORA (default: 16.0)

16.0
lora_dropout float

Dropout probability for LORA (default: 0.05)

0.05
lora_bias str

Bias initialization for LORA (default: "none")

'none'