def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_per_head_quant_scales: bool = False,
attn_type: str | None = None,
num_heads: int | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
cache_config = vllm_config.cache_config
if cache_config is not None and cache_config.user_specified_block_size:
block_size = cache_config.block_size
else:
block_size = None
kv_transfer_config = vllm_config.kv_transfer_config
use_kv_connector = (
kv_transfer_config is not None and kv_transfer_config.is_kv_transfer_instance
)
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=vllm_config.attention_config.use_non_causal,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
use_kv_connector=use_kv_connector,
)
return _cached_get_attn_backend(
backend=vllm_config.attention_config.backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)