Coverage for transformer_lens/model_bridge/generalized_components/bloom_block.py: 24%
42 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""BLOOM-specific block bridge component.
3BLOOM blocks require special arguments (alibi, attention_mask, residual) that standard
4BlockBridge doesn't handle. This custom component generates and passes these arguments.
5"""
6from typing import Any, Dict, Optional
8import torch
10from transformer_lens.model_bridge.generalized_components.alibi_utils import (
11 build_alibi_tensor as _build_alibi_tensor,
12)
13from transformer_lens.model_bridge.generalized_components.base import (
14 GeneralizedComponent,
15)
16from transformer_lens.model_bridge.generalized_components.block import BlockBridge
19class BloomBlockBridge(BlockBridge):
20 """Block bridge for BLOOM models that handles ALiBi positional encoding.
22 BLOOM uses ALiBi (Attention with Linear Biases) instead of standard positional
23 embeddings. This requires generating an alibi tensor and passing it to each block
24 along with the attention_mask.
25 """
27 def __init__(
28 self,
29 name: str,
30 config: Optional[Any] = None,
31 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
32 hook_alias_overrides: Optional[Dict[str, str]] = None,
33 ):
34 """Initialize the BLOOM block bridge.
36 Args:
37 name: The name of the component in the model
38 config: Model configuration (used to get n_heads for ALiBi)
39 submodules: Dictionary of submodules to register
40 hook_alias_overrides: Optional dictionary to override default hook aliases
41 """
42 super().__init__(name, config, submodules, hook_alias_overrides)
43 self.config = config
45 @staticmethod
46 def build_alibi_tensor(
47 attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
48 ) -> torch.Tensor:
49 """Build ALiBi tensor for attention biasing.
51 Delegates to the shared ALiBi utility in alibi_utils.py.
53 Args:
54 attention_mask: Attention mask of shape [batch_size, seq_length]
55 num_heads: Number of attention heads
56 dtype: Data type for the tensor
58 Returns:
59 ALiBi tensor of shape [batch_size, num_heads, 1, seq_length].
60 """
61 return _build_alibi_tensor(attention_mask, num_heads, dtype)
63 def forward(self, *args: Any, **kwargs: Any) -> Any:
64 """Forward pass through the BLOOM block.
66 BLOOM blocks require `alibi` and `attention_mask` arguments. If the HF model's
67 BloomModel.forward() is being called, it will generate these and pass them through.
68 If they're missing (e.g., when called standalone), we generate them here.
70 Args:
71 *args: Positional arguments (first should be hidden_states)
72 **kwargs: Keyword arguments
74 Returns:
75 Output from the original BLOOM block
76 """
77 # Debug: Check if alibi is being passed
78 # print(f"BloomBlockBridge.forward() called with kwargs keys: {list(kwargs.keys())}")
80 if self.original_component is None:
81 raise RuntimeError(
82 f"Original component not set for {self.name}. Call set_original_component() first."
83 )
85 self._check_stop_at_layer(*args, **kwargs)
86 args, kwargs = self._hook_input_hidden_states(args, kwargs)
88 # BLOOM blocks require 'alibi' and 'attention_mask' arguments.
89 # If HF's BloomModel.forward() is calling us, these will already be present.
90 # Only generate them if they're missing (e.g., standalone block testing).
91 if "alibi" not in kwargs or kwargs["alibi"] is None:
92 # Get hidden_states to determine shape
93 if len(args) > 0 and isinstance(args[0], torch.Tensor):
94 hidden_states = args[0]
95 elif "hidden_states" in kwargs:
96 hidden_states = kwargs["hidden_states"]
97 else:
98 raise ValueError("Could not find hidden_states in args or kwargs")
100 batch_size, seq_length, _ = hidden_states.shape
101 device = hidden_states.device
102 dtype = hidden_states.dtype
104 # Generate attention_mask if missing
105 if "attention_mask" not in kwargs or kwargs["attention_mask"] is None:
106 # Create default attention mask (all ones)
107 attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)
108 else:
109 attention_mask = kwargs["attention_mask"]
110 # Ensure it's 2D [batch, seq_length] for ALiBi generation
111 if attention_mask.dim() == 4:
112 # If 4D, we need 2D version for ALiBi generation
113 # Extract the last row which tells us which positions are valid
114 attention_mask_2d = attention_mask[:, 0, -1, :].long()
115 elif attention_mask.dim() == 2:
116 attention_mask_2d = attention_mask
117 else:
118 raise ValueError(
119 f"Unexpected attention_mask dimensions: {attention_mask.dim()}"
120 )
122 # Generate ALiBi bias
123 if self.config and hasattr(self.config, "n_heads"):
124 num_heads = self.config.n_heads
125 else:
126 # Fallback: try to infer from model
127 num_heads = 16 # BLOOM-560M has 16 heads
129 # Generate alibi — shared utility returns [batch, heads, 1, seq],
130 # reshape to [batch*heads, 1, seq] to match HF's format for baddbmm.
131 alibi = self.build_alibi_tensor(attention_mask_2d, num_heads, dtype)
132 alibi = alibi.reshape(batch_size * num_heads, 1, seq_length)
134 # Add alibi to kwargs
135 kwargs["alibi"] = alibi
136 # else: alibi is already present from HF, don't overwrite it!
138 # Call original component
139 output = self.original_component(*args, **kwargs)
140 return self._apply_output_hook(output, wrap_single_element=False)