Skip to content

vllm.v1.attention.backends.mla.prefill

Modules:

Name Description
base

Abstract base class for MLA prefill backends.

flash_attn

FlashAttention backend for MLA prefill.

flashinfer

FlashInfer backend for MLA prefill.

registry

Registry for MLA prefill backends.

selector

Selector for MLA prefill backends.

tokenspeed_mla

TokenSpeed CuTe DSL backend for MLA prefill.

trtllm_ragged

TRT-LLM Ragged backend for MLA prefill.

MLAPrefillBackend

Bases: ABC

Abstract base class for MLA prefill backends.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillBackend(ABC):
    """Abstract base class for MLA prefill backends."""

    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
    ]
    requires_r1_mla_dimensions: ClassVar[bool] = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def is_available(cls) -> bool:
        return True

    @classmethod
    def validate_configuration(
        cls,
        device_capability: "DeviceCapability",
        selector_config: "MLAPrefillSelectorConfig",
    ) -> list[str]:
        invalid_reasons: list[str] = []

        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append(
                f"compute capability {device_capability.major}."
                f"{device_capability.minor} not supported"
            )

        if not cls.supports_dtype(selector_config.dtype):
            invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

        if not cls.is_available():
            invalid_reasons.append("required dependencies not available")

        if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
            invalid_reasons.append(
                "model does not have DeepSeek R1 MLA dimensions "
                "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
            )

        return invalid_reasons

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
    ) -> None:
        self.num_heads = num_heads
        self.scale = scale
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.vllm_config = vllm_config

    def prepare_metadata(  # noqa: B027
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
    ) -> None:
        """Prepare backend-specific metadata before the forward pass.

        Called by the metadata builder after constructing the prefill metadata.
        """
        self._prefill_metadata = prefill_metadata

    @abstractmethod
    def run_prefill_new_tokens(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

    @abstractmethod
    def run_prefill_context_chunk(
        self,
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

prepare_metadata

prepare_metadata(
    prefill_metadata: MLACommonPrefillMetadata,
) -> None

Prepare backend-specific metadata before the forward pass.

Called by the metadata builder after constructing the prefill metadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
def prepare_metadata(  # noqa: B027
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
) -> None:
    """Prepare backend-specific metadata before the forward pass.

    Called by the metadata builder after constructing the prefill metadata.
    """
    self._prefill_metadata = prefill_metadata

MLAPrefillBackendEnum

Bases: Enum

Enumeration of all supported MLA prefill backends.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
class MLAPrefillBackendEnum(Enum, metaclass=_MLAPrefillBackendEnumMeta):
    """Enumeration of all supported MLA prefill backends."""

    FLASH_ATTN = (
        "vllm.v1.attention.backends.mla.prefill.flash_attn.FlashAttnPrefillBackend"
    )
    FLASHINFER = (
        "vllm.v1.attention.backends.mla.prefill.flashinfer.FlashInferPrefillBackend"
    )
    TRTLLM_RAGGED = (
        "vllm.v1.attention.backends.mla.prefill.trtllm_ragged."
        "TrtllmRaggedPrefillBackend"
    )
    TOKENSPEED_MLA = (
        "vllm.v1.attention.backends.mla.prefill.tokenspeed_mla."
        "TokenspeedMLAPrefillBackend"
    )
    # Placeholder for third-party/custom backends - must be registered before use
    # set to None to avoid alias with other backend, whose value is an empty string
    CUSTOM = None

    def get_path(self) -> str:
        """Get the class path for this backend (respects overrides).

        Returns:
            The fully qualified class path string

        Raises:
            ValueError: If Backend.CUSTOM is used without being registered
        """
        path = _MLA_PREFILL_OVERRIDES.get(self, self.value)
        if not path:
            raise ValueError(
                f"MLA prefill backend {self.name} must be registered before "
                f"use. Use register_mla_prefill_backend("
                f"MLAPrefillBackendEnum.{self.name}, "
                f"'your.module.YourClass')"
            )
        return path

    def get_class(self) -> "type[MLAPrefillBackend]":
        """Get the backend class (respects overrides).

        Returns:
            The backend class

        Raises:
            ImportError: If the backend class cannot be imported
            ValueError: If CUSTOM is used without being registered
        """
        return resolve_obj_by_qualname(self.get_path())

    def is_overridden(self) -> bool:
        """Check if this backend has been overridden."""
        return self in _MLA_PREFILL_OVERRIDES

    def clear_override(self) -> None:
        """Clear any override for this backend, reverting to the default."""
        _MLA_PREFILL_OVERRIDES.pop(self, None)

clear_override

clear_override() -> None

Clear any override for this backend, reverting to the default.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def clear_override(self) -> None:
    """Clear any override for this backend, reverting to the default."""
    _MLA_PREFILL_OVERRIDES.pop(self, None)

get_class

get_class() -> type[MLAPrefillBackend]

Get the backend class (respects overrides).

Returns:

Type Description
type[MLAPrefillBackend]

The backend class

Raises:

Type Description
ImportError

If the backend class cannot be imported

ValueError

If CUSTOM is used without being registered

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_class(self) -> "type[MLAPrefillBackend]":
    """Get the backend class (respects overrides).

    Returns:
        The backend class

    Raises:
        ImportError: If the backend class cannot be imported
        ValueError: If CUSTOM is used without being registered
    """
    return resolve_obj_by_qualname(self.get_path())

get_path

get_path() -> str

Get the class path for this backend (respects overrides).

Returns:

Type Description
str

The fully qualified class path string

Raises:

Type Description
ValueError

If Backend.CUSTOM is used without being registered

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_path(self) -> str:
    """Get the class path for this backend (respects overrides).

    Returns:
        The fully qualified class path string

    Raises:
        ValueError: If Backend.CUSTOM is used without being registered
    """
    path = _MLA_PREFILL_OVERRIDES.get(self, self.value)
    if not path:
        raise ValueError(
            f"MLA prefill backend {self.name} must be registered before "
            f"use. Use register_mla_prefill_backend("
            f"MLAPrefillBackendEnum.{self.name}, "
            f"'your.module.YourClass')"
        )
    return path

is_overridden

is_overridden() -> bool

Check if this backend has been overridden.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def is_overridden(self) -> bool:
    """Check if this backend has been overridden."""
    return self in _MLA_PREFILL_OVERRIDES

get_mla_prefill_backend

get_mla_prefill_backend(
    vllm_config: VllmConfig,
) -> type[MLAPrefillBackend]

Select the MLA prefill backend based on configuration and device.

This function first checks for explicit user preferences via mla_prefill_backend in AttentionConfig, then falls back to automatic priority-based selection.

Parameters:

Name Type Description Default
vllm_config VllmConfig

The vLLM configuration.

required

Returns:

Type Description
type[MLAPrefillBackend]

The selected prefill backend class.

Source code in vllm/v1/attention/backends/mla/prefill/selector.py
def get_mla_prefill_backend(
    vllm_config: "VllmConfig",
) -> "type[MLAPrefillBackend]":
    """Select the MLA prefill backend based on configuration and device.

    This function first checks for explicit user preferences via
    mla_prefill_backend in AttentionConfig, then falls back to automatic
    priority-based selection.

    Args:
        vllm_config: The vLLM configuration.

    Returns:
        The selected prefill backend class.
    """
    from vllm.platforms import current_platform

    device_capability = current_platform.get_device_capability()
    if device_capability is None:
        logger.info_once(
            "Device capability not available, using FlashAttention MLA prefill backend."
        )
        return MLAPrefillBackendEnum.FLASH_ATTN.get_class()

    attention_config = vllm_config.attention_config

    selector_config = MLAPrefillSelectorConfig(
        dtype=vllm_config.model_config.dtype,
        is_r1_compatible=is_deepseek_r1_mla_compatible(vllm_config),
    )

    if attention_config.mla_prefill_backend is not None:
        selected_backend = attention_config.mla_prefill_backend
        backend_cls: type[MLAPrefillBackend] | None = None
        try:
            backend_cls = selected_backend.get_class()
            invalid_reasons = backend_cls.validate_configuration(
                device_capability, selector_config
            )
        except ImportError:
            invalid_reasons = ["ImportError"]
        if invalid_reasons:
            raise ValueError(
                f"Selected MLA prefill backend {selected_backend.name} "
                f"is not valid for this configuration. "
                f"Reason: {invalid_reasons}"
            )
        assert backend_cls is not None
        logger.info("Using %s MLA prefill backend.", selected_backend.name)
        return backend_cls

    return _auto_select_mla_prefill_backend(
        device_capability,
        selector_config,
    )

register_mla_prefill_backend

register_mla_prefill_backend(
    backend: MLAPrefillBackendEnum,
    class_path: str | None = None,
) -> Callable[[type], type]

Register or override an MLA prefill backend implementation.

Parameters:

Name Type Description Default
backend MLAPrefillBackendEnum

The MLAPrefillBackendEnum member to register.

required
class_path str | None

Optional class path. If not provided and used as decorator, will be auto-generated from the class.

None

Returns:

Type Description
Callable[[type], type]

Decorator function if class_path is None, otherwise a no-op.

Examples:

Override an existing MLA prefill backend

@register_mla_prefill_backend(MLAPrefillBackendEnum.FLASH_ATTN) class MyCustomFlashAttn(MLAPrefillBackend): ...

Register a custom third-party MLA prefill backend

@register_mla_prefill_backend(MLAPrefillBackendEnum.CUSTOM) class MyCustomPrefillBackend(MLAPrefillBackend): ...

Direct registration

register_mla_prefill_backend( MLAPrefillBackendEnum.CUSTOM, "my.module.MyCustomPrefillBackend" )

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def register_mla_prefill_backend(
    backend: MLAPrefillBackendEnum,
    class_path: str | None = None,
) -> Callable[[type], type]:
    """Register or override an MLA prefill backend implementation.

    Args:
        backend: The MLAPrefillBackendEnum member to register.
        class_path: Optional class path. If not provided and used as
            decorator, will be auto-generated from the class.

    Returns:
        Decorator function if class_path is None, otherwise a no-op.

    Examples:
        # Override an existing MLA prefill backend
        @register_mla_prefill_backend(MLAPrefillBackendEnum.FLASH_ATTN)
        class MyCustomFlashAttn(MLAPrefillBackend):
            ...

        # Register a custom third-party MLA prefill backend
        @register_mla_prefill_backend(MLAPrefillBackendEnum.CUSTOM)
        class MyCustomPrefillBackend(MLAPrefillBackend):
            ...

        # Direct registration
        register_mla_prefill_backend(
            MLAPrefillBackendEnum.CUSTOM,
            "my.module.MyCustomPrefillBackend"
        )
    """

    def decorator(cls: type) -> type:
        _MLA_PREFILL_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
        return cls

    if class_path is not None:
        _MLA_PREFILL_OVERRIDES[backend] = class_path
        return lambda x: x

    return decorator