transformer_lens.model_bridge.supported_architectures.neox module

NeoX architecture adapter.

class transformer_lens.model_bridge.supported_architectures.neox.NeoxArchitectureAdapter(cfg: Any)

Bases: ArchitectureAdapter

Architecture adapter for NeoX models.

__init__(cfg: Any) None

Initialize the NeoX architecture adapter.

Parameters:

cfg – The configuration object.

setup_component_testing(hf_model: Any, bridge_model: Any = None) None

Set up rotary embedding references for GPT-NeoX/StableLM component testing.

GPT-NeoX models use RoPE (Rotary Position Embeddings) which need to be set on all attention bridge instances for component testing.

Parameters:
  • hf_model – The HuggingFace GPT-NeoX model instance

  • bridge_model – The TransformerBridge model (if available, set rotary_emb on actual instances)

split_qkv_matrix(original_attention_component: Any) tuple[Linear, Linear, Linear]

Split the QKV matrix into separate linear transformations.

GPT-NeoX/StableLM uses an interleaved QKV format where the weights are stored as [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, …] - i.e., Q, K, V are interleaved per head.

The weight shape is [n_heads * 3 * d_head, d_model] and the output is reshaped by HuggingFace as [batch, seq, n_heads, 3*d_head] then split on the last dim.

Parameters:

original_attention_component – The original attention layer component

Returns:

Tuple of nn.Linear modules for Q, K, and V transformations