transformer_lens.model_bridge.supported_architectures.gemma4 module

Gemma 4 architecture adapter.

Bridges the text path of Gemma4ForConditionalGeneration (model.language_model + lm_head) and the vision pipeline. For the standard variants (E2B / E4B / 31B / 26B-A4B) the vision encoder (model.vision_tower) and projector (model.embed_vision) are both bridged, enabling Phase 7 multimodal testing.

The same adapter also covers Gemma4UnifiedForConditionalGeneration (the encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict structural subset — same module paths, no PLE and no MoE, both optional here. It is still multimodal but has no vision_towermodel.embed_vision is the full vision pipeline (raw-patch projection), mapped as the projector only.

Per-layer structure is heterogeneous across the family, so all math is deferred to HF and submodules are decomposed only for hooks (parity-safe delegation):

  • KV sharing (E2B/E4B): the last num_kv_shared_layers layers reuse earlier KV states and drop their own k_proj / v_proj / k_norm / v_norm.

  • K==V attention (31B / 26B-A4B): global-attention layers share key and value weights (attention_k_eq_v) and have no v_proj.

  • Per-Layer Embeddings (E2B/E4B): each layer mixes in a per-layer input via per_layer_input_gate / per_layer_projection / post_per_layer_input_norm.

  • MoE (26B-A4B): layers add a router + batched experts block in parallel with the dense MLP, sandwiched by three extra norms.

Unlike Gemma 1-3, Gemma4RMSNorm multiplies by weight directly — there is no (1.0 + weight) offset.

class transformer_lens.model_bridge.supported_architectures.gemma4.Gemma4ArchitectureAdapter(cfg: Any)

Bases: ArchitectureAdapter

Adapter for Gemma 4 (Gemma4ForConditionalGeneration — multimodal, or Gemma4UnifiedForConditionalGeneration — text-only 12B).

applicable_phases: list[int] = [1, 2, 4]
component_mapping: ComponentMapping | None
setup_component_testing(hf_model: Any, bridge_model: Any = None) None

Force eager attention so bridge and HF match (sliding/full layer mix).

uses_split_attention: bool
weight_processing_conversions: dict