Coverage for transformer_lens/model_bridge/generalized_components/alibi_utils.py: 100%
19 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Shared ALiBi (Attention with Linear Biases) utility functions.
3Used by Bloom and Falcon ALiBi attention bridges to generate positional bias tensors.
4"""
6import math
8import torch
11def build_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
12 """Compute ALiBi per-head slope values.
14 For power-of-2 head counts, slopes are geometric: 2^(-8/n), 2^(-16/n), ...
15 For non-power-of-2, extra slopes are interleaved from a finer geometric series.
16 Matches the HuggingFace implementation.
18 Args:
19 num_heads: Number of attention heads.
20 device: Device for the output tensor.
22 Returns:
23 Slopes tensor of shape [num_heads].
24 """
25 closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
26 base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
27 powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
28 slopes = torch.pow(torch.tensor(base, device=device, dtype=torch.float32), powers)
30 if closest_power_of_2 != num_heads:
31 extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3)))
32 num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
33 extra_powers = torch.arange(
34 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
35 )
36 slopes = torch.cat(
37 [
38 slopes,
39 torch.pow(
40 torch.tensor(extra_base, device=device, dtype=torch.float32), extra_powers
41 ),
42 ],
43 dim=0,
44 )
46 return slopes
49def build_alibi_tensor(
50 attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
51) -> torch.Tensor:
52 """Build ALiBi positional bias tensor.
54 Computes per-head linear biases from token positions, matching HuggingFace's
55 ALiBi implementation used in Bloom and Falcon models.
57 Args:
58 attention_mask: Binary mask of shape [batch_size, seq_length].
59 num_heads: Number of attention heads.
60 dtype: Output dtype.
62 Returns:
63 ALiBi tensor of shape [batch_size, num_heads, 1, seq_length].
64 """
65 batch_size, seq_length = attention_mask.shape
66 slopes = build_alibi_slopes(num_heads, attention_mask.device)
68 # Position indices: 0-indexed cumulative positions masked by attention_mask
69 positions = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
70 # [batch, 1, seq] * [heads, 1, 1] -> [batch, heads, 1, seq]
71 alibi = slopes[None, :, None, None] * positions[:, None, :, :]
72 return alibi.to(dtype)