transformer_lens.components.t5_attention#

class transformer_lens.components.t5_attention.T5Attention(cfg: Union[Dict, HookedTransformerConfig], has_relative_attention_bias: bool = False, attn_type: str = 'global', layer_id: Optional[int] = None)#

Bases: AbstractAttention

T5 attention - with relative attention bias and cross-attention support This realisation expects you to precompute relative positional bias, and then feed it to forward like `python attn = T5Attention(cfg, has_relative_attention_bias=True) positional_bias = attn.compute_relative_attention_bias(query_len, key_len, device=device) result = attn(query, key, value, position_bias=positional_bias) `

compute_relative_attention_bias(query_length: int, key_length: int, device=None) Float[Tensor, '1 head_index pos kv_pos']#

Compute binned relative position bias