transformer_lens.model_bridge.generalized_components.mpt_alibi_attention module¶
MPT ALiBi attention bridge — MPT uses position_bias kwarg + bool causal mask.
- class transformer_lens.model_bridge.generalized_components.mpt_alibi_attention.MPTALiBiAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)¶
Bases:
ALiBiJointQKVAttentionBridgeALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.
- forward(*args: Any, **kwargs: Any) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, None]¶
2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Test inputs using MPT’s kwarg names: position_bias (no batch dim) + bool causal mask.
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
- Parameters:
original_component – The original attention layer to wrap