transformer_lens.model_bridge.generalized_components.base module¶
Base class for generalized transformer components.
- class transformer_lens.model_bridge.generalized_components.base.GeneralizedComponent(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
ModuleBase class for generalized transformer components.
This class provides a standardized interface for transformer components and handles hook registration and execution.
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Initialize the generalized component.
- Parameters:
name – The name of this component (None if component has no container in remote model)
config – Optional configuration object for the component
submodules – Dictionary of GeneralizedComponent submodules to register
conversion_rule – Optional conversion rule for this component’s hooks
hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default value in self.hook_aliases.
optional – If True, setup skips this subtree when absent (hybrid architectures).
- add_hook(hook_fn: Callable[[...], Tensor], hook_name: str = 'output') None¶
Add a hook function (HookedTransformer-compatible interface).
- Parameters:
hook_fn – Function to call at this hook point
hook_name – Name of the hook point (defaults to “output”)
- compatibility_mode: bool = False¶
- disable_warnings: bool = False¶
- forward(*args: Any, **kwargs: Any) Any¶
Generic forward pass for bridge components with input/output hooks.
- hook_aliases: Dict[str, str | List[str]] = {}¶
- is_list_item: bool = False¶
- property original_component: Module | None¶
Get the original component.
- property_aliases: Dict[str, str] = {}¶
- remove_hooks(hook_name: str | None = None) None¶
Remove hooks (HookedTransformer-compatible interface).
- Parameters:
hook_name – Name of the hook point to remove. If None, removes all hooks.
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps.
- Parameters:
original_component – The original transformer component to wrap
- set_processed_weights(weights: Dict[str, Tensor], verbose: bool = False) None¶
Set the processed weights for use in compatibility mode.
This method stores processed weights as attributes on the component so they can be used directly in the forward pass without modifying the original component.
Components should override this method to handle their specific weight structure. The weights dict contains keys like “weight”, “bias”, “W_in”, “W_out”, etc.
If this component has submodules, this method will automatically distribute the weights to those subcomponents using ProcessWeights.distribute_weights_to_components.
- Parameters:
weights – Dictionary of processed weight tensors
verbose – If True, print detailed information about weight setting