Coverage for transformer_lens/train.py: 28%
75 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Train.
3Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language
4modeling tasks.
5"""
7from dataclasses import dataclass
8from typing import Optional
10import torch
11import torch.optim as optim
12from torch.optim import Optimizer
13from torch.utils.data import DataLoader, Dataset
14from tqdm.auto import tqdm
16from transformer_lens import utils
17from transformer_lens.HookedTransformer import HookedTransformer
18from transformer_lens.utils import is_library_available
21@dataclass 21 ↛ 23line 21 didn't jump to line 23 because
22class HookedTransformerTrainConfig:
23 """
24 Configuration class to store training hyperparameters for a training run of
25 an HookedTransformer model.
26 Args:
27 num_epochs (int): Number of epochs to train for
28 batch_size (int): Size of batches to use for training
29 lr (float): Learning rate to use for training
30 seed (int): Random seed to use for training
31 momentum (float): Momentum to use for training
32 max_grad_norm (float, *optional*): Maximum gradient norm to use for
33 weight_decay (float, *optional*): Weight decay to use for training
34 optimizer_name (str): The name of the optimizer to use
35 device (str, *optional*): Device to use for training
36 warmup_steps (int, *optional*): Number of warmup steps to use for training
37 save_every (int, *optional*): After how many batches should a checkpoint be saved
38 save_dir, (str, *optional*): Where to save checkpoints
39 wandb (bool): Whether to use Weights and Biases for logging
40 wandb_project (str, *optional*): Name of the Weights and Biases project to use
41 print_every (int, *optional*): Print the loss every n steps
42 max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
43 """
45 num_epochs: int
46 batch_size: int
47 lr: float = 1e-3
48 seed: int = 0
49 momentum: float = 0.0
50 max_grad_norm: Optional[float] = None
51 weight_decay: Optional[float] = None
52 optimizer_name: str = "Adam"
53 device: Optional[str] = None
54 warmup_steps: int = 0
55 save_every: Optional[int] = None
56 save_dir: Optional[str] = None
57 wandb: bool = False
58 wandb_project_name: Optional[str] = None
59 print_every: Optional[int] = 50
60 max_steps: Optional[int] = None
63def train(
64 model: HookedTransformer,
65 config: HookedTransformerTrainConfig,
66 dataset: Dataset,
67) -> HookedTransformer:
68 """
69 Trains an HookedTransformer model on an autoregressive language modeling task.
70 Args:
71 model: The model to train
72 config: The training configuration
73 dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
74 Returns:
75 The trained model
76 """
78 torch.manual_seed(config.seed)
79 model.train()
81 if config.wandb:
82 if not is_library_available("wandb"):
83 raise ImportError("Wandb is not available")
85 import wandb
87 if config.wandb_project_name is None:
88 config.wandb_project_name = "easy-transformer"
89 wandb.init(project=config.wandb_project_name, config=vars(config))
91 if config.device is None:
92 config.device = utils.get_device()
94 optimizer: Optimizer
95 if config.optimizer_name in ["Adam", "AdamW"]:
96 # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
97 if config.weight_decay is not None:
98 optimizer = optim.AdamW(
99 model.parameters(),
100 lr=config.lr,
101 weight_decay=config.weight_decay,
102 )
103 else:
104 optimizer = optim.Adam(
105 model.parameters(),
106 lr=config.lr,
107 )
108 elif config.optimizer_name == "SGD":
109 optimizer = optim.SGD(
110 model.parameters(),
111 lr=config.lr,
112 weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0),
113 momentum=config.momentum,
114 )
115 else:
116 raise ValueError(f"Optimizer {config.optimizer_name} not supported")
118 scheduler = None
119 if config.warmup_steps > 0:
120 scheduler = optim.lr_scheduler.LambdaLR(
121 optimizer,
122 lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
123 )
125 dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
127 model.to(config.device)
129 for epoch in tqdm(range(1, config.num_epochs + 1)):
130 samples = 0
131 for step, batch in tqdm(enumerate(dataloader)):
132 tokens = batch["tokens"].to(config.device)
133 loss = model(tokens, return_type="loss")
134 loss.backward()
135 if config.max_grad_norm is not None:
136 torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
137 optimizer.step()
138 if config.warmup_steps > 0:
139 assert scheduler is not None
140 scheduler.step()
141 optimizer.zero_grad()
143 samples += tokens.shape[0]
145 if config.wandb:
146 wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch})
148 if config.print_every is not None and step % config.print_every == 0:
149 print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
151 if (
152 config.save_every is not None
153 and step % config.save_every == 0
154 and config.save_dir is not None
155 ):
156 torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")
158 if config.max_steps is not None and step >= config.max_steps:
159 break
161 return model