Coverage for transformer_lens/model_bridge/generalized_components/bloom_block.py: 24%

42 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""BLOOM-specific block bridge component. 

2 

3BLOOM blocks require special arguments (alibi, attention_mask, residual) that standard 

4BlockBridge doesn't handle. This custom component generates and passes these arguments. 

5""" 

6from typing import Any, Dict, Optional 

7 

8import torch 

9 

10from transformer_lens.model_bridge.generalized_components.alibi_utils import ( 

11 build_alibi_tensor as _build_alibi_tensor, 

12) 

13from transformer_lens.model_bridge.generalized_components.base import ( 

14 GeneralizedComponent, 

15) 

16from transformer_lens.model_bridge.generalized_components.block import BlockBridge 

17 

18 

19class BloomBlockBridge(BlockBridge): 

20 """Block bridge for BLOOM models that handles ALiBi positional encoding. 

21 

22 BLOOM uses ALiBi (Attention with Linear Biases) instead of standard positional 

23 embeddings. This requires generating an alibi tensor and passing it to each block 

24 along with the attention_mask. 

25 """ 

26 

27 def __init__( 

28 self, 

29 name: str, 

30 config: Optional[Any] = None, 

31 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

32 hook_alias_overrides: Optional[Dict[str, str]] = None, 

33 ): 

34 """Initialize the BLOOM block bridge. 

35 

36 Args: 

37 name: The name of the component in the model 

38 config: Model configuration (used to get n_heads for ALiBi) 

39 submodules: Dictionary of submodules to register 

40 hook_alias_overrides: Optional dictionary to override default hook aliases 

41 """ 

42 super().__init__(name, config, submodules, hook_alias_overrides) 

43 self.config = config 

44 

45 @staticmethod 

46 def build_alibi_tensor( 

47 attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype 

48 ) -> torch.Tensor: 

49 """Build ALiBi tensor for attention biasing. 

50 

51 Delegates to the shared ALiBi utility in alibi_utils.py. 

52 

53 Args: 

54 attention_mask: Attention mask of shape [batch_size, seq_length] 

55 num_heads: Number of attention heads 

56 dtype: Data type for the tensor 

57 

58 Returns: 

59 ALiBi tensor of shape [batch_size, num_heads, 1, seq_length]. 

60 """ 

61 return _build_alibi_tensor(attention_mask, num_heads, dtype) 

62 

63 def forward(self, *args: Any, **kwargs: Any) -> Any: 

64 """Forward pass through the BLOOM block. 

65 

66 BLOOM blocks require `alibi` and `attention_mask` arguments. If the HF model's 

67 BloomModel.forward() is being called, it will generate these and pass them through. 

68 If they're missing (e.g., when called standalone), we generate them here. 

69 

70 Args: 

71 *args: Positional arguments (first should be hidden_states) 

72 **kwargs: Keyword arguments 

73 

74 Returns: 

75 Output from the original BLOOM block 

76 """ 

77 # Debug: Check if alibi is being passed 

78 # print(f"BloomBlockBridge.forward() called with kwargs keys: {list(kwargs.keys())}") 

79 

80 if self.original_component is None: 

81 raise RuntimeError( 

82 f"Original component not set for {self.name}. Call set_original_component() first." 

83 ) 

84 

85 self._check_stop_at_layer(*args, **kwargs) 

86 args, kwargs = self._hook_input_hidden_states(args, kwargs) 

87 

88 # BLOOM blocks require 'alibi' and 'attention_mask' arguments. 

89 # If HF's BloomModel.forward() is calling us, these will already be present. 

90 # Only generate them if they're missing (e.g., standalone block testing). 

91 if "alibi" not in kwargs or kwargs["alibi"] is None: 

92 # Get hidden_states to determine shape 

93 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

94 hidden_states = args[0] 

95 elif "hidden_states" in kwargs: 

96 hidden_states = kwargs["hidden_states"] 

97 else: 

98 raise ValueError("Could not find hidden_states in args or kwargs") 

99 

100 batch_size, seq_length, _ = hidden_states.shape 

101 device = hidden_states.device 

102 dtype = hidden_states.dtype 

103 

104 # Generate attention_mask if missing 

105 if "attention_mask" not in kwargs or kwargs["attention_mask"] is None: 

106 # Create default attention mask (all ones) 

107 attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device) 

108 else: 

109 attention_mask = kwargs["attention_mask"] 

110 # Ensure it's 2D [batch, seq_length] for ALiBi generation 

111 if attention_mask.dim() == 4: 

112 # If 4D, we need 2D version for ALiBi generation 

113 # Extract the last row which tells us which positions are valid 

114 attention_mask_2d = attention_mask[:, 0, -1, :].long() 

115 elif attention_mask.dim() == 2: 

116 attention_mask_2d = attention_mask 

117 else: 

118 raise ValueError( 

119 f"Unexpected attention_mask dimensions: {attention_mask.dim()}" 

120 ) 

121 

122 # Generate ALiBi bias 

123 if self.config and hasattr(self.config, "n_heads"): 

124 num_heads = self.config.n_heads 

125 else: 

126 # Fallback: try to infer from model 

127 num_heads = 16 # BLOOM-560M has 16 heads 

128 

129 # Generate alibi — shared utility returns [batch, heads, 1, seq], 

130 # reshape to [batch*heads, 1, seq] to match HF's format for baddbmm. 

131 alibi = self.build_alibi_tensor(attention_mask_2d, num_heads, dtype) 

132 alibi = alibi.reshape(batch_size * num_heads, 1, seq_length) 

133 

134 # Add alibi to kwargs 

135 kwargs["alibi"] = alibi 

136 # else: alibi is already present from HF, don't overwrite it! 

137 

138 # Call original component 

139 output = self.original_component(*args, **kwargs) 

140 return self._apply_output_hook(output, wrap_single_element=False)