Coverage for transformer_lens/train.py: 29%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Train. 

2 

3Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language 

4modeling tasks. 

5""" 

6 

7from dataclasses import dataclass 

8from typing import Optional 

9 

10import torch 

11import torch.optim as optim 

12import wandb 

13from torch.optim import Optimizer 

14from torch.utils.data import DataLoader, Dataset 

15from tqdm.auto import tqdm 

16 

17from transformer_lens import utils 

18from transformer_lens.HookedTransformer import HookedTransformer 

19 

20 

21@dataclass 21 ↛ 23line 21 didn't jump to line 23, because

22class HookedTransformerTrainConfig: 

23 """ 

24 Configuration class to store training hyperparameters for a training run of 

25 an HookedTransformer model. 

26 Args: 

27 num_epochs (int): Number of epochs to train for 

28 batch_size (int): Size of batches to use for training 

29 lr (float): Learning rate to use for training 

30 seed (int): Random seed to use for training 

31 momentum (float): Momentum to use for training 

32 max_grad_norm (float, *optional*): Maximum gradient norm to use for 

33 weight_decay (float, *optional*): Weight decay to use for training 

34 optimizer_name (str): The name of the optimizer to use 

35 device (str, *optional*): Device to use for training 

36 warmup_steps (int, *optional*): Number of warmup steps to use for training 

37 save_every (int, *optional*): After how many batches should a checkpoint be saved 

38 save_dir, (str, *optional*): Where to save checkpoints 

39 wandb (bool): Whether to use Weights and Biases for logging 

40 wandb_project (str, *optional*): Name of the Weights and Biases project to use 

41 print_every (int, *optional*): Print the loss every n steps 

42 max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging. 

43 """ 

44 

45 num_epochs: int 

46 batch_size: int 

47 lr: float = 1e-3 

48 seed: int = 0 

49 momentum: float = 0.0 

50 max_grad_norm: Optional[float] = None 

51 weight_decay: Optional[float] = None 

52 optimizer_name: str = "Adam" 

53 device: Optional[str] = None 

54 warmup_steps: int = 0 

55 save_every: Optional[int] = None 

56 save_dir: Optional[str] = None 

57 wandb: bool = False 

58 wandb_project_name: Optional[str] = None 

59 print_every: Optional[int] = 50 

60 max_steps: Optional[int] = None 

61 

62 

63def train( 

64 model: HookedTransformer, 

65 config: HookedTransformerTrainConfig, 

66 dataset: Dataset, 

67) -> HookedTransformer: 

68 """ 

69 Trains an HookedTransformer model on an autoregressive language modeling task. 

70 Args: 

71 model: The model to train 

72 config: The training configuration 

73 dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling. 

74 Returns: 

75 The trained model 

76 """ 

77 torch.manual_seed(config.seed) 

78 model.train() 

79 if config.wandb: 

80 if config.wandb_project_name is None: 

81 config.wandb_project_name = "easy-transformer" 

82 wandb.init(project=config.wandb_project_name, config=vars(config)) 

83 

84 if config.device is None: 

85 config.device = utils.get_device() 

86 

87 optimizer: Optimizer 

88 if config.optimizer_name in ["Adam", "AdamW"]: 

89 # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs) 

90 if config.weight_decay is not None: 

91 optimizer = optim.AdamW( 

92 model.parameters(), 

93 lr=config.lr, 

94 weight_decay=config.weight_decay, 

95 ) 

96 else: 

97 optimizer = optim.Adam( 

98 model.parameters(), 

99 lr=config.lr, 

100 ) 

101 elif config.optimizer_name == "SGD": 

102 optimizer = optim.SGD( 

103 model.parameters(), 

104 lr=config.lr, 

105 weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0), 

106 momentum=config.momentum, 

107 ) 

108 else: 

109 raise ValueError(f"Optimizer {config.optimizer_name} not supported") 

110 

111 scheduler = None 

112 if config.warmup_steps > 0: 

113 scheduler = optim.lr_scheduler.LambdaLR( 

114 optimizer, 

115 lr_lambda=lambda step: min(1.0, step / config.warmup_steps), 

116 ) 

117 

118 dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) 

119 

120 model.to(config.device) 

121 

122 for epoch in tqdm(range(1, config.num_epochs + 1)): 

123 samples = 0 

124 for step, batch in tqdm(enumerate(dataloader)): 

125 tokens = batch["tokens"].to(config.device) 

126 loss = model(tokens, return_type="loss") 

127 loss.backward() 

128 if config.max_grad_norm is not None: 

129 torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) 

130 optimizer.step() 

131 if config.warmup_steps > 0: 

132 assert scheduler is not None 

133 scheduler.step() 

134 optimizer.zero_grad() 

135 

136 samples += tokens.shape[0] 

137 

138 if config.wandb: 

139 wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch}) 

140 

141 if config.print_every is not None and step % config.print_every == 0: 

142 print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}") 

143 

144 if ( 

145 config.save_every is not None 

146 and step % config.save_every == 0 

147 and config.save_dir is not None 

148 ): 

149 torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt") 

150 

151 if config.max_steps is not None and step >= config.max_steps: 

152 break 

153 

154 return model