Coverage for transformer_lens/model_bridge/generalized_components/alibi_utils.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Shared ALiBi (Attention with Linear Biases) utility functions. 

2 

3Used by Bloom and Falcon ALiBi attention bridges to generate positional bias tensors. 

4""" 

5 

6import math 

7 

8import torch 

9 

10 

11def build_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: 

12 """Compute ALiBi per-head slope values. 

13 

14 For power-of-2 head counts, slopes are geometric: 2^(-8/n), 2^(-16/n), ... 

15 For non-power-of-2, extra slopes are interleaved from a finer geometric series. 

16 Matches the HuggingFace implementation. 

17 

18 Args: 

19 num_heads: Number of attention heads. 

20 device: Device for the output tensor. 

21 

22 Returns: 

23 Slopes tensor of shape [num_heads]. 

24 """ 

25 closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) 

26 base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))) 

27 powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) 

28 slopes = torch.pow(torch.tensor(base, device=device, dtype=torch.float32), powers) 

29 

30 if closest_power_of_2 != num_heads: 

31 extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))) 

32 num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) 

33 extra_powers = torch.arange( 

34 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32 

35 ) 

36 slopes = torch.cat( 

37 [ 

38 slopes, 

39 torch.pow( 

40 torch.tensor(extra_base, device=device, dtype=torch.float32), extra_powers 

41 ), 

42 ], 

43 dim=0, 

44 ) 

45 

46 return slopes 

47 

48 

49def build_alibi_tensor( 

50 attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype 

51) -> torch.Tensor: 

52 """Build ALiBi positional bias tensor. 

53 

54 Computes per-head linear biases from token positions, matching HuggingFace's 

55 ALiBi implementation used in Bloom and Falcon models. 

56 

57 Args: 

58 attention_mask: Binary mask of shape [batch_size, seq_length]. 

59 num_heads: Number of attention heads. 

60 dtype: Output dtype. 

61 

62 Returns: 

63 ALiBi tensor of shape [batch_size, num_heads, 1, seq_length]. 

64 """ 

65 batch_size, seq_length = attention_mask.shape 

66 slopes = build_alibi_slopes(num_heads, attention_mask.device) 

67 

68 # Position indices: 0-indexed cumulative positions masked by attention_mask 

69 positions = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] 

70 # [batch, 1, seq] * [heads, 1, 1] -> [batch, heads, 1, seq] 

71 alibi = slopes[None, :, None, None] * positions[:, None, :, :] 

72 return alibi.to(dtype)