transformer_lens.train#

Train.

Utilities for training transformer_lens.HookedTransformer models on autoregressive language modeling tasks.

class transformer_lens.train.HookedTransformerTrainConfig(num_epochs: int, batch_size: int, lr: float = 0.001, seed: int = 0, momentum: float = 0.0, max_grad_norm: Optional[float] = None, weight_decay: Optional[float] = None, optimizer_name: str = 'Adam', device: Optional[str] = None, warmup_steps: int = 0, save_every: Optional[int] = None, save_dir: Optional[str] = None, wandb: bool = False, wandb_project_name: Optional[str] = None, print_every: Optional[int] = 50, max_steps: Optional[int] = None)#

Bases: object

Configuration class to store training hyperparameters for a training run of an HookedTransformer model. :param num_epochs: Number of epochs to train for :type num_epochs: int :param batch_size: Size of batches to use for training :type batch_size: int :param lr: Learning rate to use for training :type lr: float :param seed: Random seed to use for training :type seed: int :param momentum: Momentum to use for training :type momentum: float :param max_grad_norm: Maximum gradient norm to use for :type max_grad_norm: float, optional :param weight_decay: Weight decay to use for training :type weight_decay: float, optional :param optimizer_name: The name of the optimizer to use :type optimizer_name: str :param device: Device to use for training :type device: str, optional :param warmup_steps: Number of warmup steps to use for training :type warmup_steps: int, optional :param save_every: After how many batches should a checkpoint be saved :type save_every: int, optional :param save_dir: Where to save checkpoints :type save_dir: str, optional :param : Where to save checkpoints :type : str, optional :param wandb: Whether to use Weights and Biases for logging :type wandb: bool :param wandb_project: Name of the Weights and Biases project to use :type wandb_project: str, optional :param print_every: Print the loss every n steps :type print_every: int, optional :param max_steps: Terminate the epoch after this many steps. Used for debugging. :type max_steps: int, optional

batch_size: int#
device: Optional[str] = None#
lr: float = 0.001#
max_grad_norm: Optional[float] = None#
max_steps: Optional[int] = None#
momentum: float = 0.0#
num_epochs: int#
optimizer_name: str = 'Adam'#
print_every: Optional[int] = 50#
save_dir: Optional[str] = None#
save_every: Optional[int] = None#
seed: int = 0#
wandb: bool = False#
wandb_project_name: Optional[str] = None#
warmup_steps: int = 0#
weight_decay: Optional[float] = None#
transformer_lens.train.train(model: HookedTransformer, config: HookedTransformerTrainConfig, dataset: Dataset) HookedTransformer#

Trains an HookedTransformer model on an autoregressive language modeling task. :param model: The model to train :param config: The training configuration :param dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.

Returns:

The trained model