transformer_lens package¶
Subpackages¶
- transformer_lens.benchmarks package
- Submodules
- transformer_lens.benchmarks.activation_cache module
- transformer_lens.benchmarks.audio module
- transformer_lens.benchmarks.backward_gradients module
- transformer_lens.benchmarks.component_benchmark module
- transformer_lens.benchmarks.component_outputs module
- transformer_lens.benchmarks.forward_pass module
- transformer_lens.benchmarks.generation module
- transformer_lens.benchmarks.granular_weight_processing module
- transformer_lens.benchmarks.hook_registration module
- transformer_lens.benchmarks.hook_structure module
- transformer_lens.benchmarks.main_benchmark module
- transformer_lens.benchmarks.multimodal module
- transformer_lens.benchmarks.text_quality module
- transformer_lens.benchmarks.utils module
- transformer_lens.benchmarks.weight_processing module
benchmark_attention_output_centering()benchmark_layer_norm_folding()benchmark_mlp_output_centering()benchmark_no_nan_inf()benchmark_unembed_centering()benchmark_value_bias_folding()benchmark_weight_magnitudes()benchmark_weight_modification()benchmark_weight_processing()benchmark_weight_sharing()
- Module contents
BenchmarkResultBenchmarkSeverityPhaseReferenceDatabenchmark_activation_cache()benchmark_activation_cache_structure()benchmark_backward_hooks()benchmark_backward_hooks_structure()benchmark_critical_backward_hooks()benchmark_critical_forward_hooks()benchmark_forward_hooks()benchmark_forward_hooks_structure()benchmark_forward_pass()benchmark_gated_hooks_fire()benchmark_generation()benchmark_generation_with_kv_cache()benchmark_gradient_computation()benchmark_hook_functionality()benchmark_hook_registry()benchmark_logits_equivalence()benchmark_loss_equivalence()benchmark_multiple_generation_calls()benchmark_run_with_cache()benchmark_text_quality()benchmark_weight_modification()benchmark_weight_processing()benchmark_weight_sharing()run_benchmark_suite()
- Submodules
- transformer_lens.cache package
- transformer_lens.components package
- Submodules
- transformer_lens.components.abstract_attention module
- transformer_lens.components.attention module
- transformer_lens.components.bert_block module
- transformer_lens.components.bert_embed module
- transformer_lens.components.bert_mlm_head module
- transformer_lens.components.bert_nsp_head module
- transformer_lens.components.bert_pooler module
- transformer_lens.components.embed module
- transformer_lens.components.grouped_query_attention module
- transformer_lens.components.layer_norm module
- transformer_lens.components.layer_norm_pre module
- transformer_lens.components.pos_embed module
- transformer_lens.components.rms_norm module
- transformer_lens.components.rms_norm_pre module
- transformer_lens.components.t5_attention module
- transformer_lens.components.t5_block module
- transformer_lens.components.token_typed_embed module
- transformer_lens.components.transformer_block module
- transformer_lens.components.unembed module
- Module contents
- Submodules
- transformer_lens.config package
- transformer_lens.conversion_utils package
- transformer_lens.factories package
- transformer_lens.lit package
- Submodules
- transformer_lens.lit.constants module
- transformer_lens.lit.dataset module
- transformer_lens.lit.model module
- transformer_lens.lit.utils module
batch_examples()check_lit_installed()clean_token_string()clean_token_strings()compute_token_gradients()extract_attention_from_cache()extract_embeddings_from_cache()filter_cache_by_pattern()get_hook_name_for_layer()get_model_info()get_tokens_from_model()get_top_k_predictions()numpy_to_tensor()tensor_to_numpy()unbatch_outputs()validate_input_example()
- Module contents
HookedTransformerLITHookedTransformerLIT.__init__()HookedTransformerLIT.description()HookedTransformerLIT.from_pretrained()HookedTransformerLIT.get_embedding_table()HookedTransformerLIT.init_spec()HookedTransformerLIT.input_spec()HookedTransformerLIT.max_minibatch_size()HookedTransformerLIT.output_spec()HookedTransformerLIT.predict()HookedTransformerLIT.supports_concurrent_predictions
HookedTransformerLITConfigHookedTransformerLITConfig.batch_sizeHookedTransformerLITConfig.compute_gradientsHookedTransformerLITConfig.deviceHookedTransformerLITConfig.embedding_layersHookedTransformerLITConfig.max_seq_lengthHookedTransformerLITConfig.output_all_layersHookedTransformerLITConfig.output_attentionHookedTransformerLITConfig.output_embeddingsHookedTransformerLITConfig.prepend_bosHookedTransformerLITConfig.top_k
IOIDatasetInductionDatasetLITWidgetPromptCompletionDatasetPromptCompletionDataset.COMPLETION_FIELDPromptCompletionDataset.FULL_TEXT_FIELDPromptCompletionDataset.PROMPT_FIELDPromptCompletionDataset.__init__()PromptCompletionDataset.__iter__()PromptCompletionDataset.__len__()PromptCompletionDataset.description()PromptCompletionDataset.examplesPromptCompletionDataset.from_pairs()PromptCompletionDataset.spec()
SimpleTextDatasetcheck_lit_installed()serve()wrap_for_lit()
- Submodules
- transformer_lens.model_bridge package
- Subpackages
- Submodules
- transformer_lens.model_bridge.architecture_adapter module
- transformer_lens.model_bridge.bridge module
- transformer_lens.model_bridge.compat module
- transformer_lens.model_bridge.component_setup module
- transformer_lens.model_bridge.composition_scores module
- transformer_lens.model_bridge.exceptions module
- transformer_lens.model_bridge.get_params_util module
- transformer_lens.model_bridge.types module
- Module contents
ArchitectureAdapterArchitectureAdapter.__init__()ArchitectureAdapter.applicable_phasesArchitectureAdapter.convert_hf_key_to_tl_key()ArchitectureAdapter.create_stateful_cache()ArchitectureAdapter.default_cfgArchitectureAdapter.get_component()ArchitectureAdapter.get_component_from_list_module()ArchitectureAdapter.get_component_mapping()ArchitectureAdapter.get_generalized_component()ArchitectureAdapter.get_remote_component()ArchitectureAdapter.prepare_loading()ArchitectureAdapter.prepare_model()ArchitectureAdapter.preprocess_weights()ArchitectureAdapter.setup_component_testing()ArchitectureAdapter.translate_transformer_lens_path()
AttentionBridgeAttentionBridge.W_KAttentionBridge.W_OAttentionBridge.W_QAttentionBridge.W_VAttentionBridge.__init__()AttentionBridge.b_KAttentionBridge.b_OAttentionBridge.b_QAttentionBridge.b_VAttentionBridge.forward()AttentionBridge.get_random_inputs()AttentionBridge.hook_aliasesAttentionBridge.property_aliasesAttentionBridge.set_original_component()AttentionBridge.setup_hook_compatibility()
BlockBridgeEmbeddingBridgeJointGateUpMLPBridgeJointQKVAttentionBridgeLinearBridgeMLPBridgeMoEBridgeNormalizationBridgeRemoteComponentRemoteModelRemotePathTransformerBridgeTransformerBridge.OVTransformerBridge.OV_for_attn_layers()TransformerBridge.QKTransformerBridge.QK_for_attn_layers()TransformerBridge.W_ETransformerBridge.W_KTransformerBridge.W_OTransformerBridge.W_QTransformerBridge.W_UTransformerBridge.W_VTransformerBridge.W_gateTransformerBridge.W_inTransformerBridge.W_outTransformerBridge.__init__()TransformerBridge.accumulated_bias()TransformerBridge.add_hook()TransformerBridge.all_composition_scores()TransformerBridge.all_head_labelsTransformerBridge.attn_head_labelsTransformerBridge.b_KTransformerBridge.b_OTransformerBridge.b_QTransformerBridge.b_UTransformerBridge.b_VTransformerBridge.b_inTransformerBridge.b_outTransformerBridge.block_hooks()TransformerBridge.block_submodules()TransformerBridge.blocks_with()TransformerBridge.boot_transformers()TransformerBridge.check_model_support()TransformerBridge.clear_hook_registry()TransformerBridge.composition_layer_indices()TransformerBridge.cpu()TransformerBridge.cuda()TransformerBridge.enable_compatibility_mode()TransformerBridge.forward()TransformerBridge.generate()TransformerBridge.generate_stream()TransformerBridge.get_hook_point()TransformerBridge.get_params()TransformerBridge.get_token_position()TransformerBridge.hf_generate()TransformerBridge.hook_aliasesTransformerBridge.hook_dictTransformerBridge.hooks()TransformerBridge.layer_types()TransformerBridge.list_supported_models()TransformerBridge.load_state_dict()TransformerBridge.loss_fn()TransformerBridge.mps()TransformerBridge.named_parameters()TransformerBridge.original_modelTransformerBridge.parameters()TransformerBridge.prepare_multimodal_inputs()TransformerBridge.process_weights()TransformerBridge.reset_hooks()TransformerBridge.run_with_cache()TransformerBridge.run_with_hooks()TransformerBridge.set_use_attn_in()TransformerBridge.set_use_attn_result()TransformerBridge.set_use_split_qkv_input()TransformerBridge.stack_params_for()TransformerBridge.state_dict()TransformerBridge.tl_named_parameters()TransformerBridge.tl_parameters()TransformerBridge.to()TransformerBridge.to_single_str_token()TransformerBridge.to_single_token()TransformerBridge.to_str_tokens()TransformerBridge.to_string()TransformerBridge.to_tokens()TransformerBridge.tokens_to_residual_directions()
TransformerLensPathUnembeddingBridgereplace_remote_component()set_original_components()setup_blocks_bridge()setup_components()setup_submodules()
- transformer_lens.pretrained package
- transformer_lens.tools package
- transformer_lens.utilities package
- Submodules
- transformer_lens.utilities.activation_functions module
- transformer_lens.utilities.addmm module
- transformer_lens.utilities.aliases module
- transformer_lens.utilities.architectures module
- transformer_lens.utilities.attention module
- transformer_lens.utilities.attribute_utils module
- transformer_lens.utilities.bridge_components module
- transformer_lens.utilities.cache module
- transformer_lens.utilities.components_utils module
- transformer_lens.utilities.defaults_utils module
- transformer_lens.utilities.devices module
- transformer_lens.utilities.exploratory_utils module
- transformer_lens.utilities.gpu_utils module
- transformer_lens.utilities.hf_utils module
- transformer_lens.utilities.initialization_utils module
- transformer_lens.utilities.library_utils module
- transformer_lens.utilities.lm_utils module
- transformer_lens.utilities.logits_utils module
- transformer_lens.utilities.matrix module
- transformer_lens.utilities.multi_gpu module
AvailableDeviceMemorycalculate_available_device_cuda_memory()count_unique_devices()determine_available_memory_for_available_devices()find_embedding_device()get_best_available_cuda_device()get_best_available_device()get_device_for_block_index()resolve_device_map()sort_devices_based_on_available_memory()
- transformer_lens.utilities.slice module
- transformer_lens.utilities.tensors module
- transformer_lens.utilities.tokenize_utils module
- Module contents
- Submodules
Submodules¶
- transformer_lens.HookedAudioEncoder module
HookedAudioEncoderHookedAudioEncoder.OVHookedAudioEncoder.QKHookedAudioEncoder.W_KHookedAudioEncoder.W_OHookedAudioEncoder.W_QHookedAudioEncoder.W_VHookedAudioEncoder.W_inHookedAudioEncoder.W_outHookedAudioEncoder.all_head_labels()HookedAudioEncoder.b_KHookedAudioEncoder.b_OHookedAudioEncoder.b_QHookedAudioEncoder.b_VHookedAudioEncoder.b_inHookedAudioEncoder.b_outHookedAudioEncoder.cpu()HookedAudioEncoder.cuda()HookedAudioEncoder.encoder_output()HookedAudioEncoder.forward()HookedAudioEncoder.from_pretrained()HookedAudioEncoder.hubert_modelHookedAudioEncoder.mps()HookedAudioEncoder.processorHookedAudioEncoder.run_with_cache()HookedAudioEncoder.to()HookedAudioEncoder.to_frames()
- transformer_lens.evals module
- transformer_lens.head_detector module
- transformer_lens.hook_points module
HookFunctionHookPointHookedRootModuleHookedRootModule.add_caching_hooks()HookedRootModule.add_hook()HookedRootModule.add_perma_hook()HookedRootModule.cache_all()HookedRootModule.cache_some()HookedRootModule.check_and_add_hook()HookedRootModule.check_hooks_to_add()HookedRootModule.clear_contexts()HookedRootModule.get_caching_hooks()HookedRootModule.hook_dictHookedRootModule.hook_points()HookedRootModule.hooks()HookedRootModule.mod_dictHookedRootModule.nameHookedRootModule.remove_all_hook_fns()HookedRootModule.reset_hooks()HookedRootModule.run_with_cache()HookedRootModule.run_with_hooks()HookedRootModule.setup()
LensHandle
- transformer_lens.loading_from_pretrained module
- transformer_lens.patching module
generic_activation_patch()get_act_patch_attn_head_all_pos_every()get_act_patch_attn_head_by_pos_every()get_act_patch_attn_head_k_all_pos()get_act_patch_attn_head_k_by_pos()get_act_patch_attn_head_out_all_pos()get_act_patch_attn_head_out_by_pos()get_act_patch_attn_head_pattern_all_pos()get_act_patch_attn_head_pattern_by_pos()get_act_patch_attn_head_pattern_dest_src_pos()get_act_patch_attn_head_q_all_pos()get_act_patch_attn_head_q_by_pos()get_act_patch_attn_head_v_all_pos()get_act_patch_attn_head_v_by_pos()get_act_patch_attn_out()get_act_patch_block_every()get_act_patch_mlp_out()get_act_patch_resid_mid()get_act_patch_resid_pre()layer_head_dest_src_pos_pattern_patch_setter()layer_head_pattern_patch_setter()layer_head_pos_pattern_patch_setter()layer_head_vector_patch_setter()layer_pos_head_vector_patch_setter()layer_pos_patch_setter()
- transformer_lens.supported_models module
- transformer_lens.train module
HookedTransformerTrainConfigHookedTransformerTrainConfig.batch_sizeHookedTransformerTrainConfig.deviceHookedTransformerTrainConfig.lrHookedTransformerTrainConfig.max_grad_normHookedTransformerTrainConfig.max_stepsHookedTransformerTrainConfig.momentumHookedTransformerTrainConfig.num_epochsHookedTransformerTrainConfig.optimizer_nameHookedTransformerTrainConfig.print_everyHookedTransformerTrainConfig.save_dirHookedTransformerTrainConfig.save_everyHookedTransformerTrainConfig.seedHookedTransformerTrainConfig.wandbHookedTransformerTrainConfig.wandb_project_nameHookedTransformerTrainConfig.warmup_stepsHookedTransformerTrainConfig.weight_decay
train()
- transformer_lens.utils module
LocallyOverridenDefaultsSlicecalc_fan_in_and_fan_out()composition_scores()download_file_from_hf()filter_dict_by_prefix()gelu_fast()gelu_new()get_act_name()get_attention_mask()get_corner()get_cumsum_along_dim()get_dataset()get_device()get_input_with_manually_prepended_bos()get_nested_attr()get_offset_position_ids()get_tokenizer_with_bos()get_tokens_with_bos_removed()init_kaiming_normal_()init_kaiming_uniform_()init_xavier_normal_()init_xavier_uniform_()is_library_available()is_lower_triangular()is_square()keep_single_column()lm_accuracy()lm_cross_entropy_loss()override_or_use_default_value()print_gpu_mem()remove_batch_dim()repeat_along_head_dimension()sample_logits()set_nested_attr()solu()test_prompt()to_numpy()tokenize_and_concatenate()transpose()warn_if_mps()
- transformer_lens.weight_processing module
ProcessWeightsProcessWeights.center_attention_weights()ProcessWeights.center_unembed()ProcessWeights.center_weight_single()ProcessWeights.center_writing_weights()ProcessWeights.convert_tensor_to_hf_format()ProcessWeights.convert_tensor_to_tl_format()ProcessWeights.distribute_weights_to_components()ProcessWeights.extract_attention_tensors_for_folding()ProcessWeights.fold_layer_norm()ProcessWeights.fold_layer_norm_bias_single()ProcessWeights.fold_layer_norm_biases()ProcessWeights.fold_layer_norm_weight_single()ProcessWeights.fold_layer_norm_weights()ProcessWeights.fold_value_biases()ProcessWeights.process_weights()ProcessWeights.refactor_factored_attn_matrices()
Module contents¶
- class transformer_lens.TransformerLensKeyValueCache(entries: List[TransformerLensKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)¶
Bases:
objectA cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
- __getitem__(idx)¶
- append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])¶
- entries: List[TransformerLensKeyValueCacheEntry]¶
- freeze()¶
- frozen: bool = False¶
- classmethod init_cache(cfg: TransformerLensConfig | HookedTransformerConfig, device: device | str | None, batch_size: int = 1)¶
- previous_attention_mask: Int[Tensor, 'batch pos_so_far']¶
- unfreeze()¶
- class transformer_lens.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)¶
Bases:
object- append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])¶
- frozen: bool = False¶
- classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)¶
- past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']¶
- past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']¶