transformer_lens.model_bridge.generalized_components.base module

Base class for generalized transformer components.

class transformer_lens.model_bridge.generalized_components.base.GeneralizedComponent(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: Module

Base class for generalized transformer components.

This class provides a standardized interface for transformer components and handles hook registration and execution.

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Initialize the generalized component.

Parameters:
  • name – The name of this component (None if component has no container in remote model)

  • config – Optional configuration object for the component

  • submodules – Dictionary of GeneralizedComponent submodules to register

  • conversion_rule – Optional conversion rule for this component’s hooks

  • hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default value in self.hook_aliases.

  • optional – If True, setup skips this subtree when absent (hybrid architectures).

add_hook(hook_fn: Callable[[...], Tensor], hook_name: str = 'output') None

Add a hook function (HookedTransformer-compatible interface).

Parameters:
  • hook_fn – Function to call at this hook point

  • hook_name – Name of the hook point (defaults to “output”)

compatibility_mode: bool = False
disable_warnings: bool = False
forward(*args: Any, **kwargs: Any) Any

Generic forward pass for bridge components with input/output hooks.

get_hooks() Dict[str, HookPoint]

Get all hooks registered in this component.

hook_aliases: Dict[str, str | List[str]] = {}
is_list_item: bool = False
property original_component: Module | None

Get the original component.

property_aliases: Dict[str, str] = {}
remove_hooks(hook_name: str | None = None) None

Remove hooks (HookedTransformer-compatible interface).

Parameters:

hook_name – Name of the hook point to remove. If None, removes all hooks.

set_original_component(original_component: Module) None

Set the original component that this bridge wraps.

Parameters:

original_component – The original transformer component to wrap

set_processed_weights(weights: Dict[str, Tensor], verbose: bool = False) None

Set the processed weights for use in compatibility mode.

This method stores processed weights as attributes on the component so they can be used directly in the forward pass without modifying the original component.

Components should override this method to handle their specific weight structure. The weights dict contains keys like “weight”, “bias”, “W_in”, “W_out”, etc.

If this component has submodules, this method will automatically distribute the weights to those subcomponents using ProcessWeights.distribute_weights_to_components.

Parameters:
  • weights – Dictionary of processed weight tensors

  • verbose – If True, print detailed information about weight setting