transformer_lens.config.TransformerBridgeConfig module¶
Configuration class for TransformerBridge.
- class transformer_lens.config.TransformerBridgeConfig.TransformerBridgeConfig(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)¶
Bases:
TransformerLensConfigConfiguration for TransformerBridge.
This extends TransformerLensConfig with bridge-specific properties, particularly architecture information needed for adapter selection. Also includes all HookedTransformerConfig fields for compatibility.
- __init__(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)¶
Initialize TransformerBridgeConfig.
- property head_dim: int¶
Alias for d_head to match HuggingFace config naming convention.