transformer_lens.model_bridge.generalized_components.mla_attention module¶
Multi-Head Latent Attention (MLA) bridge component for DeepSeek models.
MLA compresses Q and KV into lower-dimensional latent spaces via LoRA-style projections before standard attention. This component reimplements the MLA forward path step-by-step with hooks at each meaningful stage, exposing:
hook_q_latent / hook_kv_latent: compressed representations (the information bottleneck)
hook_q / hook_k / hook_v: final Q/K/V entering attention (post-decompression, post-RoPE)
hook_rot_q / hook_rot_k: after RoPE on the rope portion splits
hook_attn_scores / hook_pattern: pre/post-softmax attention weights
hook_z: pre-output-projection (alias for o.hook_in)
- class transformer_lens.model_bridge.generalized_components.mla_attention.MLAAttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)¶
Bases:
PositionEmbeddingHooksMixin,AttentionBridgeBridge for DeepSeek’s Multi-Head Latent Attention (MLA).
Reimplements the MLA forward path with hooks at each computation stage. Standard W_Q/W_K/W_V properties are not available on MLA models — use the submodule weight access (q_a_proj, q_b_proj, etc.) instead.
- forward(*args: Any, **kwargs: Any) Any¶
Reimplemented MLA forward with hooks at each computation stage.
Follows the DeepseekV3Attention forward path, calling into HF submodules individually and firing hooks at each meaningful stage.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate test inputs with hidden_states, position_embeddings, and attention_mask.
- hook_aliases: Dict[str, str | List[str]] = {'hook_result': 'hook_out', 'hook_z': 'o.hook_in'}¶
- property_aliases: Dict[str, str] = {}¶
- real_components: Dict[str, tuple]¶
- training: bool¶