transformer_lens.lit.constants#
Constants for the LIT integration module.
This module defines constants used throughout the LIT integration with TransformerLens. These include default configuration values, field names, and other settings that ensure consistency across the integration.
Note: LIT (Learning Interpretability Tool) is Google’s framework-agnostic tool for ML model interpretability. See: https://pair-code.github.io/lit/
References
LIT Documentation: https://pair-code.github.io/lit/documentation/
TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
- class transformer_lens.lit.constants.DefaultConfig(MAX_SEQ_LENGTH: int = 512, BATCH_SIZE: int = 8, TOP_K: int = 10, COMPUTE_GRADIENTS: bool = True, OUTPUT_ATTENTION: bool = True, OUTPUT_EMBEDDINGS: bool = True, OUTPUT_ALL_LAYERS: bool = False, EMBEDDING_LAYERS: tuple | None = None, PREPEND_BOS: bool = True, DEVICE: str | None = None, USE_FP16: bool = False)#
Bases:
objectDefault configuration values for the LIT wrapper.
- BATCH_SIZE: int = 8#
- COMPUTE_GRADIENTS: bool = True#
- DEVICE: str | None = None#
- EMBEDDING_LAYERS: tuple | None = None#
- MAX_SEQ_LENGTH: int = 512#
- OUTPUT_ALL_LAYERS: bool = False#
- OUTPUT_ATTENTION: bool = True#
- OUTPUT_EMBEDDINGS: bool = True#
- PREPEND_BOS: bool = True#
- TOP_K: int = 10#
- USE_FP16: bool = False#
- class transformer_lens.lit.constants.ErrorMessages(NO_TOKENIZER: str = 'HookedTransformer has no tokenizer. Please load a model with a tokenizer or set one manually.', INVALID_MODEL: str = 'Model must be an instance of HookedTransformer. Got: {model_type}', LIT_NOT_INSTALLED: str = 'LIT (lit-nlp) is not installed. Please install it with: pip install lit-nlp', INCOMPATIBLE_INPUT: str = 'Input does not match the expected input_spec. Expected fields: {expected}, got: {actual}', BATCH_SIZE_MISMATCH: str = 'Batch size mismatch. Expected {expected}, got {actual}')#
Bases:
objectStandard error messages for the LIT integration.
- BATCH_SIZE_MISMATCH: str = 'Batch size mismatch. Expected {expected}, got {actual}'#
- INCOMPATIBLE_INPUT: str = 'Input does not match the expected input_spec. Expected fields: {expected}, got: {actual}'#
- INVALID_MODEL: str = 'Model must be an instance of HookedTransformer. Got: {model_type}'#
- LIT_NOT_INSTALLED: str = 'LIT (lit-nlp) is not installed. Please install it with: pip install lit-nlp'#
- NO_TOKENIZER: str = 'HookedTransformer has no tokenizer. Please load a model with a tokenizer or set one manually.'#
- class transformer_lens.lit.constants.HookPointNames(HOOK_EMBED: str = 'hook_embed', HOOK_POS_EMBED: str = 'hook_pos_embed', HOOK_TOKENS: str = 'hook_tokens', RESID_PRE_TEMPLATE: str = 'blocks.{layer}.hook_resid_pre', RESID_POST_TEMPLATE: str = 'blocks.{layer}.hook_resid_post', RESID_MID_TEMPLATE: str = 'blocks.{layer}.hook_resid_mid', ATTN_OUT_TEMPLATE: str = 'blocks.{layer}.hook_attn_out', ATTN_PATTERN_TEMPLATE: str = 'blocks.{layer}.attn.hook_pattern', ATTN_SCORES_TEMPLATE: str = 'blocks.{layer}.attn.hook_attn_scores', Q_TEMPLATE: str = 'blocks.{layer}.attn.hook_q', K_TEMPLATE: str = 'blocks.{layer}.attn.hook_k', V_TEMPLATE: str = 'blocks.{layer}.attn.hook_v', MLP_OUT_TEMPLATE: str = 'blocks.{layer}.hook_mlp_out', MLP_PRE_TEMPLATE: str = 'blocks.{layer}.mlp.hook_pre', MLP_POST_TEMPLATE: str = 'blocks.{layer}.mlp.hook_post', LN_FINAL: str = 'ln_final.hook_normalized')#
Bases:
objectCommon hook point names used in TransformerLens.
These correspond to the hook points defined in HookedTransformer where we can intercept and extract intermediate activations.
- ATTN_OUT_TEMPLATE: str = 'blocks.{layer}.hook_attn_out'#
- ATTN_PATTERN_TEMPLATE: str = 'blocks.{layer}.attn.hook_pattern'#
- ATTN_SCORES_TEMPLATE: str = 'blocks.{layer}.attn.hook_attn_scores'#
- HOOK_EMBED: str = 'hook_embed'#
- HOOK_POS_EMBED: str = 'hook_pos_embed'#
- HOOK_TOKENS: str = 'hook_tokens'#
- K_TEMPLATE: str = 'blocks.{layer}.attn.hook_k'#
- LN_FINAL: str = 'ln_final.hook_normalized'#
- MLP_OUT_TEMPLATE: str = 'blocks.{layer}.hook_mlp_out'#
- MLP_POST_TEMPLATE: str = 'blocks.{layer}.mlp.hook_post'#
- MLP_PRE_TEMPLATE: str = 'blocks.{layer}.mlp.hook_pre'#
- Q_TEMPLATE: str = 'blocks.{layer}.attn.hook_q'#
- RESID_MID_TEMPLATE: str = 'blocks.{layer}.hook_resid_mid'#
- RESID_POST_TEMPLATE: str = 'blocks.{layer}.hook_resid_post'#
- RESID_PRE_TEMPLATE: str = 'blocks.{layer}.hook_resid_pre'#
- V_TEMPLATE: str = 'blocks.{layer}.attn.hook_v'#
- class transformer_lens.lit.constants.InputFieldNames(TEXT: str = 'text', TOKENS: str = 'tokens', TOKEN_EMBEDDINGS: str = 'token_embeddings', TARGET: str = 'target', TARGET_MASK: str = 'target_mask')#
Bases:
objectField names for model inputs in LIT.
- TARGET: str = 'target'#
- TARGET_MASK: str = 'target_mask'#
- TEXT: str = 'text'#
- TOKENS: str = 'tokens'#
- TOKEN_EMBEDDINGS: str = 'token_embeddings'#
- class transformer_lens.lit.constants.OutputFieldNames(TOKENS: str = 'tokens', TOKEN_IDS: str = 'token_ids', LOGITS: str = 'logits', TOP_K_TOKENS: str = 'top_k_tokens', GENERATED_TEXT: str = 'generated_text', PROBAS: str = 'probas', LOSS: str = 'loss', LAYER_EMB_TEMPLATE: str = 'layer_{layer}/embeddings', CLS_EMBEDDING: str = 'cls_embedding', MEAN_EMBEDDING: str = 'mean_embedding', ATTENTION_TEMPLATE: str = 'layer_{layer}/head_{head}/attention', LAYER_ATTENTION_TEMPLATE: str = 'layer_{layer}/attention', TOKEN_GRADIENTS: str = 'token_gradients', GRAD_L2: str = 'grad_l2', GRAD_DOT_INPUT: str = 'grad_dot_input', INPUT_EMBEDDINGS: str = 'input_embeddings')#
Bases:
objectField names for model outputs in LIT.
- ATTENTION_TEMPLATE: str = 'layer_{layer}/head_{head}/attention'#
- CLS_EMBEDDING: str = 'cls_embedding'#
- GENERATED_TEXT: str = 'generated_text'#
- GRAD_DOT_INPUT: str = 'grad_dot_input'#
- GRAD_L2: str = 'grad_l2'#
- INPUT_EMBEDDINGS: str = 'input_embeddings'#
- LAYER_ATTENTION_TEMPLATE: str = 'layer_{layer}/attention'#
- LAYER_EMB_TEMPLATE: str = 'layer_{layer}/embeddings'#
- LOGITS: str = 'logits'#
- LOSS: str = 'loss'#
- MEAN_EMBEDDING: str = 'mean_embedding'#
- PROBAS: str = 'probas'#
- TOKENS: str = 'tokens'#
- TOKEN_GRADIENTS: str = 'token_gradients'#
- TOKEN_IDS: str = 'token_ids'#
- TOP_K_TOKENS: str = 'top_k_tokens'#
- class transformer_lens.lit.constants.ServerConfig(DEFAULT_PORT: int = 5432, DEFAULT_HOST: str = 'localhost', DEFAULT_TITLE: str = 'TransformerLens + LIT', DEV_MODE: bool = False, WARM_START: bool = True, MAX_EXAMPLES: int = 1000)#
Bases:
objectDefault configuration for the LIT server.
- DEFAULT_HOST: str = 'localhost'#
- DEFAULT_PORT: int = 5432#
- DEFAULT_TITLE: str = 'TransformerLens + LIT'#
- DEV_MODE: bool = False#
- MAX_EXAMPLES: int = 1000#
- WARM_START: bool = True#