transformer_lens.loading_from_pretrained#
Loading Pretrained Models Utilities.
This module contains functions for loading pretrained models from the Hugging Face Hub.
- class transformer_lens.loading_from_pretrained.Config(d_model: int = 768, debug: bool = True, layer_norm_eps: float = 1e-05, d_vocab: int = 50257, init_range: float = 0.02, n_ctx: int = 1024, d_head: int = 64, d_mlp: int = 3072, n_heads: int = 12, n_layers: int = 12)#
Bases:
object
- d_head: int = 64#
- d_mlp: int = 3072#
- d_model: int = 768#
- d_vocab: int = 50257#
- debug: bool = True#
- init_range: float = 0.02#
- layer_norm_eps: float = 1e-05#
- n_ctx: int = 1024#
- n_heads: int = 12#
- n_layers: int = 12#
- transformer_lens.loading_from_pretrained.MODEL_ALIASES = {'01-ai/Yi-34B': ['yi-34b', 'Yi-34B'], '01-ai/Yi-34B-Chat': ['yi-34b-chat', 'Yi-34B-Chat'], '01-ai/Yi-6B': ['yi-6b', 'Yi-6B'], '01-ai/Yi-6B-Chat': ['yi-6b-chat', 'Yi-6B-Chat'], 'ArthurConmy/redwood_attn_2l': ['redwood_attn_2l'], 'Baidicoot/Othello-GPT-Transformer-Lens': ['othello-gpt'], 'EleutherAI/gpt-j-6B': ['gpt-j-6B', 'gpt-j', 'gptj'], 'EleutherAI/gpt-neo-1.3B': ['gpt-neo-1.3B', 'gpt-neo-medium', 'neo-medium'], 'EleutherAI/gpt-neo-125M': ['gpt-neo-125M', 'gpt-neo-small', 'neo-small', 'neo'], 'EleutherAI/gpt-neo-2.7B': ['gpt-neo-2.7B', 'gpt-neo-large', 'neo-large'], 'EleutherAI/gpt-neox-20b': ['gpt-neox-20b', 'gpt-neox', 'neox'], 'EleutherAI/pythia-1.4b': ['pythia-1.4b', 'EleutherAI/pythia-1.3b', 'pythia-1.3b'], 'EleutherAI/pythia-1.4b-deduped': ['pythia-1.4b-deduped', 'EleutherAI/pythia-1.3b-deduped', 'pythia-1.3b-deduped'], 'EleutherAI/pythia-1.4b-deduped-v0': ['pythia-1.4b-deduped-v0', 'EleutherAI/pythia-1.3b-deduped-v0', 'pythia-1.3b-deduped-v0'], 'EleutherAI/pythia-1.4b-v0': ['pythia-1.4b-v0', 'EleutherAI/pythia-1.3b-v0', 'pythia-1.3b-v0'], 'EleutherAI/pythia-12b': ['pythia-12b', 'EleutherAI/pythia-13b', 'pythia-13b'], 'EleutherAI/pythia-12b-deduped': ['pythia-12b-deduped', 'EleutherAI/pythia-13b-deduped', 'pythia-13b-deduped'], 'EleutherAI/pythia-12b-deduped-v0': ['pythia-12b-deduped-v0', 'EleutherAI/pythia-13b-deduped-v0', 'pythia-13b-deduped-v0'], 'EleutherAI/pythia-12b-v0': ['pythia-12b-v0', 'EleutherAI/pythia-13b-v0', 'pythia-13b-v0'], 'EleutherAI/pythia-14m': ['pythia-14m'], 'EleutherAI/pythia-160m': ['pythia-160m', 'EleutherAI/pythia-125m', 'pythia-125m'], 'EleutherAI/pythia-160m-deduped': ['pythia-160m-deduped', 'EleutherAI/pythia-125m-deduped', 'pythia-125m-deduped'], 'EleutherAI/pythia-160m-deduped-v0': ['pythia-160m-deduped-v0', 'EleutherAI/pythia-125m-deduped-v0', 'pythia-125m-deduped-v0'], 'EleutherAI/pythia-160m-seed1': ['pythia-160m-seed1', 'EleutherAI/pythia-125m-seed1', 'pythia-125m-seed1'], 'EleutherAI/pythia-160m-seed2': ['pythia-160m-seed2', 'EleutherAI/pythia-125m-seed2', 'pythia-125m-seed2'], 'EleutherAI/pythia-160m-seed3': ['pythia-160m-seed3', 'EleutherAI/pythia-125m-seed3', 'pythia-125m-seed3'], 'EleutherAI/pythia-160m-v0': ['pythia-160m-v0', 'EleutherAI/pythia-125m-v0', 'pythia-125m-v0'], 'EleutherAI/pythia-1b': ['pythia-1b', 'EleutherAI/pythia-800m', 'pythia-800m'], 'EleutherAI/pythia-1b-deduped': ['pythia-1b-deduped', 'EleutherAI/pythia-800m-deduped', 'pythia-800m-deduped'], 'EleutherAI/pythia-1b-deduped-v0': ['pythia-1b-deduped-v0', 'EleutherAI/pythia-800m-deduped-v0', 'pythia-800m-deduped-v0'], 'EleutherAI/pythia-1b-v0': ['pythia-1b-v0', 'EleutherAI/pythia-800m-v0', 'pythia-800m-v0'], 'EleutherAI/pythia-2.8b': ['pythia-2.8b', 'EleutherAI/pythia-2.7b', 'pythia-2.7b'], 'EleutherAI/pythia-2.8b-deduped': ['pythia-2.8b-deduped', 'EleutherAI/pythia-2.7b-deduped', 'pythia-2.7b-deduped'], 'EleutherAI/pythia-2.8b-deduped-v0': ['pythia-2.8b-deduped-v0', 'EleutherAI/pythia-2.7b-deduped-v0', 'pythia-2.7b-deduped-v0'], 'EleutherAI/pythia-2.8b-v0': ['pythia-2.8b-v0', 'EleutherAI/pythia-2.7b-v0', 'pythia-2.7b-v0'], 'EleutherAI/pythia-31m': ['pythia-31m'], 'EleutherAI/pythia-410m': ['pythia-410m', 'EleutherAI/pythia-350m', 'pythia-350m'], 'EleutherAI/pythia-410m-deduped': ['pythia-410m-deduped', 'EleutherAI/pythia-350m-deduped', 'pythia-350m-deduped'], 'EleutherAI/pythia-410m-deduped-v0': ['pythia-410m-deduped-v0', 'EleutherAI/pythia-350m-deduped-v0', 'pythia-350m-deduped-v0'], 'EleutherAI/pythia-410m-v0': ['pythia-410m-v0', 'EleutherAI/pythia-350m-v0', 'pythia-350m-v0'], 'EleutherAI/pythia-6.9b': ['pythia-6.9b', 'EleutherAI/pythia-6.7b', 'pythia-6.7b'], 'EleutherAI/pythia-6.9b-deduped': ['pythia-6.9b-deduped', 'EleutherAI/pythia-6.7b-deduped', 'pythia-6.7b-deduped'], 'EleutherAI/pythia-6.9b-deduped-v0': ['pythia-6.9b-deduped-v0', 'EleutherAI/pythia-6.7b-deduped-v0', 'pythia-6.7b-deduped-v0'], 'EleutherAI/pythia-6.9b-v0': ['pythia-6.9b-v0', 'EleutherAI/pythia-6.7b-v0', 'pythia-6.7b-v0'], 'EleutherAI/pythia-70m': ['pythia-70m', 'pythia', 'EleutherAI/pythia-19m', 'pythia-19m'], 'EleutherAI/pythia-70m-deduped': ['pythia-70m-deduped', 'EleutherAI/pythia-19m-deduped', 'pythia-19m-deduped'], 'EleutherAI/pythia-70m-deduped-v0': ['pythia-70m-deduped-v0', 'EleutherAI/pythia-19m-deduped-v0', 'pythia-19m-deduped-v0'], 'EleutherAI/pythia-70m-v0': ['pythia-70m-v0', 'pythia-v0', 'EleutherAI/pythia-19m-v0', 'pythia-19m-v0'], 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr': ['attn-only-2l-demo', 'attn-only-2l-shortformer-6b-big-lr', 'attn-only-2l-induction-demo', 'attn-only-demo'], 'NeelNanda/Attn_Only_1L512W_C4_Code': ['attn-only-1l', 'attn-only-1l-new', 'attn-only-1l-c4-code'], 'NeelNanda/Attn_Only_2L512W_C4_Code': ['attn-only-2l', 'attn-only-2l-new', 'attn-only-2l-c4-code'], 'NeelNanda/Attn_Only_3L512W_C4_Code': ['attn-only-3l', 'attn-only-3l-new', 'attn-only-3l-c4-code'], 'NeelNanda/Attn_Only_4L512W_C4_Code': ['attn-only-4l', 'attn-only-4l-new', 'attn-only-4l-c4-code'], 'NeelNanda/GELU_1L512W_C4_Code': ['gelu-1l', 'gelu-1l-new', 'gelu-1l-c4-code'], 'NeelNanda/GELU_2L512W_C4_Code': ['gelu-2l', 'gelu-2l-new', 'gelu-2l-c4-code'], 'NeelNanda/GELU_3L512W_C4_Code': ['gelu-3l', 'gelu-3l-new', 'gelu-3l-c4-code'], 'NeelNanda/GELU_4L512W_C4_Code': ['gelu-4l', 'gelu-4l-new', 'gelu-4l-c4-code'], 'NeelNanda/SoLU_10L1280W_C4_Code': ['solu-10l', 'solu-10l-new', 'solu-10l-c4-code'], 'NeelNanda/SoLU_10L_v22_old': ['solu-10l-pile', 'solu-10l-old'], 'NeelNanda/SoLU_12L1536W_C4_Code': ['solu-12l', 'solu-12l-new', 'solu-12l-c4-code'], 'NeelNanda/SoLU_12L_v23_old': ['solu-12l-pile', 'solu-12l-old'], 'NeelNanda/SoLU_1L512W_C4_Code': ['solu-1l', 'solu-1l-new', 'solu-1l-c4-code'], 'NeelNanda/SoLU_1L512W_Wiki_Finetune': ['solu-1l-wiki', 'solu-1l-wiki-finetune', 'solu-1l-finetune'], 'NeelNanda/SoLU_1L_v9_old': ['solu-1l-pile', 'solu-1l-old'], 'NeelNanda/SoLU_2L512W_C4_Code': ['solu-2l', 'solu-2l-new', 'solu-2l-c4-code'], 'NeelNanda/SoLU_2L_v10_old': ['solu-2l-pile', 'solu-2l-old'], 'NeelNanda/SoLU_3L512W_C4_Code': ['solu-3l', 'solu-3l-new', 'solu-3l-c4-code'], 'NeelNanda/SoLU_4L512W_C4_Code': ['solu-4l', 'solu-4l-new', 'solu-4l-c4-code'], 'NeelNanda/SoLU_4L512W_Wiki_Finetune': ['solu-4l-wiki', 'solu-4l-wiki-finetune', 'solu-4l-finetune'], 'NeelNanda/SoLU_4L_v11_old': ['solu-4l-pile', 'solu-4l-old'], 'NeelNanda/SoLU_6L768W_C4_Code': ['solu-6l', 'solu-6l-new', 'solu-6l-c4-code'], 'NeelNanda/SoLU_6L_v13_old': ['solu-6l-pile', 'solu-6l-old'], 'NeelNanda/SoLU_8L1024W_C4_Code': ['solu-8l', 'solu-8l-new', 'solu-8l-c4-code'], 'NeelNanda/SoLU_8L_v21_old': ['solu-8l-pile', 'solu-8l-old'], 'Qwen/Qwen-14B': ['qwen-14b'], 'Qwen/Qwen-14B-Chat': ['qwen-14b-chat'], 'Qwen/Qwen-1_8B': ['qwen-1.8b'], 'Qwen/Qwen-1_8B-Chat': ['qwen-1.8b-chat'], 'Qwen/Qwen-7B': ['qwen-7b'], 'Qwen/Qwen-7B-Chat': ['qwen-7b-chat'], 'Qwen/Qwen1.5-0.5B': ['qwen1.5-0.5b'], 'Qwen/Qwen1.5-0.5B-Chat': ['qwen1.5-0.5b-chat'], 'Qwen/Qwen1.5-1.8B': ['qwen1.5-1.8b'], 'Qwen/Qwen1.5-1.8B-Chat': ['qwen1.5-1.8b-chat'], 'Qwen/Qwen1.5-14B': ['qwen1.5-14b'], 'Qwen/Qwen1.5-14B-Chat': ['qwen1.5-14b-chat'], 'Qwen/Qwen1.5-4B': ['qwen1.5-4b'], 'Qwen/Qwen1.5-4B-Chat': ['qwen1.5-4b-chat'], 'Qwen/Qwen1.5-7B': ['qwen1.5-7b'], 'Qwen/Qwen1.5-7B-Chat': ['qwen1.5-7b-chat'], 'ai-forever/mGPT': ['mGPT'], 'bigcode/santacoder': ['santacoder'], 'bigscience/bloom-1b1': ['bloom-1b1'], 'bigscience/bloom-1b7': ['bloom-1b7'], 'bigscience/bloom-3b': ['bloom-3b'], 'bigscience/bloom-560m': ['bloom-560m'], 'bigscience/bloom-7b1': ['bloom-7b1'], 'codellama/CodeLlama-7b-Instruct-hf': ['CodeLlama-7b-instruct', 'codellama/CodeLlama-7b-Instruct-hf'], 'codellama/CodeLlama-7b-Python-hf': ['CodeLlama-7b-python', 'codellama/CodeLlama-7b-Python-hf'], 'codellama/CodeLlama-7b-hf': ['CodeLlamallama-2-7b', 'codellama/CodeLlama-7b-hf'], 'distilgpt2': ['distillgpt2', 'distill-gpt2', 'distil-gpt2', 'gpt2-xs'], 'facebook/opt-1.3b': ['opt-1.3b', 'opt-medium'], 'facebook/opt-125m': ['opt-125m', 'opt-small', 'opt'], 'facebook/opt-13b': ['opt-13b', 'opt-xxl'], 'facebook/opt-2.7b': ['opt-2.7b', 'opt-large'], 'facebook/opt-30b': ['opt-30b', 'opt-xxxl'], 'facebook/opt-6.7b': ['opt-6.7b', 'opt-xl'], 'facebook/opt-66b': ['opt-66b', 'opt-xxxxl'], 'google-t5/t5-base': ['t5-base'], 'google-t5/t5-large': ['t5-large'], 'google-t5/t5-small': ['t5-small'], 'google/gemma-2-27b': ['gemma-2-27b'], 'google/gemma-2-27b-it': ['gemma-2-27b-it'], 'google/gemma-2-2b': ['gemma-2-2b'], 'google/gemma-2-2b-it': ['gemma-2-2b-it'], 'google/gemma-2-9b': ['gemma-2-9b'], 'google/gemma-2-9b-it': ['gemma-2-9b-it'], 'google/gemma-2b': ['gemma-2b'], 'google/gemma-2b-it': ['gemma-2b-it'], 'google/gemma-7b': ['gemma-7b'], 'google/gemma-7b-it': ['gemma-7b-it'], 'gpt2': ['gpt2-small'], 'llama-13b-hf': ['llama-13b'], 'llama-30b-hf': ['llama-30b'], 'llama-65b-hf': ['llama-65b'], 'llama-7b-hf': ['llama-7b'], 'meta-llama/Llama-2-13b-chat-hf': ['Llama-2-13b-chat', 'meta-llama/Llama-2-13b-chat-hf'], 'meta-llama/Llama-2-13b-hf': ['Llama-2-13b', 'meta-llama/Llama-2-13b-hf'], 'meta-llama/Llama-2-70b-chat-hf': ['Llama-2-70b-chat', 'meta-llama-2-70b-chat-hf'], 'meta-llama/Llama-2-7b-chat-hf': ['Llama-2-7b-chat', 'meta-llama/Llama-2-7b-chat-hf'], 'meta-llama/Llama-2-7b-hf': ['Llama-2-7b', 'meta-llama/Llama-2-7b-hf'], 'microsoft/Phi-3-mini-4k-instruct': ['phi-3'], 'microsoft/phi-1': ['phi-1'], 'microsoft/phi-1_5': ['phi-1_5'], 'microsoft/phi-2': ['phi-2'], 'mistralai/Mistral-7B-Instruct-v0.1': ['mistral-7b-instruct'], 'mistralai/Mistral-7B-v0.1': ['mistral-7b'], 'mistralai/Mistral-Nemo-Base-2407': ['mistral-nemo-base-2407'], 'mistralai/Mixtral-8x7B-Instruct-v0.1': ['mixtral-instruct', 'mixtral-8x7b-instruct'], 'mistralai/Mixtral-8x7B-v0.1': ['mixtral', 'mixtral-8x7b'], 'roneneldan/TinyStories-1Layer-21M': ['tiny-stories-1L-21M'], 'roneneldan/TinyStories-1M': ['tiny-stories-1M'], 'roneneldan/TinyStories-28M': ['tiny-stories-28M'], 'roneneldan/TinyStories-2Layers-33M': ['tiny-stories-2L-33M'], 'roneneldan/TinyStories-33M': ['tiny-stories-33M'], 'roneneldan/TinyStories-3M': ['tiny-stories-3M'], 'roneneldan/TinyStories-8M': ['tiny-stories-8M'], 'roneneldan/TinyStories-Instruct-1M': ['tiny-stories-instruct-1M'], 'roneneldan/TinyStories-Instruct-28M': ['tiny-stories-instruct-28M'], 'roneneldan/TinyStories-Instruct-2Layers-33M': ['tiny-stories-instruct-2L-33M'], 'roneneldan/TinyStories-Instruct-33M': ['tiny-stories-instruct-33M'], 'roneneldan/TinyStories-Instruct-3M': ['tiny-stories-instruct-3M'], 'roneneldan/TinyStories-Instruct-8M': ['tiny-stories-instruct-8M'], 'roneneldan/TinyStories-Instuct-1Layer-21M': ['tiny-stories-instruct-1L-21M'], 'stabilityai/stablelm-base-alpha-3b': ['stablelm-base-alpha-3b', 'stablelm-base-3b'], 'stabilityai/stablelm-base-alpha-7b': ['stablelm-base-alpha-7b', 'stablelm-base-7b'], 'stabilityai/stablelm-tuned-alpha-3b': ['stablelm-tuned-alpha-3b', 'stablelm-tuned-3b'], 'stabilityai/stablelm-tuned-alpha-7b': ['stablelm-tuned-alpha-7b', 'stablelm-tuned-7b'], 'stanford-crfm/alias-gpt2-small-x21': ['stanford-gpt2-small-a', 'alias-gpt2-small-x21', 'gpt2-mistral-small-a', 'gpt2-stanford-small-a'], 'stanford-crfm/arwen-gpt2-medium-x21': ['stanford-gpt2-medium-a', 'arwen-gpt2-medium-x21', 'gpt2-medium-small-a', 'gpt2-stanford-medium-a'], 'stanford-crfm/battlestar-gpt2-small-x49': ['stanford-gpt2-small-b', 'battlestar-gpt2-small-x49', 'gpt2-mistral-small-b', 'gpt2-mistral-small-b'], 'stanford-crfm/beren-gpt2-medium-x49': ['stanford-gpt2-medium-b', 'beren-gpt2-medium-x49', 'gpt2-medium-small-b', 'gpt2-stanford-medium-b'], 'stanford-crfm/caprica-gpt2-small-x81': ['stanford-gpt2-small-c', 'caprica-gpt2-small-x81', 'gpt2-mistral-small-c', 'gpt2-stanford-small-c'], 'stanford-crfm/celebrimbor-gpt2-medium-x81': ['stanford-gpt2-medium-c', 'celebrimbor-gpt2-medium-x81', 'gpt2-medium-small-c', 'gpt2-medium-small-c'], 'stanford-crfm/darkmatter-gpt2-small-x343': ['stanford-gpt2-small-d', 'darkmatter-gpt2-small-x343', 'gpt2-mistral-small-d', 'gpt2-mistral-small-d'], 'stanford-crfm/durin-gpt2-medium-x343': ['stanford-gpt2-medium-d', 'durin-gpt2-medium-x343', 'gpt2-medium-small-d', 'gpt2-stanford-medium-d'], 'stanford-crfm/eowyn-gpt2-medium-x777': ['stanford-gpt2-medium-e', 'eowyn-gpt2-medium-x777', 'gpt2-medium-small-e', 'gpt2-stanford-medium-e'], 'stanford-crfm/expanse-gpt2-small-x777': ['stanford-gpt2-small-e', 'expanse-gpt2-small-x777', 'gpt2-mistral-small-e', 'gpt2-mistral-small-e']}#
Model aliases for models on HuggingFace.
- transformer_lens.loading_from_pretrained.NON_HF_HOSTED_MODEL_NAMES = ['llama-7b-hf', 'llama-13b-hf', 'llama-30b-hf', 'llama-65b-hf']#
Official model names for models not hosted on HuggingFace.
- transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES = ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'distilgpt2', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neox-20b', 'stanford-crfm/alias-gpt2-small-x21', 'stanford-crfm/battlestar-gpt2-small-x49', 'stanford-crfm/caprica-gpt2-small-x81', 'stanford-crfm/darkmatter-gpt2-small-x343', 'stanford-crfm/expanse-gpt2-small-x777', 'stanford-crfm/arwen-gpt2-medium-x21', 'stanford-crfm/beren-gpt2-medium-x49', 'stanford-crfm/celebrimbor-gpt2-medium-x81', 'stanford-crfm/durin-gpt2-medium-x343', 'stanford-crfm/eowyn-gpt2-medium-x777', 'EleutherAI/pythia-14m', 'EleutherAI/pythia-31m', 'EleutherAI/pythia-70m', 'EleutherAI/pythia-160m', 'EleutherAI/pythia-410m', 'EleutherAI/pythia-1b', 'EleutherAI/pythia-1.4b', 'EleutherAI/pythia-2.8b', 'EleutherAI/pythia-6.9b', 'EleutherAI/pythia-12b', 'EleutherAI/pythia-70m-deduped', 'EleutherAI/pythia-160m-deduped', 'EleutherAI/pythia-410m-deduped', 'EleutherAI/pythia-1b-deduped', 'EleutherAI/pythia-1.4b-deduped', 'EleutherAI/pythia-2.8b-deduped', 'EleutherAI/pythia-6.9b-deduped', 'EleutherAI/pythia-12b-deduped', 'EleutherAI/pythia-70m-v0', 'EleutherAI/pythia-160m-v0', 'EleutherAI/pythia-410m-v0', 'EleutherAI/pythia-1b-v0', 'EleutherAI/pythia-1.4b-v0', 'EleutherAI/pythia-2.8b-v0', 'EleutherAI/pythia-6.9b-v0', 'EleutherAI/pythia-12b-v0', 'EleutherAI/pythia-70m-deduped-v0', 'EleutherAI/pythia-160m-deduped-v0', 'EleutherAI/pythia-410m-deduped-v0', 'EleutherAI/pythia-1b-deduped-v0', 'EleutherAI/pythia-1.4b-deduped-v0', 'EleutherAI/pythia-2.8b-deduped-v0', 'EleutherAI/pythia-6.9b-deduped-v0', 'EleutherAI/pythia-12b-deduped-v0', 'EleutherAI/pythia-160m-seed1', 'EleutherAI/pythia-160m-seed2', 'EleutherAI/pythia-160m-seed3', 'NeelNanda/SoLU_1L_v9_old', 'NeelNanda/SoLU_2L_v10_old', 'NeelNanda/SoLU_4L_v11_old', 'NeelNanda/SoLU_6L_v13_old', 'NeelNanda/SoLU_8L_v21_old', 'NeelNanda/SoLU_10L_v22_old', 'NeelNanda/SoLU_12L_v23_old', 'NeelNanda/SoLU_1L512W_C4_Code', 'NeelNanda/SoLU_2L512W_C4_Code', 'NeelNanda/SoLU_3L512W_C4_Code', 'NeelNanda/SoLU_4L512W_C4_Code', 'NeelNanda/SoLU_6L768W_C4_Code', 'NeelNanda/SoLU_8L1024W_C4_Code', 'NeelNanda/SoLU_10L1280W_C4_Code', 'NeelNanda/SoLU_12L1536W_C4_Code', 'NeelNanda/GELU_1L512W_C4_Code', 'NeelNanda/GELU_2L512W_C4_Code', 'NeelNanda/GELU_3L512W_C4_Code', 'NeelNanda/GELU_4L512W_C4_Code', 'NeelNanda/Attn_Only_1L512W_C4_Code', 'NeelNanda/Attn_Only_2L512W_C4_Code', 'NeelNanda/Attn_Only_3L512W_C4_Code', 'NeelNanda/Attn_Only_4L512W_C4_Code', 'NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr', 'NeelNanda/SoLU_1L512W_Wiki_Finetune', 'NeelNanda/SoLU_4L512W_Wiki_Finetune', 'ArthurConmy/redwood_attn_2l', 'llama-7b-hf', 'llama-13b-hf', 'llama-30b-hf', 'llama-65b-hf', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-2-7b-chat-hf', 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-70b-chat-hf', 'codellama/CodeLlama-7b-hf', 'codellama/CodeLlama-7b-Python-hf', 'codellama/CodeLlama-7b-Instruct-hf', 'meta-llama/Meta-Llama-3-8B', 'meta-llama/Meta-Llama-3-8B-Instruct', 'meta-llama/Meta-Llama-3-70B', 'meta-llama/Meta-Llama-3-70B-Instruct', 'meta-llama/Llama-3.2-1B', 'meta-llama/Llama-3.2-3B', 'meta-llama/Llama-3.2-1B-Instruct', 'meta-llama/Llama-3.2-3B-Instruct', 'meta-llama/Llama-3.1-70B', 'meta-llama/Llama-3.1-8B', 'meta-llama/Llama-3.1-8B-Instruct', 'meta-llama/Llama-3.1-70B-Instruct', 'Baidicoot/Othello-GPT-Transformer-Lens', 'bert-base-cased', 'roneneldan/TinyStories-1M', 'roneneldan/TinyStories-3M', 'roneneldan/TinyStories-8M', 'roneneldan/TinyStories-28M', 'roneneldan/TinyStories-33M', 'roneneldan/TinyStories-Instruct-1M', 'roneneldan/TinyStories-Instruct-3M', 'roneneldan/TinyStories-Instruct-8M', 'roneneldan/TinyStories-Instruct-28M', 'roneneldan/TinyStories-Instruct-33M', 'roneneldan/TinyStories-1Layer-21M', 'roneneldan/TinyStories-2Layers-33M', 'roneneldan/TinyStories-Instuct-1Layer-21M', 'roneneldan/TinyStories-Instruct-2Layers-33M', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b', 'stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-Nemo-Base-2407', 'mistralai/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'bigscience/bloom-560m', 'bigscience/bloom-1b1', 'bigscience/bloom-1b7', 'bigscience/bloom-3b', 'bigscience/bloom-7b1', 'bigcode/santacoder', 'Qwen/Qwen-1_8B', 'Qwen/Qwen-7B', 'Qwen/Qwen-14B', 'Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 'Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-7B', 'Qwen/Qwen2-7B-Instruct', 'microsoft/phi-1', 'microsoft/phi-1_5', 'microsoft/phi-2', 'microsoft/Phi-3-mini-4k-instruct', 'google/gemma-2b', 'google/gemma-7b', 'google/gemma-2b-it', 'google/gemma-7b-it', 'google/gemma-2-2b', 'google/gemma-2-2b-it', 'google/gemma-2-9b', 'google/gemma-2-9b-it', 'google/gemma-2-27b', 'google/gemma-2-27b-it', '01-ai/Yi-6B', '01-ai/Yi-34B', '01-ai/Yi-6B-Chat', '01-ai/Yi-34B-Chat', 'google-t5/t5-small', 'google-t5/t5-base', 'google-t5/t5-large', 'ai-forever/mGPT']#
Official model names for models on HuggingFace.
- transformer_lens.loading_from_pretrained.get_checkpoint_labels(model_name: str, **kwargs)#
Returns the checkpoint labels for a given model, and the label_type (step or token). Raises an error for models that are not checkpointed.
- transformer_lens.loading_from_pretrained.get_num_params_of_pretrained(model_name)#
Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
- transformer_lens.loading_from_pretrained.get_pretrained_model_config(model_name: str, hf_cfg: Optional[dict] = None, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, fold_ln: bool = False, device: Optional[Union[str, device]] = None, n_devices: int = 1, default_prepend_bos: Optional[bool] = None, dtype: dtype = torch.float32, first_n_layers: Optional[int] = None, **kwargs)#
Returns the pretrained model config as an HookedTransformerConfig object.
There are two types of pretrained models: HuggingFace models (where AutoModel and AutoConfig work), and models trained by me (NeelNanda) which aren’t as integrated with HuggingFace infrastructure.
- Parameters:
model_name – The name of the model. This can be either the official HuggingFace model name, or the name of a model trained by me (NeelNanda).
hf_cfg (dict, optional) – Config of a loaded pretrained HF model, converted to a dictionary.
checkpoint_index (int, optional) – If loading from a checkpoint, the index of the checkpoint to load. Defaults to None.
checkpoint_value (int, optional) – If loading from a checkpoint, the
of (value) – the checkpoint to load, ie the step or token number (each model has checkpoints labelled with exactly one of these). Defaults to None.
fold_ln (bool, optional) – Whether to fold the layer norm into the subsequent linear layers (see HookedTransformer.fold_layer_norm for details). Defaults to False.
device (str, optional) – The device to load the model onto. By default will load to CUDA if available, else CPU.
n_devices (int, optional) – The number of devices to split the model across. Defaults to 1.
default_prepend_bos (bool, optional) –
Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only when input is a string). Resolution order for default_prepend_bos: 1. If user passes value explicitly, use that value 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it’s False) 3. Global default (True)
Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position and accordingly lose information from the first token, so this empirically seems to give better results. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string.
dtype (torch.dtype, optional) – The dtype to load the TransformerLens model in.
kwargs – Other optional arguments passed to HuggingFace’s from_pretrained. Also given to other HuggingFace functions when compatible.