transformer_lens.train#
Train.
Utilities for training transformer_lens.HookedTransformer
models on autoregressive language
modeling tasks.
- class transformer_lens.train.HookedTransformerTrainConfig(num_epochs: int, batch_size: int, lr: float = 0.001, seed: int = 0, momentum: float = 0.0, max_grad_norm: Optional[float] = None, weight_decay: Optional[float] = None, optimizer_name: str = 'Adam', device: Optional[str] = None, warmup_steps: int = 0, save_every: Optional[int] = None, save_dir: Optional[str] = None, wandb: bool = False, wandb_project_name: Optional[str] = None, print_every: Optional[int] = 50, max_steps: Optional[int] = None)#
Bases:
object
Configuration class to store training hyperparameters for a training run of an HookedTransformer model. :param num_epochs: Number of epochs to train for :type num_epochs: int :param batch_size: Size of batches to use for training :type batch_size: int :param lr: Learning rate to use for training :type lr: float :param seed: Random seed to use for training :type seed: int :param momentum: Momentum to use for training :type momentum: float :param max_grad_norm: Maximum gradient norm to use for :type max_grad_norm: float, optional :param weight_decay: Weight decay to use for training :type weight_decay: float, optional :param optimizer_name: The name of the optimizer to use :type optimizer_name: str :param device: Device to use for training :type device: str, optional :param warmup_steps: Number of warmup steps to use for training :type warmup_steps: int, optional :param save_every: After how many batches should a checkpoint be saved :type save_every: int, optional :param save_dir: Where to save checkpoints :type save_dir: str, optional :param : Where to save checkpoints :type : str, optional :param wandb: Whether to use Weights and Biases for logging :type wandb: bool :param wandb_project: Name of the Weights and Biases project to use :type wandb_project: str, optional :param print_every: Print the loss every n steps :type print_every: int, optional :param max_steps: Terminate the epoch after this many steps. Used for debugging. :type max_steps: int, optional
- batch_size: int#
- device: Optional[str] = None#
- lr: float = 0.001#
- max_grad_norm: Optional[float] = None#
- max_steps: Optional[int] = None#
- momentum: float = 0.0#
- num_epochs: int#
- optimizer_name: str = 'Adam'#
- print_every: Optional[int] = 50#
- save_dir: Optional[str] = None#
- save_every: Optional[int] = None#
- seed: int = 0#
- wandb: bool = False#
- wandb_project_name: Optional[str] = None#
- warmup_steps: int = 0#
- weight_decay: Optional[float] = None#
- transformer_lens.train.train(model: HookedTransformer, config: HookedTransformerTrainConfig, dataset: Dataset) HookedTransformer #
Trains an HookedTransformer model on an autoregressive language modeling task. :param model: The model to train :param config: The training configuration :param dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
- Returns:
The trained model