Coverage for transformer_lens/utilities/components_utils.py: 87%

29 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""components_utils. 

2 

3This module contains utility functions related to model components 

4""" 

5 

6from __future__ import annotations 

7 

8import re 

9from typing import Optional, Union 

10 

11 

12def get_act_name( 

13 name: str, 

14 layer: Optional[Union[int, str]] = None, 

15 layer_type: Optional[str] = None, 

16): 

17 """ 

18 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

19 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

20 

21 Returns a name corresponding to an activation point in a TransformerLens model. 

22 

23 Args: 

24 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

25 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

26 that is the layer type. 

27 

28 Given only a word and number, it leaves layer_type as is. 

29 Given only a word, it leaves layer and layer_type as is. 

30 

31 Examples: 

32 get_act_name('embed') = get_act_name('embed', None, None) 

33 get_act_name('k6') = get_act_name('k', 6, None) 

34 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1') 

35 

36 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

37 

38 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

39 

40 Full Examples: 

41 

42 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

43 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

44 get_act_name('embed')=='hook_embed' 

45 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

46 get_act_name('k6')=='blocks.6.attn.hook_k' 

47 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

48 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

49 """ 

50 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 50 ↛ 52line 50 didn't jump to line 52 because the condition on line 50 was never true

51 # If this was called on a full name, just return it 

52 return name 

53 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

54 if match is not None: 

55 name, layer, layer_type = match.groups(0) # type: ignore 

56 

57 layer_type_alias = { 

58 "a": "attn", 

59 "m": "mlp", 

60 "b": "", 

61 "block": "", 

62 "blocks": "", 

63 "attention": "attn", 

64 } 

65 

66 act_name_alias = { 

67 "attn": "pattern", 

68 "attn_logits": "attn_scores", 

69 "key": "k", 

70 "query": "q", 

71 "value": "v", 

72 "mlp_pre": "pre", 

73 "mlp_mid": "mid", 

74 "mlp_post": "post", 

75 } 

76 

77 layer_norm_names = ["scale", "normalized"] 

78 

79 if name in act_name_alias: 

80 name = act_name_alias[name] 

81 

82 full_act_name = "" 

83 if layer is not None: 

84 full_act_name += f"blocks.{layer}." 

85 if name in [ 

86 "k", 

87 "v", 

88 "q", 

89 "z", 

90 "rot_k", 

91 "rot_q", 

92 "result", 

93 "pattern", 

94 "attn_scores", 

95 ]: 

96 layer_type = "attn" 

97 elif name in ["pre", "post", "mid", "pre_linear"]: 

98 layer_type = "mlp" 

99 elif layer_type in layer_type_alias: 99 ↛ 100line 99 didn't jump to line 100 because the condition on line 99 was never true

100 layer_type = layer_type_alias[layer_type] 

101 

102 if layer_type: 

103 full_act_name += f"{layer_type}." 

104 full_act_name += f"hook_{name}" 

105 

106 if name in layer_norm_names and layer is None: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 full_act_name = f"ln_final.{full_act_name}" 

108 return full_act_name