transformer_lens.loading_from_pretrained#

Loading Pretrained Models Utilities.

This module contains functions for loading pretrained models from the Hugging Face Hub.

class transformer_lens.loading_from_pretrained.Config(d_model: int = 768, debug: bool = True, layer_norm_eps: float = 1e-05, d_vocab: int = 50257, init_range: float = 0.02, n_ctx: int = 1024, d_head: int = 64, d_mlp: int = 3072, n_heads: int = 12, n_layers: int = 12)#

Bases: object

d_head: int = 64#
d_mlp: int = 3072#
d_model: int = 768#
d_vocab: int = 50257#
debug: bool = True#
init_range: float = 0.02#
layer_norm_eps: float = 1e-05#
n_ctx: int = 1024#
n_heads: int = 12#
n_layers: int = 12#
transformer_lens.loading_from_pretrained.MODEL_ALIASES = {'01-ai/Yi-34B': ['yi-34b', 'Yi-34B'], '01-ai/Yi-34B-Chat': ['yi-34b-chat', 'Yi-34B-Chat'], '01-ai/Yi-6B': ['yi-6b', 'Yi-6B'], '01-ai/Yi-6B-Chat': ['yi-6b-chat', 'Yi-6B-Chat'], 'ArthurConmy/redwood_attn_2l': ['redwood_attn_2l'], 'Baidicoot/Othello-GPT-Transformer-Lens': ['othello-gpt'], 'CodeLlama-7b-Instruct-hf': ['CodeLlama-7b-instruct', 'codellama/CodeLlama-7b-Instruct-hf'], 'CodeLlama-7b-Python-hf': ['CodeLlama-7b-python', 'codellama/CodeLlama-7b-Python-hf'], 'CodeLlama-7b-hf': ['CodeLlamallama-2-7b', 'codellama/CodeLlama-7b-hf'], 'EleutherAI/gpt-j-6B': ['gpt-j-6B', 'gpt-j', 'gptj'], 'EleutherAI/gpt-neo-1.3B': ['gpt-neo-1.3B', 'gpt-neo-medium', 'neo-medium'], 'EleutherAI/gpt-neo-125M': ['gpt-neo-125M', 'gpt-neo-small', 'neo-small', 'neo'], 'EleutherAI/gpt-neo-2.7B': ['gpt-neo-2.7B', 'gpt-neo-large', 'neo-large'], 'EleutherAI/gpt-neox-20b': ['gpt-neox-20b', 'gpt-neox', 'neox'], 'EleutherAI/pythia-1.4b': ['pythia-1.4b', 'EleutherAI/pythia-1.3b', 'pythia-1.3b'], 'EleutherAI/pythia-1.4b-deduped': ['pythia-1.4b-deduped', 'EleutherAI/pythia-1.3b-deduped', 'pythia-1.3b-deduped'], 'EleutherAI/pythia-1.4b-deduped-v0': ['pythia-1.4b-deduped-v0', 'EleutherAI/pythia-1.3b-deduped-v0', 'pythia-1.3b-deduped-v0'], 'EleutherAI/pythia-1.4b-v0': ['pythia-1.4b-v0', 'EleutherAI/pythia-1.3b-v0', 'pythia-1.3b-v0'], 'EleutherAI/pythia-12b': ['pythia-12b', 'EleutherAI/pythia-13b', 'pythia-13b'], 'EleutherAI/pythia-12b-deduped': ['pythia-12b-deduped', 'EleutherAI/pythia-13b-deduped', 'pythia-13b-deduped'], 'EleutherAI/pythia-12b-deduped-v0': ['pythia-12b-deduped-v0', 'EleutherAI/pythia-13b-deduped-v0', 'pythia-13b-deduped-v0'], 'EleutherAI/pythia-12b-v0': ['pythia-12b-v0', 'EleutherAI/pythia-13b-v0', 'pythia-13b-v0'], 'EleutherAI/pythia-14m': ['pythia-14m'], 'EleutherAI/pythia-160m': ['pythia-160m', 'EleutherAI/pythia-125m', 'pythia-125m'], 'EleutherAI/pythia-160m-deduped': ['pythia-160m-deduped', 'EleutherAI/pythia-125m-deduped', 'pythia-125m-deduped'], 'EleutherAI/pythia-160m-deduped-v0': ['pythia-160m-deduped-v0', 'EleutherAI/pythia-125m-deduped-v0', 'pythia-125m-deduped-v0'], 'EleutherAI/pythia-160m-seed1': ['pythia-160m-seed1', 'EleutherAI/pythia-125m-seed1', 'pythia-125m-seed1'], 'EleutherAI/pythia-160m-seed2': ['pythia-160m-seed2', 'EleutherAI/pythia-125m-seed2', 'pythia-125m-seed2'], 'EleutherAI/pythia-160m-seed3': ['pythia-160m-seed3', 'EleutherAI/pythia-125m-seed3', 'pythia-125m-seed3'], 'EleutherAI/pythia-160m-v0': ['pythia-160m-v0', 'EleutherAI/pythia-125m-v0', 'pythia-125m-v0'], 'EleutherAI/pythia-1b': ['pythia-1b', 'EleutherAI/pythia-800m', 'pythia-800m'], 'EleutherAI/pythia-1b-deduped': ['pythia-1b-deduped', 'EleutherAI/pythia-800m-deduped', 'pythia-800m-deduped'], 'EleutherAI/pythia-1b-deduped-v0': ['pythia-1b-deduped-v0', 'EleutherAI/pythia-800m-deduped-v0', 'pythia-800m-deduped-v0'], 'EleutherAI/pythia-1b-v0': ['pythia-1b-v0', 'EleutherAI/pythia-800m-v0', 'pythia-800m-v0'], 'EleutherAI/pythia-2.8b': ['pythia-2.8b', 'EleutherAI/pythia-2.7b', 'pythia-2.7b'], 'EleutherAI/pythia-2.8b-deduped': ['pythia-2.8b-deduped', 'EleutherAI/pythia-2.7b-deduped', 'pythia-2.7b-deduped'], 'EleutherAI/pythia-2.8b-deduped-v0': ['pythia-2.8b-deduped-v0', 'EleutherAI/pythia-2.7b-deduped-v0', 'pythia-2.7b-deduped-v0'], 'EleutherAI/pythia-2.8b-v0': ['pythia-2.8b-v0', 'EleutherAI/pythia-2.7b-v0', 'pythia-2.7b-v0'], 'EleutherAI/pythia-31m': ['pythia-31m'], 'EleutherAI/pythia-410m': ['pythia-410m', 'EleutherAI/pythia-350m', 'pythia-350m'], 'EleutherAI/pythia-410m-deduped': ['pythia-410m-deduped', 'EleutherAI/pythia-350m-deduped', 'pythia-350m-deduped'], 'EleutherAI/pythia-410m-deduped-v0': ['pythia-410m-deduped-v0', 'EleutherAI/pythia-350m-deduped-v0', 'pythia-350m-deduped-v0'], 'EleutherAI/pythia-410m-v0': ['pythia-410m-v0', 'EleutherAI/pythia-350m-v0', 'pythia-350m-v0'], 'EleutherAI/pythia-6.9b': ['pythia-6.9b', 'EleutherAI/pythia-6.7b', 'pythia-6.7b'], 'EleutherAI/pythia-6.9b-deduped': ['pythia-6.9b-deduped', 'EleutherAI/pythia-6.7b-deduped', 'pythia-6.7b-deduped'], 'EleutherAI/pythia-6.9b-deduped-v0': ['pythia-6.9b-deduped-v0', 'EleutherAI/pythia-6.7b-deduped-v0', 'pythia-6.7b-deduped-v0'], 'EleutherAI/pythia-6.9b-v0': ['pythia-6.9b-v0', 'EleutherAI/pythia-6.7b-v0', 'pythia-6.7b-v0'], 'EleutherAI/pythia-70m': ['pythia-70m', 'pythia', 'EleutherAI/pythia-19m', 'pythia-19m'], 'EleutherAI/pythia-70m-deduped': ['pythia-70m-deduped', 'EleutherAI/pythia-19m-deduped', 'pythia-19m-deduped'], 'EleutherAI/pythia-70m-deduped-v0': ['pythia-70m-deduped-v0', 'EleutherAI/pythia-19m-deduped-v0', 'pythia-19m-deduped-v0'], 'EleutherAI/pythia-70m-v0': ['pythia-70m-v0', 'pythia-v0', 'EleutherAI/pythia-19m-v0', 'pythia-19m-v0'], 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr': ['attn-only-2l-demo', 'attn-only-2l-shortformer-6b-big-lr', 'attn-only-2l-induction-demo', 'attn-only-demo'], 'NeelNanda/Attn_Only_1L512W_C4_Code': ['attn-only-1l', 'attn-only-1l-new', 'attn-only-1l-c4-code'], 'NeelNanda/Attn_Only_2L512W_C4_Code': ['attn-only-2l', 'attn-only-2l-new', 'attn-only-2l-c4-code'], 'NeelNanda/Attn_Only_3L512W_C4_Code': ['attn-only-3l', 'attn-only-3l-new', 'attn-only-3l-c4-code'], 'NeelNanda/Attn_Only_4L512W_C4_Code': ['attn-only-4l', 'attn-only-4l-new', 'attn-only-4l-c4-code'], 'NeelNanda/GELU_1L512W_C4_Code': ['gelu-1l', 'gelu-1l-new', 'gelu-1l-c4-code'], 'NeelNanda/GELU_2L512W_C4_Code': ['gelu-2l', 'gelu-2l-new', 'gelu-2l-c4-code'], 'NeelNanda/GELU_3L512W_C4_Code': ['gelu-3l', 'gelu-3l-new', 'gelu-3l-c4-code'], 'NeelNanda/GELU_4L512W_C4_Code': ['gelu-4l', 'gelu-4l-new', 'gelu-4l-c4-code'], 'NeelNanda/SoLU_10L1280W_C4_Code': ['solu-10l', 'solu-10l-new', 'solu-10l-c4-code'], 'NeelNanda/SoLU_10L_v22_old': ['solu-10l-pile', 'solu-10l-old'], 'NeelNanda/SoLU_12L1536W_C4_Code': ['solu-12l', 'solu-12l-new', 'solu-12l-c4-code'], 'NeelNanda/SoLU_12L_v23_old': ['solu-12l-pile', 'solu-12l-old'], 'NeelNanda/SoLU_1L512W_C4_Code': ['solu-1l', 'solu-1l-new', 'solu-1l-c4-code'], 'NeelNanda/SoLU_1L512W_Wiki_Finetune': ['solu-1l-wiki', 'solu-1l-wiki-finetune', 'solu-1l-finetune'], 'NeelNanda/SoLU_1L_v9_old': ['solu-1l-pile', 'solu-1l-old'], 'NeelNanda/SoLU_2L512W_C4_Code': ['solu-2l', 'solu-2l-new', 'solu-2l-c4-code'], 'NeelNanda/SoLU_2L_v10_old': ['solu-2l-pile', 'solu-2l-old'], 'NeelNanda/SoLU_3L512W_C4_Code': ['solu-3l', 'solu-3l-new', 'solu-3l-c4-code'], 'NeelNanda/SoLU_4L512W_C4_Code': ['solu-4l', 'solu-4l-new', 'solu-4l-c4-code'], 'NeelNanda/SoLU_4L512W_Wiki_Finetune': ['solu-4l-wiki', 'solu-4l-wiki-finetune', 'solu-4l-finetune'], 'NeelNanda/SoLU_4L_v11_old': ['solu-4l-pile', 'solu-4l-old'], 'NeelNanda/SoLU_6L768W_C4_Code': ['solu-6l', 'solu-6l-new', 'solu-6l-c4-code'], 'NeelNanda/SoLU_6L_v13_old': ['solu-6l-pile', 'solu-6l-old'], 'NeelNanda/SoLU_8L1024W_C4_Code': ['solu-8l', 'solu-8l-new', 'solu-8l-c4-code'], 'NeelNanda/SoLU_8L_v21_old': ['solu-8l-pile', 'solu-8l-old'], 'Qwen/Qwen-14B': ['qwen-14b'], 'Qwen/Qwen-14B-Chat': ['qwen-14b-chat'], 'Qwen/Qwen-1_8B': ['qwen-1.8b'], 'Qwen/Qwen-1_8B-Chat': ['qwen-1.8b-chat'], 'Qwen/Qwen-7B': ['qwen-7b'], 'Qwen/Qwen-7B-Chat': ['qwen-7b-chat'], 'Qwen/Qwen1.5-0.5B': ['qwen1.5-0.5b'], 'Qwen/Qwen1.5-0.5B-Chat': ['qwen1.5-0.5b-chat'], 'Qwen/Qwen1.5-1.8B': ['qwen1.5-1.8b'], 'Qwen/Qwen1.5-1.8B-Chat': ['qwen1.5-1.8b-chat'], 'Qwen/Qwen1.5-14B': ['qwen1.5-14b'], 'Qwen/Qwen1.5-14B-Chat': ['qwen1.5-14b-chat'], 'Qwen/Qwen1.5-4B': ['qwen1.5-4b'], 'Qwen/Qwen1.5-4B-Chat': ['qwen1.5-4b-chat'], 'Qwen/Qwen1.5-7B': ['qwen1.5-7b'], 'Qwen/Qwen1.5-7B-Chat': ['qwen1.5-7b-chat'], 'ai-forever/mGPT': ['mGPT'], 'bigcode/santacoder': ['santacoder'], 'bigscience/bloom-1b1': ['bloom-1b1'], 'bigscience/bloom-1b7': ['bloom-1b7'], 'bigscience/bloom-3b': ['bloom-3b'], 'bigscience/bloom-560m': ['bloom-560m'], 'bigscience/bloom-7b1': ['bloom-7b1'], 'distilgpt2': ['distillgpt2', 'distill-gpt2', 'distil-gpt2', 'gpt2-xs'], 'facebook/opt-1.3b': ['opt-1.3b', 'opt-medium'], 'facebook/opt-125m': ['opt-125m', 'opt-small', 'opt'], 'facebook/opt-13b': ['opt-13b', 'opt-xxl'], 'facebook/opt-2.7b': ['opt-2.7b', 'opt-large'], 'facebook/opt-30b': ['opt-30b', 'opt-xxxl'], 'facebook/opt-6.7b': ['opt-6.7b', 'opt-xl'], 'facebook/opt-66b': ['opt-66b', 'opt-xxxxl'], 'google-t5/t5-base': ['t5-base'], 'google-t5/t5-large': ['t5-large'], 'google-t5/t5-small': ['t5-small'], 'google/gemma-2-27b': ['gemma-2-27b'], 'google/gemma-2-27b-it': ['gemma-2-27b-it'], 'google/gemma-2-2b': ['gemma-2-2b'], 'google/gemma-2-2b-it': ['gemma-2-2b-it'], 'google/gemma-2-9b': ['gemma-2-9b'], 'google/gemma-2-9b-it': ['gemma-2-9b-it'], 'google/gemma-2b': ['gemma-2b'], 'google/gemma-2b-it': ['gemma-2b-it'], 'google/gemma-7b': ['gemma-7b'], 'google/gemma-7b-it': ['gemma-7b-it'], 'gpt2': ['gpt2-small'], 'llama-13b-hf': ['llama-13b'], 'llama-30b-hf': ['llama-30b'], 'llama-65b-hf': ['llama-65b'], 'llama-7b-hf': ['llama-7b'], 'meta-llama/Llama-2-13b-chat-hf': ['Llama-2-13b-chat', 'meta-llama/Llama-2-13b-chat-hf'], 'meta-llama/Llama-2-13b-hf': ['Llama-2-13b', 'meta-llama/Llama-2-13b-hf'], 'meta-llama/Llama-2-70b-chat-hf': ['Llama-2-70b-chat', 'meta-llama-2-70b-chat-hf'], 'meta-llama/Llama-2-7b-chat-hf': ['Llama-2-7b-chat', 'meta-llama/Llama-2-7b-chat-hf'], 'meta-llama/Llama-2-7b-hf': ['Llama-2-7b', 'meta-llama/Llama-2-7b-hf'], 'microsoft/Phi-3-mini-4k-instruct': ['phi-3'], 'microsoft/phi-1': ['phi-1'], 'microsoft/phi-1_5': ['phi-1_5'], 'microsoft/phi-2': ['phi-2'], 'mistralai/Mistral-7B-Instruct-v0.1': ['mistral-7b-instruct'], 'mistralai/Mistral-7B-v0.1': ['mistral-7b'], 'mistralai/Mixtral-8x7B-Instruct-v0.1': ['mixtral-instruct', 'mixtral-8x7b-instruct'], 'mistralai/Mixtral-8x7B-v0.1': ['mixtral', 'mixtral-8x7b'], 'roneneldan/TinyStories-1Layer-21M': ['tiny-stories-1L-21M'], 'roneneldan/TinyStories-1M': ['tiny-stories-1M'], 'roneneldan/TinyStories-28M': ['tiny-stories-28M'], 'roneneldan/TinyStories-2Layers-33M': ['tiny-stories-2L-33M'], 'roneneldan/TinyStories-33M': ['tiny-stories-33M'], 'roneneldan/TinyStories-3M': ['tiny-stories-3M'], 'roneneldan/TinyStories-8M': ['tiny-stories-8M'], 'roneneldan/TinyStories-Instruct-1M': ['tiny-stories-instruct-1M'], 'roneneldan/TinyStories-Instruct-28M': ['tiny-stories-instruct-28M'], 'roneneldan/TinyStories-Instruct-2Layers-33M': ['tiny-stories-instruct-2L-33M'], 'roneneldan/TinyStories-Instruct-33M': ['tiny-stories-instruct-33M'], 'roneneldan/TinyStories-Instruct-3M': ['tiny-stories-instruct-3M'], 'roneneldan/TinyStories-Instruct-8M': ['tiny-stories-instruct-8M'], 'roneneldan/TinyStories-Instuct-1Layer-21M': ['tiny-stories-instruct-1L-21M'], 'stabilityai/stablelm-base-alpha-3b': ['stablelm-base-alpha-3b', 'stablelm-base-3b'], 'stabilityai/stablelm-base-alpha-7b': ['stablelm-base-alpha-7b', 'stablelm-base-7b'], 'stabilityai/stablelm-tuned-alpha-3b': ['stablelm-tuned-alpha-3b', 'stablelm-tuned-3b'], 'stabilityai/stablelm-tuned-alpha-7b': ['stablelm-tuned-alpha-7b', 'stablelm-tuned-7b'], 'stanford-crfm/alias-gpt2-small-x21': ['stanford-gpt2-small-a', 'alias-gpt2-small-x21', 'gpt2-mistral-small-a', 'gpt2-stanford-small-a'], 'stanford-crfm/arwen-gpt2-medium-x21': ['stanford-gpt2-medium-a', 'arwen-gpt2-medium-x21', 'gpt2-medium-small-a', 'gpt2-stanford-medium-a'], 'stanford-crfm/battlestar-gpt2-small-x49': ['stanford-gpt2-small-b', 'battlestar-gpt2-small-x49', 'gpt2-mistral-small-b', 'gpt2-mistral-small-b'], 'stanford-crfm/beren-gpt2-medium-x49': ['stanford-gpt2-medium-b', 'beren-gpt2-medium-x49', 'gpt2-medium-small-b', 'gpt2-stanford-medium-b'], 'stanford-crfm/caprica-gpt2-small-x81': ['stanford-gpt2-small-c', 'caprica-gpt2-small-x81', 'gpt2-mistral-small-c', 'gpt2-stanford-small-c'], 'stanford-crfm/celebrimbor-gpt2-medium-x81': ['stanford-gpt2-medium-c', 'celebrimbor-gpt2-medium-x81', 'gpt2-medium-small-c', 'gpt2-medium-small-c'], 'stanford-crfm/darkmatter-gpt2-small-x343': ['stanford-gpt2-small-d', 'darkmatter-gpt2-small-x343', 'gpt2-mistral-small-d', 'gpt2-mistral-small-d'], 'stanford-crfm/durin-gpt2-medium-x343': ['stanford-gpt2-medium-d', 'durin-gpt2-medium-x343', 'gpt2-medium-small-d', 'gpt2-stanford-medium-d'], 'stanford-crfm/eowyn-gpt2-medium-x777': ['stanford-gpt2-medium-e', 'eowyn-gpt2-medium-x777', 'gpt2-medium-small-e', 'gpt2-stanford-medium-e'], 'stanford-crfm/expanse-gpt2-small-x777': ['stanford-gpt2-small-e', 'expanse-gpt2-small-x777', 'gpt2-mistral-small-e', 'gpt2-mistral-small-e']}#

Model aliases for models on HuggingFace.

transformer_lens.loading_from_pretrained.NON_HF_HOSTED_MODEL_NAMES = ['llama-7b-hf', 'llama-13b-hf', 'llama-30b-hf', 'llama-65b-hf']#

Official model names for models not hosted on HuggingFace.

transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES = ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'distilgpt2', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neox-20b', 'stanford-crfm/alias-gpt2-small-x21', 'stanford-crfm/battlestar-gpt2-small-x49', 'stanford-crfm/caprica-gpt2-small-x81', 'stanford-crfm/darkmatter-gpt2-small-x343', 'stanford-crfm/expanse-gpt2-small-x777', 'stanford-crfm/arwen-gpt2-medium-x21', 'stanford-crfm/beren-gpt2-medium-x49', 'stanford-crfm/celebrimbor-gpt2-medium-x81', 'stanford-crfm/durin-gpt2-medium-x343', 'stanford-crfm/eowyn-gpt2-medium-x777', 'EleutherAI/pythia-14m', 'EleutherAI/pythia-31m', 'EleutherAI/pythia-70m', 'EleutherAI/pythia-160m', 'EleutherAI/pythia-410m', 'EleutherAI/pythia-1b', 'EleutherAI/pythia-1.4b', 'EleutherAI/pythia-2.8b', 'EleutherAI/pythia-6.9b', 'EleutherAI/pythia-12b', 'EleutherAI/pythia-70m-deduped', 'EleutherAI/pythia-160m-deduped', 'EleutherAI/pythia-410m-deduped', 'EleutherAI/pythia-1b-deduped', 'EleutherAI/pythia-1.4b-deduped', 'EleutherAI/pythia-2.8b-deduped', 'EleutherAI/pythia-6.9b-deduped', 'EleutherAI/pythia-12b-deduped', 'EleutherAI/pythia-70m-v0', 'EleutherAI/pythia-160m-v0', 'EleutherAI/pythia-410m-v0', 'EleutherAI/pythia-1b-v0', 'EleutherAI/pythia-1.4b-v0', 'EleutherAI/pythia-2.8b-v0', 'EleutherAI/pythia-6.9b-v0', 'EleutherAI/pythia-12b-v0', 'EleutherAI/pythia-70m-deduped-v0', 'EleutherAI/pythia-160m-deduped-v0', 'EleutherAI/pythia-410m-deduped-v0', 'EleutherAI/pythia-1b-deduped-v0', 'EleutherAI/pythia-1.4b-deduped-v0', 'EleutherAI/pythia-2.8b-deduped-v0', 'EleutherAI/pythia-6.9b-deduped-v0', 'EleutherAI/pythia-12b-deduped-v0', 'EleutherAI/pythia-160m-seed1', 'EleutherAI/pythia-160m-seed2', 'EleutherAI/pythia-160m-seed3', 'NeelNanda/SoLU_1L_v9_old', 'NeelNanda/SoLU_2L_v10_old', 'NeelNanda/SoLU_4L_v11_old', 'NeelNanda/SoLU_6L_v13_old', 'NeelNanda/SoLU_8L_v21_old', 'NeelNanda/SoLU_10L_v22_old', 'NeelNanda/SoLU_12L_v23_old', 'NeelNanda/SoLU_1L512W_C4_Code', 'NeelNanda/SoLU_2L512W_C4_Code', 'NeelNanda/SoLU_3L512W_C4_Code', 'NeelNanda/SoLU_4L512W_C4_Code', 'NeelNanda/SoLU_6L768W_C4_Code', 'NeelNanda/SoLU_8L1024W_C4_Code', 'NeelNanda/SoLU_10L1280W_C4_Code', 'NeelNanda/SoLU_12L1536W_C4_Code', 'NeelNanda/GELU_1L512W_C4_Code', 'NeelNanda/GELU_2L512W_C4_Code', 'NeelNanda/GELU_3L512W_C4_Code', 'NeelNanda/GELU_4L512W_C4_Code', 'NeelNanda/Attn_Only_1L512W_C4_Code', 'NeelNanda/Attn_Only_2L512W_C4_Code', 'NeelNanda/Attn_Only_3L512W_C4_Code', 'NeelNanda/Attn_Only_4L512W_C4_Code', 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr', 'NeelNanda/SoLU_1L512W_Wiki_Finetune', 'NeelNanda/SoLU_4L512W_Wiki_Finetune', 'ArthurConmy/redwood_attn_2l', 'llama-7b-hf', 'llama-13b-hf', 'llama-30b-hf', 'llama-65b-hf', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-7b-chat-hf', 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-70b-chat-hf', 'CodeLlama-7b-hf', 'CodeLlama-7b-Python-hf', 'CodeLlama-7b-Instruct-hf', 'meta-llama/Meta-Llama-3-8B', 'meta-llama/Meta-Llama-3-8B-Instruct', 'meta-llama/Meta-Llama-3-70B', 'meta-llama/Meta-Llama-3-70B-Instruct', 'Baidicoot/Othello-GPT-Transformer-Lens', 'bert-base-cased', 'roneneldan/TinyStories-1M', 'roneneldan/TinyStories-3M', 'roneneldan/TinyStories-8M', 'roneneldan/TinyStories-28M', 'roneneldan/TinyStories-33M', 'roneneldan/TinyStories-Instruct-1M', 'roneneldan/TinyStories-Instruct-3M', 'roneneldan/TinyStories-Instruct-8M', 'roneneldan/TinyStories-Instruct-28M', 'roneneldan/TinyStories-Instruct-33M', 'roneneldan/TinyStories-1Layer-21M', 'roneneldan/TinyStories-2Layers-33M', 'roneneldan/TinyStories-Instuct-1Layer-21M', 'roneneldan/TinyStories-Instruct-2Layers-33M', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b', 'stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1', 'mistralai/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'bigscience/bloom-560m', 'bigscience/bloom-1b1', 'bigscience/bloom-1b7', 'bigscience/bloom-3b', 'bigscience/bloom-7b1', 'bigcode/santacoder', 'Qwen/Qwen-1_8B', 'Qwen/Qwen-7B', 'Qwen/Qwen-14B', 'Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 'Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-7B', 'Qwen/Qwen2-7B-Instruct', 'microsoft/phi-1', 'microsoft/phi-1_5', 'microsoft/phi-2', 'microsoft/Phi-3-mini-4k-instruct', 'google/gemma-2b', 'google/gemma-7b', 'google/gemma-2b-it', 'google/gemma-7b-it', 'google/gemma-2-2b', 'google/gemma-2-2b-it', 'google/gemma-2-9b', 'google/gemma-2-9b-it', 'google/gemma-2-27b', 'google/gemma-2-27b-it', '01-ai/Yi-6B', '01-ai/Yi-34B', '01-ai/Yi-6B-Chat', '01-ai/Yi-34B-Chat', 'google-t5/t5-small', 'google-t5/t5-base', 'google-t5/t5-large', 'ai-forever/mGPT']#

Official model names for models on HuggingFace.

transformer_lens.loading_from_pretrained.get_checkpoint_labels(model_name: str, **kwargs)#

Returns the checkpoint labels for a given model, and the label_type (step or token). Raises an error for models that are not checkpointed.

transformer_lens.loading_from_pretrained.get_num_params_of_pretrained(model_name)#

Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.

transformer_lens.loading_from_pretrained.get_pretrained_model_config(model_name: str, hf_cfg: Optional[dict] = None, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, fold_ln: bool = False, device: Optional[Union[str, device]] = None, n_devices: int = 1, default_prepend_bos: bool = True, dtype: dtype = torch.float32, first_n_layers: Optional[int] = None, **kwargs)#

Returns the pretrained model config as an HookedTransformerConfig object.

There are two types of pretrained models: HuggingFace models (where AutoModel and AutoConfig work), and models trained by me (NeelNanda) which aren’t as integrated with HuggingFace infrastructure.

Parameters:
  • model_name – The name of the model. This can be either the official HuggingFace model name, or the name of a model trained by me (NeelNanda).

  • hf_cfg (dict, optional) – Config of a loaded pretrained HF model, converted to a dictionary.

  • checkpoint_index (int, optional) – If loading from a checkpoint, the index of the checkpoint to load. Defaults to None.

  • checkpoint_value (int, optional) – If loading from a checkpoint, the

  • of (value) – the checkpoint to load, ie the step or token number (each model has checkpoints labelled with exactly one of these). Defaults to None.

  • fold_ln (bool, optional) – Whether to fold the layer norm into the subsequent linear layers (see HookedTransformer.fold_layer_norm for details). Defaults to False.

  • device (str, optional) – The device to load the model onto. By default will load to CUDA if available, else CPU.

  • n_devices (int, optional) – The number of devices to split the model across. Defaults to 1.

  • default_prepend_bos (bool, optional) – Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only when input is a string). Defaults to True - even for models not explicitly trained with this, heads often use the first position as a resting position and accordingly lose information from the first token, so this empirically seems to give better results. To change the default behavior to False, pass in default_prepend_bos=False. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string.

  • dtype (torch.dtype, optional) – The dtype to load the TransformerLens model in.

  • kwargs – Other optional arguments passed to HuggingFace’s from_pretrained. Also given to other HuggingFace functions when compatible.