LlamaConfig
LlamaConfig
dataclass
Bases: FinetuningConfig
Training config for LLaMA-7B.
Examples:
from baseten.training import Dataset, LlamaConfig
config = LlamaConfig(
input_dataset=Dataset("DATASET_ID"),
epochs=3,
learning_rate=5e-5,
max_steps=1000,
train_batch_size=8,
sample_batch_size=8,
report_to="wandb"
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
str
|
Pretrained model to fine-tune |
'decapoda-research/llama-7b-hf'
|
source_col_name |
str
|
Name of the source column in the input CSV |
'source'
|
target_col_name |
str
|
Name of the target column in the input CSV |
'target'
|
evaluation_strategy |
str
|
Interval for evaluation (default: "epoch") |
'epoch'
|
train_batch_size |
int
|
Batch size for training (default: 8) |
16
|
train_micro_batch_size |
int
|
Micro batch size for training (default: 4) |
8
|
sample_batch_size |
int
|
Batch size for sampling (default: 8) |
16
|
sample_micro_batch_size |
int
|
Micro batch size for sampling (default: 4) |
8
|
gradient_accumulation |
bool
|
Whether to perform gradient accumulation (default: True) |
True
|
gradient_checkpointing |
bool
|
Whether to use gradient checkpointing (default: False) |
False
|
learning_rate |
float
|
Learning rate for optimizer (default: 5e-5) |
5e-05
|
weight_decay |
float
|
Weight decay for optimizer (default: 0.0) |
0.0
|
adam_beta1 |
float
|
Beta1 parameter for AdamW optimizer (default: 0.9) |
0.9
|
adam_beta2 |
float
|
Beta2 parameter for AdamW optimizer (default: 0.999) |
0.999
|
adam_epsilon |
float
|
Epsilon parameter for AdamW optimizer (default: 1e-8) |
1e-08
|
max_grad_norm |
float
|
Maximum gradient norm (default: 1.0) |
1.0
|
epochs |
float
|
Number of epochs to train for (default: 3.0) |
3.0
|
max_steps |
int
|
Maximum number of steps to train for (default: -1) |
-1
|
warmup_steps |
int
|
Number of warmup steps (default: 0) |
0
|
logging_steps |
int
|
Interval for logging (default: 500) |
10
|
seed |
int
|
Random seed (default: 42) |
42
|
fp16 |
bool
|
Whether to use mixed precision training (default: True) |
True
|
run_name |
str
|
Name for the run (default: "llama-7b") |
'blueprint_llama_7b'
|
disable_tqdm |
bool
|
Whether to disable tqdm progress bars (default: True) |
False
|
label_smoothing_factor |
float
|
The label smoothing factor to use (default: 0.0) |
0.0
|
adafactor |
bool
|
Whether to use the Adafactor optimizer (default: False) |
False
|
report_to |
str
|
Destination for metrics reporting (default: None) |
'wandb'
|
max_length |
int
|
Maximum input length (default: 256) |
512
|
lora_r |
int
|
Number of attention heads for LORA (default: 8) |
8
|
lora_target_modules |
list
|
List of target modules for LORA (default: ["q_proj", "v_proj"]) |
field(default_factory=lambda : ['q_proj', 'v_proj'])
|
lora_alpha |
float
|
Alpha parameter for LORA (default: 16.0) |
16.0
|
lora_dropout |
float
|
Dropout probability for LORA (default: 0.05) |
0.05
|
lora_bias |
str
|
Bias initialization for LORA (default: "none") |
'none'
|