transformer_lens.model_bridge.generalized_components.mpt_alibi_attention module

MPT ALiBi attention bridge — MPT uses position_bias kwarg + bool causal mask.

class transformer_lens.model_bridge.generalized_components.mpt_alibi_attention.MPTALiBiAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)

Bases: ALiBiJointQKVAttentionBridge

ALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.

forward(*args: Any, **kwargs: Any) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, None]

2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Test inputs using MPT’s kwarg names: position_bias (no batch dim) + bool causal mask.

set_original_component(original_component: Module) None

Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.

Parameters:

original_component – The original attention layer to wrap