Coverage for transformer_lens/model_bridge/supported_architectures/granite.py: 61%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Granite architecture adapter. 

2 

3Base adapter for the IBM Granite model family. Provides shared config setup and 

4helper methods used by GraniteMoe and GraniteMoeHybrid variants. 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 GatedMLPBridge, 

14 LinearBridge, 

15 PositionEmbeddingsAttentionBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class GraniteArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for IBM Granite models (dense). 

24 

25 Granite is a Llama-like architecture with RMSNorm, rotary position embeddings 

26 (RoPE), GQA, and a gated MLP (SiLU activation). Granite-specific scaling 

27 multipliers are handled by the HF model's native forward pass. 

28 

29 Optional Parameters (may not exist in state_dict): 

30 ------------------------------------------------- 

31 Granite models do NOT have biases on attention and MLP projections: 

32 

33 - blocks.{i}.attn.b_Q/b_K/b_V/b_O - No bias on attention projections 

34 - blocks.{i}.mlp.b_in/b_gate/b_out - No bias on MLP projections 

35 - blocks.{i}.ln1.b, blocks.{i}.ln2.b, ln_final.b - RMSNorm has no bias 

36 """ 

37 

38 def __init__(self, cfg: Any) -> None: 

39 """Initialize the Granite architecture adapter.""" 

40 super().__init__(cfg) 

41 

42 self._setup_common_config(cfg) 

43 self.weight_processing_conversions = {**self._qkvo_weight_conversions()} 

44 self.component_mapping = self._build_component_mapping() 

45 

46 def _setup_common_config(self, cfg: Any) -> None: 

47 """Set up config variables shared across all Granite variants.""" 

48 self.cfg.normalization_type = "RMS" 

49 self.cfg.positional_embedding_type = "rotary" 

50 self.cfg.final_rms = True 

51 self.cfg.gated_mlp = True 

52 self.cfg.attn_only = False 

53 self.cfg.uses_rms_norm = True 

54 self.cfg.default_prepend_bos = False 

55 self.cfg.eps_attr = "variance_epsilon" 

56 

57 self.default_config = { 

58 "d_model": cfg.d_model, 

59 "d_head": cfg.d_model // cfg.n_heads, 

60 "n_heads": cfg.n_heads, 

61 "n_layers": cfg.n_layers, 

62 "d_vocab": cfg.d_vocab, 

63 } 

64 

65 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 65 ↛ exitline 65 didn't return from function '_setup_common_config' because the condition on line 65 was always true

66 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

67 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

68 

69 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge: 

70 """Build the standard Granite attention bridge.""" 

71 return PositionEmbeddingsAttentionBridge( 

72 name="self_attn", 

73 config=self.cfg, 

74 optional=optional, 

75 submodules={ 

76 "q": LinearBridge(name="q_proj"), 

77 "k": LinearBridge(name="k_proj"), 

78 "v": LinearBridge(name="v_proj"), 

79 "o": LinearBridge(name="o_proj"), 

80 }, 

81 requires_attention_mask=True, 

82 requires_position_embeddings=True, 

83 ) 

84 

85 def _build_mlp_bridge(self) -> GatedMLPBridge: 

86 """Build the dense gated MLP bridge.""" 

87 return GatedMLPBridge( 

88 name="mlp", 

89 config=self.cfg, 

90 submodules={ 

91 "gate": LinearBridge(name="gate_proj"), 

92 "in": LinearBridge(name="up_proj"), 

93 "out": LinearBridge(name="down_proj"), 

94 }, 

95 ) 

96 

97 def _build_component_mapping(self) -> dict: 

98 """Build the full component mapping for dense Granite.""" 

99 return { 

100 "embed": EmbeddingBridge(name="model.embed_tokens"), 

101 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

102 "blocks": BlockBridge( 

103 name="model.layers", 

104 submodules={ 

105 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

106 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

107 "attn": self._build_attention_bridge(), 

108 "mlp": self._build_mlp_bridge(), 

109 }, 

110 ), 

111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

112 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

113 } 

114 

115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

116 """Set up rotary embedding references for Granite component testing. 

117 

118 Args: 

119 hf_model: The HuggingFace Granite model instance 

120 bridge_model: The TransformerBridge model (if available) 

121 """ 

122 if not hasattr(hf_model.model, "rotary_emb"): 

123 return 

124 

125 rotary_emb = hf_model.model.rotary_emb 

126 

127 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

128 for block in bridge_model.blocks: 

129 if "attn" in block._modules: 

130 block.attn.set_rotary_emb(rotary_emb) 

131 

132 try: 

133 attn_bridge = self.get_generalized_component("blocks.0.attn") 

134 attn_bridge.set_rotary_emb(rotary_emb) 

135 except (AttributeError, KeyError, ValueError): 

136 pass