transformer_lens.model_bridge.generalized_components.mla_attention module

Multi-Head Latent Attention (MLA) bridge component for DeepSeek models.

MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style projections before standard attention. This component reimplements the MLA forward path step-by-step with hooks at each meaningful stage, exposing:

  • hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck)

  • hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE)

  • hook_rot_q / hook_rot_k: after RoPE on the rope portion splits

  • hook_attn_scores / hook_pattern: pre/post-softmax attention weights

  • hook_z: pre-output-projection (alias for o.hook_in)

class transformer_lens.model_bridge.generalized_components.mla_attention.MLAAttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)

Bases: PositionEmbeddingHooksMixin, AttentionBridge

Bridge for DeepSeek’s Multi-Head Latent Attention (MLA).

Reimplements the MLA forward path with hooks at each computation stage. Standard W_Q/W_K/W_V properties are not available on MLA models — use the submodule weight access (q_a_proj, q_b_proj, etc.) instead.

forward(*args: Any, **kwargs: Any) Any

Reimplemented MLA forward with hooks at each computation stage.

Follows the DeepseekV3Attention forward path, calling into HF submodules individually and firing hooks at each meaningful stage.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate test inputs with hidden_states, position_embeddings, and attention_mask.

hook_aliases: Dict[str, str | List[str]] = {'hook_result': 'hook_out', 'hook_z': 'o.hook_in'}
property_aliases: Dict[str, str] = {}
real_components: Dict[str, tuple]
training: bool