Coverage for transformer_lens/train.py: 29%
72 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Train.
3Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language
4modeling tasks.
5"""
7from dataclasses import dataclass
8from typing import Optional
10import torch
11import torch.optim as optim
12import wandb
13from torch.optim import Optimizer
14from torch.utils.data import DataLoader, Dataset
15from tqdm.auto import tqdm
17from transformer_lens import utils
18from transformer_lens.HookedTransformer import HookedTransformer
21@dataclass 21 ↛ 23line 21 didn't jump to line 23, because
22class HookedTransformerTrainConfig:
23 """
24 Configuration class to store training hyperparameters for a training run of
25 an HookedTransformer model.
26 Args:
27 num_epochs (int): Number of epochs to train for
28 batch_size (int): Size of batches to use for training
29 lr (float): Learning rate to use for training
30 seed (int): Random seed to use for training
31 momentum (float): Momentum to use for training
32 max_grad_norm (float, *optional*): Maximum gradient norm to use for
33 weight_decay (float, *optional*): Weight decay to use for training
34 optimizer_name (str): The name of the optimizer to use
35 device (str, *optional*): Device to use for training
36 warmup_steps (int, *optional*): Number of warmup steps to use for training
37 save_every (int, *optional*): After how many batches should a checkpoint be saved
38 save_dir, (str, *optional*): Where to save checkpoints
39 wandb (bool): Whether to use Weights and Biases for logging
40 wandb_project (str, *optional*): Name of the Weights and Biases project to use
41 print_every (int, *optional*): Print the loss every n steps
42 max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
43 """
45 num_epochs: int
46 batch_size: int
47 lr: float = 1e-3
48 seed: int = 0
49 momentum: float = 0.0
50 max_grad_norm: Optional[float] = None
51 weight_decay: Optional[float] = None
52 optimizer_name: str = "Adam"
53 device: Optional[str] = None
54 warmup_steps: int = 0
55 save_every: Optional[int] = None
56 save_dir: Optional[str] = None
57 wandb: bool = False
58 wandb_project_name: Optional[str] = None
59 print_every: Optional[int] = 50
60 max_steps: Optional[int] = None
63def train(
64 model: HookedTransformer,
65 config: HookedTransformerTrainConfig,
66 dataset: Dataset,
67) -> HookedTransformer:
68 """
69 Trains an HookedTransformer model on an autoregressive language modeling task.
70 Args:
71 model: The model to train
72 config: The training configuration
73 dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
74 Returns:
75 The trained model
76 """
77 torch.manual_seed(config.seed)
78 model.train()
79 if config.wandb:
80 if config.wandb_project_name is None:
81 config.wandb_project_name = "easy-transformer"
82 wandb.init(project=config.wandb_project_name, config=vars(config))
84 if config.device is None:
85 config.device = utils.get_device()
87 optimizer: Optimizer
88 if config.optimizer_name in ["Adam", "AdamW"]:
89 # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
90 if config.weight_decay is not None:
91 optimizer = optim.AdamW(
92 model.parameters(),
93 lr=config.lr,
94 weight_decay=config.weight_decay,
95 )
96 else:
97 optimizer = optim.Adam(
98 model.parameters(),
99 lr=config.lr,
100 )
101 elif config.optimizer_name == "SGD":
102 optimizer = optim.SGD(
103 model.parameters(),
104 lr=config.lr,
105 weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0),
106 momentum=config.momentum,
107 )
108 else:
109 raise ValueError(f"Optimizer {config.optimizer_name} not supported")
111 scheduler = None
112 if config.warmup_steps > 0:
113 scheduler = optim.lr_scheduler.LambdaLR(
114 optimizer,
115 lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
116 )
118 dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
120 model.to(config.device)
122 for epoch in tqdm(range(1, config.num_epochs + 1)):
123 samples = 0
124 for step, batch in tqdm(enumerate(dataloader)):
125 tokens = batch["tokens"].to(config.device)
126 loss = model(tokens, return_type="loss")
127 loss.backward()
128 if config.max_grad_norm is not None:
129 torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
130 optimizer.step()
131 if config.warmup_steps > 0:
132 assert scheduler is not None
133 scheduler.step()
134 optimizer.zero_grad()
136 samples += tokens.shape[0]
138 if config.wandb:
139 wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch})
141 if config.print_every is not None and step % config.print_every == 0:
142 print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
144 if (
145 config.save_every is not None
146 and step % config.save_every == 0
147 and config.save_dir is not None
148 ):
149 torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")
151 if config.max_steps is not None and step >= config.max_steps:
152 break
154 return model