transformer_lens package¶
Subpackages¶
- transformer_lens.benchmarks package
- Submodules
- transformer_lens.benchmarks.activation_cache module
- transformer_lens.benchmarks.audio module
- transformer_lens.benchmarks.backward_gradients module
- transformer_lens.benchmarks.component_benchmark module
- transformer_lens.benchmarks.component_outputs module
- transformer_lens.benchmarks.forward_pass module
- transformer_lens.benchmarks.generation module
- transformer_lens.benchmarks.granular_weight_processing module
- transformer_lens.benchmarks.hook_registration module
- transformer_lens.benchmarks.hook_structure module
- transformer_lens.benchmarks.main_benchmark module
- transformer_lens.benchmarks.multimodal module
- transformer_lens.benchmarks.text_quality module
- transformer_lens.benchmarks.utils module
- transformer_lens.benchmarks.weight_processing module
benchmark_attention_output_centering()benchmark_layer_norm_folding()benchmark_mlp_output_centering()benchmark_no_nan_inf()benchmark_unembed_centering()benchmark_value_bias_folding()benchmark_weight_magnitudes()benchmark_weight_modification()benchmark_weight_processing()benchmark_weight_sharing()
- Module contents
BenchmarkResultBenchmarkSeverityPhaseReferenceDatabenchmark_activation_cache()benchmark_activation_cache_structure()benchmark_backward_hooks()benchmark_backward_hooks_structure()benchmark_critical_backward_hooks()benchmark_critical_forward_hooks()benchmark_forward_hooks()benchmark_forward_hooks_structure()benchmark_forward_pass()benchmark_gated_hooks_fire()benchmark_generation()benchmark_generation_with_kv_cache()benchmark_gradient_computation()benchmark_hook_functionality()benchmark_hook_registry()benchmark_logits_equivalence()benchmark_loss_equivalence()benchmark_multiple_generation_calls()benchmark_run_with_cache()benchmark_text_quality()benchmark_weight_modification()benchmark_weight_processing()benchmark_weight_sharing()run_benchmark_suite()
- Submodules
- transformer_lens.cache package
- transformer_lens.components package
- Submodules
- transformer_lens.components.abstract_attention module
- transformer_lens.components.attention module
- transformer_lens.components.bert_block module
- transformer_lens.components.bert_embed module
- transformer_lens.components.bert_mlm_head module
- transformer_lens.components.bert_nsp_head module
- transformer_lens.components.bert_pooler module
- transformer_lens.components.embed module
- transformer_lens.components.grouped_query_attention module
- transformer_lens.components.layer_norm module
- transformer_lens.components.layer_norm_pre module
- transformer_lens.components.pos_embed module
- transformer_lens.components.rms_norm module
- transformer_lens.components.rms_norm_pre module
- transformer_lens.components.t5_attention module
- transformer_lens.components.t5_block module
- transformer_lens.components.token_typed_embed module
- transformer_lens.components.transformer_block module
- transformer_lens.components.unembed module
- Module contents
- Submodules
- transformer_lens.config package
- transformer_lens.conversion_utils package
- transformer_lens.factories package
- transformer_lens.lit package
- Submodules
- transformer_lens.lit.constants module
- transformer_lens.lit.dataset module
- transformer_lens.lit.model module
- transformer_lens.lit.utils module
batch_examples()check_lit_installed()clean_token_string()clean_token_strings()compute_token_gradients()extract_attention_from_cache()extract_embeddings_from_cache()filter_cache_by_pattern()get_hook_name_for_layer()get_model_info()get_tokens_from_model()get_top_k_predictions()numpy_to_tensor()tensor_to_numpy()unbatch_outputs()validate_input_example()
- Module contents
HookedTransformerLITHookedTransformerLIT.__init__()HookedTransformerLIT.description()HookedTransformerLIT.from_pretrained()HookedTransformerLIT.get_embedding_table()HookedTransformerLIT.init_spec()HookedTransformerLIT.input_spec()HookedTransformerLIT.max_minibatch_size()HookedTransformerLIT.output_spec()HookedTransformerLIT.predict()HookedTransformerLIT.supports_concurrent_predictions
HookedTransformerLITConfigHookedTransformerLITConfig.batch_sizeHookedTransformerLITConfig.compute_gradientsHookedTransformerLITConfig.deviceHookedTransformerLITConfig.embedding_layersHookedTransformerLITConfig.max_seq_lengthHookedTransformerLITConfig.output_all_layersHookedTransformerLITConfig.output_attentionHookedTransformerLITConfig.output_embeddingsHookedTransformerLITConfig.prepend_bosHookedTransformerLITConfig.top_k
IOIDatasetInductionDatasetLITWidgetPromptCompletionDatasetPromptCompletionDataset.COMPLETION_FIELDPromptCompletionDataset.FULL_TEXT_FIELDPromptCompletionDataset.PROMPT_FIELDPromptCompletionDataset.__init__()PromptCompletionDataset.__iter__()PromptCompletionDataset.__len__()PromptCompletionDataset.description()PromptCompletionDataset.examplesPromptCompletionDataset.from_pairs()PromptCompletionDataset.spec()
SimpleTextDatasetcheck_lit_installed()serve()wrap_for_lit()
- Submodules
- transformer_lens.model_bridge package
- Subpackages
- Submodules
- transformer_lens.model_bridge.architecture_adapter module
- transformer_lens.model_bridge.bridge module
- transformer_lens.model_bridge.compat module
- transformer_lens.model_bridge.component_setup module
- transformer_lens.model_bridge.composition_scores module
- transformer_lens.model_bridge.exceptions module
- transformer_lens.model_bridge.get_params_util module
- transformer_lens.model_bridge.types module
- Module contents
ArchitectureAdapterArchitectureAdapter.__init__()ArchitectureAdapter.applicable_phasesArchitectureAdapter.convert_hf_key_to_tl_key()ArchitectureAdapter.create_stateful_cache()ArchitectureAdapter.default_cfgArchitectureAdapter.get_component()ArchitectureAdapter.get_component_from_list_module()ArchitectureAdapter.get_component_mapping()ArchitectureAdapter.get_generalized_component()ArchitectureAdapter.get_remote_component()ArchitectureAdapter.prepare_loading()ArchitectureAdapter.prepare_model()ArchitectureAdapter.preprocess_weights()ArchitectureAdapter.required_librariesArchitectureAdapter.required_libraries_groupArchitectureAdapter.setup_component_testing()ArchitectureAdapter.supports_generationArchitectureAdapter.translate_transformer_lens_path()
AttentionBridgeAttentionBridge.W_KAttentionBridge.W_OAttentionBridge.W_QAttentionBridge.W_VAttentionBridge.__init__()AttentionBridge.b_KAttentionBridge.b_OAttentionBridge.b_QAttentionBridge.b_VAttentionBridge.forward()AttentionBridge.get_random_inputs()AttentionBridge.hook_aliasesAttentionBridge.property_aliasesAttentionBridge.set_original_component()AttentionBridge.setup_hook_compatibility()AttentionBridge.supports_split_qkv_fork
BlockBridgeEmbeddingBridgeJointGateUpMLPBridgeJointQKVAttentionBridgeLinearBridgeMLPBridgeMoEBridgeNormalizationBridgeRemoteComponentRemoteModelRemotePathTransformerBridgeTransformerBridge.OVTransformerBridge.OV_for_attn_layers()TransformerBridge.QKTransformerBridge.QK_for_attn_layers()TransformerBridge.W_ETransformerBridge.W_KTransformerBridge.W_OTransformerBridge.W_QTransformerBridge.W_UTransformerBridge.W_VTransformerBridge.W_gateTransformerBridge.W_inTransformerBridge.W_outTransformerBridge.__init__()TransformerBridge.accumulated_bias()TransformerBridge.add_hook()TransformerBridge.add_perma_hook()TransformerBridge.all_composition_scores()TransformerBridge.all_head_labelsTransformerBridge.attn_head_labelsTransformerBridge.b_KTransformerBridge.b_OTransformerBridge.b_QTransformerBridge.b_UTransformerBridge.b_VTransformerBridge.b_inTransformerBridge.b_outTransformerBridge.block_hooks()TransformerBridge.block_submodules()TransformerBridge.blocks_with()TransformerBridge.boot_native()TransformerBridge.boot_transformers()TransformerBridge.check_model_support()TransformerBridge.clear_hook_registry()TransformerBridge.composition_layer_indices()TransformerBridge.cpu()TransformerBridge.cuda()TransformerBridge.enable_compatibility_mode()TransformerBridge.forward()TransformerBridge.generate()TransformerBridge.generate_stream()TransformerBridge.get_hook_point()TransformerBridge.get_params()TransformerBridge.get_token_position()TransformerBridge.hf_generate()TransformerBridge.hook_aliasesTransformerBridge.hook_dictTransformerBridge.hooks()TransformerBridge.layer_types()TransformerBridge.list_supported_models()TransformerBridge.load_state_dict()TransformerBridge.loss_fn()TransformerBridge.mps()TransformerBridge.n_params_totalTransformerBridge.named_parameters()TransformerBridge.original_modelTransformerBridge.parameters()TransformerBridge.prepare_multimodal_inputs()TransformerBridge.process_weights()TransformerBridge.reset_hooks()TransformerBridge.run_with_cache()TransformerBridge.run_with_hooks()TransformerBridge.set_use_attn_in()TransformerBridge.set_use_attn_result()TransformerBridge.set_use_hook_mlp_in()TransformerBridge.set_use_split_qkv_input()TransformerBridge.stack_params_for()TransformerBridge.state_dict()TransformerBridge.tl_named_parameters()TransformerBridge.tl_parameters()TransformerBridge.to()TransformerBridge.to_single_str_token()TransformerBridge.to_single_token()TransformerBridge.to_str_tokens()TransformerBridge.to_string()TransformerBridge.to_tokens()TransformerBridge.tokens_to_residual_directions()
TransformerLensPathUnembeddingBridgereplace_remote_component()set_original_components()setup_blocks_bridge()setup_components()setup_submodules()
- transformer_lens.pretrained package
- transformer_lens.tools package
- transformer_lens.utilities package
- Submodules
- transformer_lens.utilities.activation_functions module
- transformer_lens.utilities.addmm module
- transformer_lens.utilities.aliases module
- transformer_lens.utilities.architectures module
- transformer_lens.utilities.attention module
- transformer_lens.utilities.attribute_utils module
- transformer_lens.utilities.bridge_components module
- transformer_lens.utilities.cache module
- transformer_lens.utilities.components_utils module
- transformer_lens.utilities.defaults_utils module
- transformer_lens.utilities.devices module
- transformer_lens.utilities.exploratory_utils module
- transformer_lens.utilities.gpu_utils module
- transformer_lens.utilities.hf_utils module
- transformer_lens.utilities.initialization_utils module
- transformer_lens.utilities.library_utils module
- transformer_lens.utilities.lm_utils module
- transformer_lens.utilities.logits_utils module
- transformer_lens.utilities.matrix module
- transformer_lens.utilities.multi_gpu module
AvailableDeviceMemorycalculate_available_device_cuda_memory()count_unique_devices()determine_available_memory_for_available_devices()find_embedding_device()get_best_available_cuda_device()get_best_available_device()get_device_for_block_index()resolve_device_map()sort_devices_based_on_available_memory()
- transformer_lens.utilities.slice module
- transformer_lens.utilities.tensors module
- transformer_lens.utilities.tokenize_utils module
- Module contents
- Submodules
Submodules¶
- transformer_lens.HookedAudioEncoder module
HookedAudioEncoderHookedAudioEncoder.OVHookedAudioEncoder.QKHookedAudioEncoder.W_KHookedAudioEncoder.W_OHookedAudioEncoder.W_QHookedAudioEncoder.W_VHookedAudioEncoder.W_inHookedAudioEncoder.W_outHookedAudioEncoder.all_head_labels()HookedAudioEncoder.b_KHookedAudioEncoder.b_OHookedAudioEncoder.b_QHookedAudioEncoder.b_VHookedAudioEncoder.b_inHookedAudioEncoder.b_outHookedAudioEncoder.cpu()HookedAudioEncoder.cuda()HookedAudioEncoder.encoder_output()HookedAudioEncoder.forward()HookedAudioEncoder.from_pretrained()HookedAudioEncoder.hubert_modelHookedAudioEncoder.mps()HookedAudioEncoder.processorHookedAudioEncoder.run_with_cache()HookedAudioEncoder.to()HookedAudioEncoder.to_frames()
- transformer_lens.HookedRootModule module
HookedRootModuleHookedRootModule.add_caching_hooks()HookedRootModule.add_hook()HookedRootModule.add_perma_hook()HookedRootModule.cache_all()HookedRootModule.cache_some()HookedRootModule.check_and_add_hook()HookedRootModule.check_hooks_to_add()HookedRootModule.clear_contexts()HookedRootModule.get_caching_hooks()HookedRootModule.hook_dictHookedRootModule.hook_points()HookedRootModule.hooks()HookedRootModule.mod_dictHookedRootModule.nameHookedRootModule.remove_all_hook_fns()HookedRootModule.reset_hooks()HookedRootModule.run_with_cache()HookedRootModule.run_with_hooks()HookedRootModule.setup()HookedRootModule.training
- transformer_lens.evals module
- transformer_lens.head_detector module
- transformer_lens.hook_points module
- transformer_lens.loading_from_pretrained module
- transformer_lens.patching module
generic_activation_patch()get_act_patch_attn_head_all_pos_every()get_act_patch_attn_head_by_pos_every()get_act_patch_attn_head_k_all_pos()get_act_patch_attn_head_k_by_pos()get_act_patch_attn_head_out_all_pos()get_act_patch_attn_head_out_by_pos()get_act_patch_attn_head_pattern_all_pos()get_act_patch_attn_head_pattern_by_pos()get_act_patch_attn_head_pattern_dest_src_pos()get_act_patch_attn_head_q_all_pos()get_act_patch_attn_head_q_by_pos()get_act_patch_attn_head_v_all_pos()get_act_patch_attn_head_v_by_pos()get_act_patch_attn_out()get_act_patch_block_every()get_act_patch_mlp_out()get_act_patch_resid_mid()get_act_patch_resid_pre()layer_head_dest_src_pos_pattern_patch_setter()layer_head_pattern_patch_setter()layer_head_pos_pattern_patch_setter()layer_head_vector_patch_setter()layer_pos_head_vector_patch_setter()layer_pos_patch_setter()
- transformer_lens.supported_models module
- transformer_lens.train module
HookedTransformerTrainConfigHookedTransformerTrainConfig.batch_sizeHookedTransformerTrainConfig.deviceHookedTransformerTrainConfig.lrHookedTransformerTrainConfig.max_grad_normHookedTransformerTrainConfig.max_stepsHookedTransformerTrainConfig.momentumHookedTransformerTrainConfig.num_epochsHookedTransformerTrainConfig.optimizer_nameHookedTransformerTrainConfig.print_everyHookedTransformerTrainConfig.save_dirHookedTransformerTrainConfig.save_everyHookedTransformerTrainConfig.seedHookedTransformerTrainConfig.wandbHookedTransformerTrainConfig.wandb_project_nameHookedTransformerTrainConfig.warmup_stepsHookedTransformerTrainConfig.weight_decay
train()
- transformer_lens.utils module
LocallyOverridenDefaultsSlicecalc_fan_in_and_fan_out()composition_scores()download_file_from_hf()filter_dict_by_prefix()gelu_fast()gelu_new()get_act_name()get_attention_mask()get_corner()get_cumsum_along_dim()get_dataset()get_device()get_input_with_manually_prepended_bos()get_nested_attr()get_offset_position_ids()get_tokenizer_with_bos()get_tokens_with_bos_removed()init_kaiming_normal_()init_kaiming_uniform_()init_xavier_normal_()init_xavier_uniform_()is_library_available()is_lower_triangular()is_square()keep_single_column()lm_accuracy()lm_cross_entropy_loss()override_or_use_default_value()print_gpu_mem()remove_batch_dim()repeat_along_head_dimension()sample_logits()set_nested_attr()solu()test_prompt()to_numpy()tokenize_and_concatenate()transpose()warn_if_mps()
- transformer_lens.weight_processing module
ProcessWeightsProcessWeights.center_attention_weights()ProcessWeights.center_unembed()ProcessWeights.center_weight_single()ProcessWeights.center_writing_weights()ProcessWeights.convert_tensor_to_hf_format()ProcessWeights.convert_tensor_to_tl_format()ProcessWeights.distribute_weights_to_components()ProcessWeights.extract_attention_tensors_for_folding()ProcessWeights.fold_layer_norm()ProcessWeights.fold_layer_norm_bias_single()ProcessWeights.fold_layer_norm_biases()ProcessWeights.fold_layer_norm_weight_single()ProcessWeights.fold_layer_norm_weights()ProcessWeights.fold_value_biases()ProcessWeights.process_weights()ProcessWeights.refactor_factored_attn_matrices()
Module contents¶
- class transformer_lens.HookedRootModule(*args: Any)¶
Bases:
HookIntrospectionMixin,ModuleA class building on nn.Module to interface nicely with HookPoints.
Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, and run_with_cache to run the model on some input and return a cache of all activations.
Notes:
The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add the fixed version, the broken one is still there. To solve this, run_with_hooks will remove hooks at the end by default, and I recommend using the API of this and run_with_cache. If you want to add hooks into global state, I recommend being intentional about this, and I recommend using reset_hooks liberally in your code to remove any accidentally remaining global state.
The main time this goes wrong is when you want to use backward hooks (to cache or intervene on gradients). In this case, you need to keep the hooks around as global state until you’ve run loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
- add_caching_hooks(names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False, cache: dict | None = None) dict¶
Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
- Parameters:
names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.
device (_type_, optional) – The device to store on. Defaults to same device as model.
remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.
- Returns:
The cache where activations will be stored.
- Return type:
cache (dict)
- add_hook(name: str | Callable[[str], bool], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False) None¶
- add_perma_hook(name: str | Callable[[str], bool], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None¶
- cache_all(cache: dict | None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False)¶
- cache_some(cache: dict | None, names: Callable[[str], bool], incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False)¶
Cache a list of hook provided by names, Boolean function on names
- check_and_add_hook(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False) None¶
Runs checks on the hook, and then adds it to the hook point
- check_hooks_to_add(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, prepend: bool = False) None¶
Override this function to add checks on which hooks should be added
- clear_contexts()¶
- get_caching_hooks(names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False, cache: dict | None = None, pos_slice: Slice | int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None) tuple[dict, list, list]¶
Creates hooks to cache activations. Note: It does not add the hooks to the model.
- Parameters:
names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.
device (_type_, optional) – The device to store on. Keeps on the same device as the layer if None.
remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.
- Returns:
The cache where activations will be stored. fwd_hooks (list): The forward hooks. bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
- Return type:
cache (dict)
- hook_points()¶
- hooks(fwd_hooks: list[tuple[str | Callable, Callable]] = [], bwd_hooks: list[tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False)¶
A context manager for adding temporary hooks to the model.
- Parameters:
fwd_hooks – List[Tuple[name, hook]], where name is either the name of a hook point or a Boolean function on hook names and hook is the function to add to that hook point.
bwd_hooks – Same as fwd_hooks, but for the backward pass.
reset_hooks_end (bool) – If True, removes all hooks added by this context manager when the context manager exits.
clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset.
Example:
with model.hooks(fwd_hooks=my_hooks): hooked_loss = model(text, return_type="loss")
- mod_dict: dict[str, nn.Module]¶
- name: str | None¶
- remove_all_hook_fns(direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: int | None = None)¶
- reset_hooks(clear_contexts: bool = True, direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: int | None = None)¶
- run_with_cache(*model_args: Any, names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, device: device | None = None, remove_batch_dim: bool = False, incl_bwd: bool = False, reset_hooks_end: bool = True, clear_contexts: bool = False, pos_slice: Slice | int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None, **model_kwargs: Any)¶
Runs the model and returns the model output and a Cache object.
- Parameters:
*model_args – Positional arguments for the model.
names_filter (NamesFilter, optional) – A filter for which activations to cache. Accepts None, str, list of str, or a function that takes a string and returns a bool. Defaults to None, which means cache everything.
device (str or torch.Device, optional) – The device to cache activations on. Defaults to the model device. WARNING: Setting a different device than the one used by the model leads to significant performance degradation.
remove_batch_dim (bool, optional) – If True, removes the batch dimension when caching. Only makes sense with batch_size=1 inputs. Defaults to False.
incl_bwd (bool, optional) – If True, calls backward on the model output and caches gradients as well. Assumes that the model outputs a scalar (e.g., return_type=”loss”). Custom loss functions are not supported. Defaults to False.
reset_hooks_end (bool, optional) – If True, removes all hooks added by this function at the end of the run. Defaults to True.
clear_contexts (bool, optional) – If True, clears hook contexts whenever hooks are reset. Defaults to False.
pos_slice – The slice to apply to the cache output. Defaults to None, do nothing.
**model_kwargs – Keyword arguments for the model’s forward function. See your related models forward pass for details as to what sort of arguments you can pass through.
- Returns:
A tuple containing the model output and a Cache object.
- Return type:
tuple
- run_with_hooks(*model_args: Any, fwd_hooks: list[tuple[str | Callable, Callable]] = [], bwd_hooks: list[tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, **model_kwargs: Any)¶
Runs the model with specified forward and backward hooks.
- Parameters:
fwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – A list of (name, hook), where name is either the name of a hook point or a boolean function on hook names, and hook is the function to add to that hook point. Hooks with names that evaluate to True are added respectively.
bwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – Same as fwd_hooks, but for the backward pass.
reset_hooks_end (bool) – If True, all hooks are removed at the end, including those added during this run. Default is True.
clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset. Default is False.
*model_args – Positional arguments for the model.
**model_kwargs – Keyword arguments for the model’s forward function. See your related models forward pass for details as to what sort of arguments you can pass through.
Note
If you want to use backward hooks, set reset_hooks_end to False, so the backward hooks remain active. This function only runs a forward pass.
- setup()¶
Sets up model.
This function must be called in the model’s __init__ method AFTER defining all layers. It adds a parameter to each module containing its name, and builds a dictionary mapping module names to the module instances. It also initializes a hook dictionary for modules of type “HookPoint”.
- class transformer_lens.TransformerBridgeConfig(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)¶
Bases:
TransformerLensConfigConfiguration for TransformerBridge.
This extends TransformerLensConfig with bridge-specific properties, particularly architecture information needed for adapter selection. Also includes all HookedTransformerConfig fields for compatibility.
- __init__(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)¶
Initialize TransformerBridgeConfig.
- property head_dim: int¶
Alias for d_head to match HuggingFace config naming convention.
- class transformer_lens.TransformerLensKeyValueCache(entries: List[TransformerLensKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)¶
Bases:
objectA cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
- __getitem__(idx)¶
- append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])¶
- entries: List[TransformerLensKeyValueCacheEntry]¶
- freeze()¶
- frozen: bool = False¶
- classmethod init_cache(cfg: TransformerLensConfig | HookedTransformerConfig, device: device | str | None, batch_size: int = 1)¶
- previous_attention_mask: Int[Tensor, 'batch pos_so_far']¶
- unfreeze()¶
- class transformer_lens.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)¶
Bases:
object- append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])¶
- frozen: bool = False¶
- classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)¶
- past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']¶
- past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']¶