transformer_lens.model_bridge.sources.native package¶
Submodules¶
Module contents¶
TL-native model source for TransformerBridge.
- class transformer_lens.model_bridge.sources.native.NativeAttention(cfg: TransformerBridgeConfig, rotary: NativeRotary | None = None)¶
Bases:
ModuleSplit-QKV causal self-attention. Returns (out, pattern); AttentionBridge fires
hook_patternoff the second element.- causal_mask: torch.Tensor¶
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) tuple[Tensor, Tensor]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.NativeBlock(cfg: TransformerBridgeConfig, rotary: NativeRotary | None = None)¶
Bases:
ModulePre-LN transformer block. Layout adapts to
cfg.attn_onlyandcfg.gated_mlp.- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) tuple[Tensor]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.NativeMLP(cfg: TransformerBridgeConfig)¶
Bases:
ModuleTwo-layer MLP with configurable activation.
- act: Callable[[torch.Tensor], torch.Tensor]¶
- forward(hidden_states: Tensor, **kwargs) Tensor¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transformer_lens.model_bridge.sources.native.NativeModel(cfg: TransformerBridgeConfig)¶
Bases:
ModuleTL-native transformer. See module docstring for the supported feature set.
- forward(input_ids: Tensor, attention_mask: Tensor | None = None, position_ids: Tensor | None = None, **kwargs) Tensor¶
Returns logits directly.
- pos: nn.Embedding | None¶
- rotary: NativeRotary | None¶
- transformer_lens.model_bridge.sources.native.initialize_native_model(model: NativeModel, cfg: TransformerBridgeConfig, seed: int | None = None) None¶
Initialize
modelweights in-place. Honorscfg.init_modeandcfg.seed.