transformer_lens.model_bridge.sources.native.model module¶
TL-native transformer for TransformerBridge — minimal, no HF/HT dependency.
Cfg-driven features: normalization_type (LN / RMS / RMSPre), final_rms,
gated_mlp, attn_only, n_key_value_heads (GQA), attn_scores_soft_cap,
output_logits_soft_cap, positional_embedding_type (standard / rotary),
rotary_dim / rotary_base / rope_scaling (linear PI, dynamic/NTK,
llama3 by-parts).
- class transformer_lens.model_bridge.sources.native.model.NativeAttention(cfg: TransformerBridgeConfig, rotary: NativeRotary | None = None)¶
Bases:
ModuleSplit-QKV causal self-attention. Returns (out, pattern); AttentionBridge fires
hook_patternoff the second element.- causal_mask: torch.Tensor¶
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) tuple[Tensor, Tensor]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class transformer_lens.model_bridge.sources.native.model.NativeBlock(cfg: TransformerBridgeConfig, rotary: NativeRotary | None = None)¶
Bases:
ModulePre-LN transformer block. Layout adapts to
cfg.attn_onlyandcfg.gated_mlp.- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) tuple[Tensor]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.model.NativeGatedMLP(cfg: TransformerBridgeConfig)¶
Bases:
ModuleSwiGLU / ReGLU / GeGLU gated MLP (variant picked by
cfg.act_fn).Submodules
gate/in/outmatch GatedMLPBridge’s expected slots.- act: Callable[[torch.Tensor], torch.Tensor]¶
- forward(hidden_states: Tensor, **kwargs) Tensor¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.model.NativeMLP(cfg: TransformerBridgeConfig)¶
Bases:
ModuleTwo-layer MLP with configurable activation.
- act: Callable[[torch.Tensor], torch.Tensor]¶
- forward(hidden_states: Tensor, **kwargs) Tensor¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class transformer_lens.model_bridge.sources.native.model.NativeModel(cfg: TransformerBridgeConfig)¶
Bases:
ModuleTL-native transformer. See module docstring for the supported feature set.
- forward(input_ids: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) Tensor¶
Returns logits directly.
- pos: nn.Embedding | None¶
- rotary: NativeRotary | None¶
- training: bool¶
- class transformer_lens.model_bridge.sources.native.model.NativeRMSNorm(d_model: int, eps: float = 1e-05)¶
Bases:
ModuleLlama-style RMSNorm. Variance in fp32 regardless of input dtype, then cast back before the per-channel scale (matches HF LlamaRMSNorm).
- forward(x: Tensor) Tensor¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.model.NativeRotary(cfg: TransformerBridgeConfig)¶
Bases:
ModuleShared cos/sin tables for RoPE. Honors
cfg.rope_scaling.- apply_rope(q: Tensor, k: Tensor, *, position_ids: Tensor | None = None) tuple[Tensor, Tensor]¶
Apply RoPE to Q/K of shape [batch, heads, seq, d_head].
Named
apply_roperather thanapplysonn.Module.apply(fn)— PyTorch’s recursive function-application utility used bybridge.apply(init_fn)— isn’t shadowed.
- cos_cached: torch.Tensor¶
- sin_cached: torch.Tensor¶