Skip to content

vllm.model_executor.kernels.linear

This module re-exports linear kernel implementations to provide a stable import interface during an ongoing reorganization. Upcoming PRs will remove the scaled_mm and mixed_precision subdirectories and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.) rather than by precision type. By centralizing exports here, we minimize the need to update imports across other modules when the internal structure changes. If you are adding a new kernel selector or kernel implementation, add it to this init.py to maintain import stability.

Modules:

Name Description
Mxfp8LinearKernel
base
mixed_precision
mxfp4
mxfp8
nvfp4
scaled_mm
zentorch_utils

Gates zentorch CPU linear dispatch on platform/op availability.

AiterInt8ScaledMMLinearKernel

Bases: CutlassInt8ScaledMMLinearKernel

Source code in vllm/model_executor/kernels/linear/scaled_mm/aiter.py
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_rocm():
            return False, "Requires ROCm."

        if compute_capability is not None and compute_capability < 90:
            return False, "requires compute capability 90 and above."

        try:
            import aiter  # noqa: F401 # deliberately attempt to import aiter
        except Exception:
            return False, "requires `aiter` to be installed."

        if not rocm_aiter_ops.is_linear_enabled():
            return (
                False,
                "requires setting `VLLM_ROCM_USE_AITER=1` "
                "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
                "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
            )
        return True, None

    @classmethod
    def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        if not c.input_symmetric:
            return False, "supports symmetric quantization only."
        return True, None

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        `AiterInt8ScaledMMLinearKernel` implements a fused version of
            `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
        where scale_a * a and scale_b * b are implemented using numpy-style
        broadcasting.
        Currently only support per-tensor-per-tensor GEMM
        and per-token-per-channel GEMM through AITER
        w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
        ATIER block scaled GEMM and mix-precision GEMM.
        """
        w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)

        # ops.scaled_int8_quant supports both dynamic and static quant:
        # * dynamic, i_s is None and x_s computed from x.
        # * static, i_s is scalar and x_s is i_s.
        symmetric = azp_adj is None
        assert symmetric, (
            "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
        )
        x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)

        assert x_zp is None, (
            "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
        )
        out_dtype = x.dtype

        assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
        assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
        assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype

        m = x_q.shape[0]  # a
        n = w_q.shape[1]  # b

        per_tensor_scale_a = x_s.numel() == 1
        per_tensor_scale_b = w_s.numel() == 1
        per_token_scale_a = x_s.numel() == m
        per_channel_scale_b = w_s.numel() == n

        # @TODO:
        # Maybe broadcast the per-tensor-scale into per-channel-scale
        # if one of the scale is a per-channel-scale.
        # For now, it only supports:
        # - per-tensor-per-tensor a8w8 scaled GEMM, and
        # - per-token-per-channel a8w8 scaled GEMM
        assert (per_tensor_scale_a and per_tensor_scale_b) or (
            per_token_scale_a and per_channel_scale_b
        ), (
            "Currently only support per-tensor-per-tensor GEMM "
            " and per-token-per-channel GEMM through AITER"
            " w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
            "does not support AITER block scaled GEMM."
        )

        # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
        # a to be [M, K]
        # b to be [N, K]
        # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
        return rocm_aiter_ops.w8a8_gemm(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor

AiterInt8ScaledMMLinearKernel implements a fused version of output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype) where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. Currently only support per-tensor-per-tensor GEMM and per-token-per-channel GEMM through AITER w8a8 scaled gemm. AiterInt8ScaledMMLinearKernel also does not support ATIER block scaled GEMM and mix-precision GEMM.

Source code in vllm/model_executor/kernels/linear/scaled_mm/aiter.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    """
    `AiterInt8ScaledMMLinearKernel` implements a fused version of
        `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
    where scale_a * a and scale_b * b are implemented using numpy-style
    broadcasting.
    Currently only support per-tensor-per-tensor GEMM
    and per-token-per-channel GEMM through AITER
    w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
    ATIER block scaled GEMM and mix-precision GEMM.
    """
    w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)

    # ops.scaled_int8_quant supports both dynamic and static quant:
    # * dynamic, i_s is None and x_s computed from x.
    # * static, i_s is scalar and x_s is i_s.
    symmetric = azp_adj is None
    assert symmetric, (
        "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
    )
    x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)

    assert x_zp is None, (
        "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
    )
    out_dtype = x.dtype

    assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
    assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
    assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype

    m = x_q.shape[0]  # a
    n = w_q.shape[1]  # b

    per_tensor_scale_a = x_s.numel() == 1
    per_tensor_scale_b = w_s.numel() == 1
    per_token_scale_a = x_s.numel() == m
    per_channel_scale_b = w_s.numel() == n

    # @TODO:
    # Maybe broadcast the per-tensor-scale into per-channel-scale
    # if one of the scale is a per-channel-scale.
    # For now, it only supports:
    # - per-tensor-per-tensor a8w8 scaled GEMM, and
    # - per-token-per-channel a8w8 scaled GEMM
    assert (per_tensor_scale_a and per_tensor_scale_b) or (
        per_token_scale_a and per_channel_scale_b
    ), (
        "Currently only support per-tensor-per-tensor GEMM "
        " and per-token-per-channel GEMM through AITER"
        " w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
        "does not support AITER block scaled GEMM."
    )

    # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
    # a to be [M, K]
    # b to be [N, K]
    # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
    return rocm_aiter_ops.w8a8_gemm(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

CutlassFP8ScaledMMLinearKernel

Bases: FP8ScaledMMLinearKernel

Source code in vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    def __init__(
        self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
    ) -> None:
        self.logical_output_size: int | None = None
        super().__init__(c, layer_param_names)

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cuda():
            return False, "requires CUDA."
        return True, None

    @classmethod
    def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    @staticmethod
    def _pad_to_alignment(
        x: torch.Tensor, dim: int, alignment: int, value: float = 0.0
    ) -> torch.Tensor:
        """Pad tensor ``x`` along ``dim`` to the next multiple of
        ``alignment``."""
        remainder = x.shape[dim] % alignment
        if remainder == 0:
            return x
        pad_size = alignment - remainder
        pad_spec = [0] * (2 * x.dim())
        pad_spec[-(2 * dim + 1)] = pad_size
        return torch.nn.functional.pad(x, pad_spec, value=value)

    @staticmethod
    def padded_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        if loaded_weight.shape != param.shape:
            slices = tuple(slice(0, s) for s in loaded_weight.shape)
            param.data[slices].copy_(loaded_weight)
        else:
            param.data.copy_(loaded_weight.view(param.shape))

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight_name, weight_scale_name, _, _ = self.layer_param_names
        weight = getattr(layer, weight_name)

        # keep the logical output width so runtime can slice away static padding.
        self.logical_output_size = weight.shape[1]

        pad_k = (16 - weight.shape[0] % 16) % 16
        pad_n = (16 - weight.shape[1] % 16) % 16
        if pad_k == 0 and pad_n == 0:
            return

        # B is column-major [K, N]
        padded_weight = torch.nn.functional.pad(
            weight.t().contiguous(),
            (0, pad_k, 0, pad_n),
        ).t()
        replace_parameter(layer, weight_name, padded_weight.data)
        set_weight_attrs(
            getattr(layer, weight_name),
            {
                "weight_loader": self.padded_weight_loader,
            },
        )

        weight_scale = getattr(layer, weight_scale_name, None)
        if weight_scale is not None and pad_n > 0 and weight_scale.numel() > 1:
            flat_scale = weight_scale.reshape(-1)
            padded_scale = self._pad_to_alignment(
                flat_scale, dim=0, alignment=16, value=1.0
            ).view(-1, *weight_scale.shape[1:])
            replace_parameter(layer, weight_scale_name, padded_scale.data)
            set_weight_attrs(
                getattr(layer, weight_name),
                {
                    "weight_loader": self.padded_weight_loader,
                },
            )

    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        padded_k, padded_n = B.shape
        output_size = self.logical_output_size
        assert output_size is not None
        pad_k = padded_k - A.shape[1]
        pad_n = padded_n - output_size

        if pad_k > 0:
            A = self._pad_to_alignment(A, dim=1, alignment=16)
        if pad_n > 0 and bias is not None:
            bias = self._pad_to_alignment(bias, dim=0, alignment=16)

        output = ops.cutlass_scaled_mm(
            A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
        )

        if pad_n > 0:
            output = output[..., :output_size].contiguous()

        return output.view(*output_shape[:-1], output_size)

_pad_to_alignment staticmethod

_pad_to_alignment(
    x: Tensor, dim: int, alignment: int, value: float = 0.0
) -> Tensor

Pad tensor x along dim to the next multiple of alignment.

Source code in vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
@staticmethod
def _pad_to_alignment(
    x: torch.Tensor, dim: int, alignment: int, value: float = 0.0
) -> torch.Tensor:
    """Pad tensor ``x`` along ``dim`` to the next multiple of
    ``alignment``."""
    remainder = x.shape[dim] % alignment
    if remainder == 0:
        return x
    pad_size = alignment - remainder
    pad_spec = [0] * (2 * x.dim())
    pad_spec[-(2 * dim + 1)] = pad_size
    return torch.nn.functional.pad(x, pad_spec, value=value)

CutlassNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via the vLLM CUTLASS kernel.

Source code in vllm/model_executor/kernels/linear/nvfp4/cutlass.py
class CutlassNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via the vLLM CUTLASS kernel."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if cutlass_fp4_supported():
            return True, None
        return False, "CUTLASS FP4 kernels not available"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight_scale = torch.nn.Parameter(
            swizzle_blockscale(layer.weight_scale.data), requires_grad=False
        )
        padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
            layer.weight.data
        )
        layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
        layer.weights_padding_cols = weights_padding_cols

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]
        weights_padding_bytes = getattr(layer, "weights_padding_cols", 0)

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="cutlass",
            padded_n=x.shape[-1] + weights_padding_bytes * 2,
        )

        out = cutlass_scaled_fp4_mm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
        )

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

EmulationMxfp8LinearKernel

Bases: Mxfp8LinearKernel

Software emulation fallback for MXFP8 (dequant to BF16).

Source code in vllm/model_executor/kernels/linear/mxfp8/emulation.py
class EmulationMxfp8LinearKernel(Mxfp8LinearKernel):
    """Software emulation fallback for MXFP8 (dequant to BF16)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        return True, None

    @classmethod
    def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape
        scale_k = K // MXFP8_BLOCK_SIZE

        weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        weight_scale = layer.weight_scale
        if weight_scale.dtype != MXFP8_SCALE_DTYPE:
            raise ValueError(
                f"Emulation backend requires {MXFP8_SCALE_DTYPE} "
                f"weight_scale dtype, got {weight_scale.dtype}."
            )
        if weight_scale.ndim != 2:
            raise ValueError(
                f"Emulation backend requires 2D weight_scale, "
                f"got {weight_scale.ndim}D. "
                f"Ensure process_weights_after_loading was called."
            )

        weight_bf16 = dequant_mxfp8_to_bf16(layer.weight, weight_scale)
        output = torch.nn.functional.linear(x, weight_bf16, bias)
        return output.to(x.dtype)

EmulationNvFp4LinearKernel

Bases: NvFp4LinearKernel

Software emulation fallback for NVFP4 (dequant → BF16 matmul).

Source code in vllm/model_executor/kernels/linear/nvfp4/emulation.py
class EmulationNvFp4LinearKernel(NvFp4LinearKernel):
    """Software emulation fallback for NVFP4 (dequant → BF16 matmul)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        # Always available as a last-resort fallback.
        return True, None

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Move the E2M1 lookup table to the device now, because
        # `.to(device)` is not allowed during CUDA graph capture.
        kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        out = run_nvfp4_emulations(
            x=x,
            input_global_scale=layer.input_global_scale_inv,
            weight=layer.weight,
            weight_scale_swizzled=layer.weight_scale,
            weight_global_scale=layer.weight_global_scale,
            swizzle=False,
        )
        if bias is not None:
            out = out + bias
        return out

FbgemmNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via FBGEMM.

Source code in vllm/model_executor/kernels/linear/nvfp4/fbgemm.py
class FbgemmNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via FBGEMM."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if has_fbgemm_gpu():
            return True, None
        return False, "fbgemm_gpu required"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        swizzled = swizzle_blockscale(layer.weight_scale.data)
        layer.weight_scale = torch.nn.Parameter(
            swizzled.view(-1).view(torch.uint8), requires_grad=False
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        import fbgemm_gpu  # noqa: F401 - registers torch.ops.fbgemm.*

        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="fbgemm",
        )

        out = torch.ops.fbgemm.f4f4bf16(
            x_fp4,
            layer.weight,
            x_blockscale.view(-1).view(torch.uint8),
            layer.weight_scale,
            layer.alpha,
            use_mx=False,
        ).to(output_dtype)

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

FlashInferB12xNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via FlashInfer's b12x CuTe DSL warp-level MMA kernel (SM120+).

Source code in vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
class FlashInferB12xNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via FlashInfer's b12x CuTe DSL warp-level MMA kernel (SM120+)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if current_platform.has_device_capability(120) and has_flashinfer_b12x_gemm():
            return True, None
        return (
            False,
            "FlashInfer b12x requires SM120+ and FlashInfer "
            "with Sm120BlockScaledDenseGemmKernel",
        )

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight_scale = torch.nn.Parameter(
            swizzle_blockscale(layer.weight_scale.data), requires_grad=False
        )
        padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
            layer.weight.data
        )
        layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
        layer.weights_padding_cols = weights_padding_cols

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="b12x",
        )

        x_fp4 = pad_nvfp4_activation_for_cutlass(
            x_fp4, getattr(layer, "weights_padding_cols", 0)
        )

        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
            backend="b12x",
        )

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

FlashInferCudnnNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via FlashInfer's cuDNN wrapper.

Source code in vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
class FlashInferCudnnNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via FlashInfer's cuDNN wrapper."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if has_flashinfer():
            return True, None
        return False, "FlashInfer required"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # cuDNN uses the same swizzled + padded layout as CUTLASS
        layer.weight_scale = torch.nn.Parameter(
            swizzle_blockscale(layer.weight_scale.data), requires_grad=False
        )
        padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
            layer.weight.data
        )
        layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
        layer.weights_padding_cols = weights_padding_cols

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]
        weights_padding_bytes = getattr(layer, "weights_padding_cols", 0)

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="flashinfer-cudnn",
            padded_n=x.shape[-1] + weights_padding_bytes * 2,
        )

        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
            backend="cudnn",
        )

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

FlashInferCutlassMxfp8LinearKernel

Bases: Mxfp8LinearKernel

MXFP8 W8A8 GEMM via FlashInfer CUTLASS (SM100+).

Source code in vllm/model_executor/kernels/linear/mxfp8/flashinfer.py
class FlashInferCutlassMxfp8LinearKernel(Mxfp8LinearKernel):
    """MXFP8 W8A8 GEMM via FlashInfer CUTLASS (SM100+)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if current_platform.has_device_capability(100):
            return True, None
        return False, "requires >=sm_100 (Blackwell)"

    @classmethod
    def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight = layer.weight.data  # [N, K]
        N, K = weight.shape

        scale_k = K // MXFP8_BLOCK_SIZE
        weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous()
        weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)

        layer.weight = Parameter(weight.contiguous(), requires_grad=False)
        layer.weight_scale = Parameter(
            weight_scale_swizzled.contiguous(), requires_grad=False
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        weight = layer.weight
        weight_scale = layer.weight_scale
        out_dtype = x.dtype
        N, K = weight.shape

        input_shape = x.shape
        input_2d = x.view(-1, K)
        M_orig = input_2d.shape[0]

        min_dim = 128

        assert min_dim <= K, (
            f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
            f"in_features is too small for mm_mxfp8."
        )
        assert K % MXFP8_BLOCK_SIZE == 0, (
            f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
        )
        assert min_dim <= N, (
            f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
            f"out_features is too small for mm_mxfp8."
        )

        M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
        if M_padded != M_orig:
            pad_rows = M_padded - M_orig
            input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))

        input_mxfp8, input_scale = mxfp8_e4m3_quantize(
            input_2d, is_sf_swizzled_layout=True
        )

        if not weight.is_contiguous():
            weight = weight.contiguous()

        output = vllm_flashinfer.mm_mxfp8(
            input_mxfp8,
            weight.t(),
            input_scale,
            weight_scale,
            out_dtype=out_dtype,
            backend="cutlass",
        )

        if M_padded != M_orig:
            output = output[:M_orig, :]

        if bias is not None:
            output = output + bias

        output_shape = (*input_shape[:-1], N)
        return output.view(output_shape)

FlashInferCutlassNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via FlashInfer's CUTLASS wrapper.

Source code in vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
class FlashInferCutlassNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via FlashInfer's CUTLASS wrapper."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
            cutlass_fp4_supported,
        )

        if (
            cutlass_fp4_supported()
            and current_platform.has_device_capability(100)
            and has_flashinfer()
        ):
            return True, None
        return False, "FlashInfer + >=sm_100 required"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight_scale = torch.nn.Parameter(
            swizzle_blockscale(layer.weight_scale.data), requires_grad=False
        )
        padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
            layer.weight.data
        )
        layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
        layer.weights_padding_cols = weights_padding_cols

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]
        weights_padding_bytes = getattr(layer, "weights_padding_cols", 0)

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="flashinfer-cutlass",
            padded_n=x.shape[-1] + weights_padding_bytes * 2,
        )

        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
            backend="cutlass",
        )

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

FlashInferFp8DeepGEMMDynamicBlockScaledKernel

Bases: Fp8BlockScaledDynamicMMLinearKernel

Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.

Dispatches between two kernels based on input batch size: - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation. - Large batches (M >= 32): DeepGEMM for peak throughput.

apply_input_quant is False because FlashInfer accepts BF16 input and handles FP8 conversion internally. The DeepGEMM branch therefore quantises BF16→FP8 inside apply_mm via a closure before dispatching to the DeepGEMM kernel — keeping both branches compatible with the single BF16 tensor operand list passed by torch.cond.

Source code in vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
    Fp8BlockScaledDynamicMMLinearKernel
):
    """
    Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.

    Dispatches between two kernels based on input batch size:
    - Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
    - Large batches (M >= 32): DeepGEMM for peak throughput.

    apply_input_quant is False because FlashInfer accepts BF16 input and
    handles FP8 conversion internally.  The DeepGEMM branch therefore
    quantises BF16→FP8 inside apply_mm via a closure before dispatching to
    the DeepGEMM kernel — keeping both branches compatible with the single
    BF16 tensor operand list passed by torch.cond.
    """

    base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
        FlashInferFp8BlockScaledMMKernel
    )
    fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
        DeepGemmFp8BlockScaledMMKernel
    )
    apply_input_quant: ClassVar[bool] = False

    def __init__(self, config: FP8ScaledMMLinearLayerConfig):
        super().__init__(config)
        self.base: FlashInferFp8BlockScaledMMKernel
        self.fallback: DeepGemmFp8BlockScaledMMKernel

    def process_weights_after_loading(self, layer: torch.nn.Module):
        # DeepGEMM need post-processing; both kernels share the same
        # parameter tensor layout so processing once is sufficient.
        self.fallback.process_weights_after_loading(layer)

    def apply_block_scaled_mm(
        self,
        A: torch.Tensor,
        B: torch.Tensor,
        As: torch.Tensor,
        Bs: torch.Tensor,
    ) -> torch.Tensor:
        group_size = self.weight_group_shape.col
        use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0

        return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
            A, B, Bs, group_size, use_deep_gemm_e8m0
        )

FlashInferMxFp4LinearKernel

Bases: MxFp4LinearKernel

MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).

Source code in vllm/model_executor/kernels/linear/mxfp4/flashinfer.py
class FlashInferMxFp4LinearKernel(MxFp4LinearKernel):
    """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if current_platform.has_device_capability(100) and has_flashinfer_cutedsl():
            return True, None
        return False, "FlashInfer + >=sm_100 (Blackwell) required"

    @classmethod
    def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        N, scale_K = layer.weight_scale.shape
        K = scale_K * _MXFP4_GROUP_SIZE

        # swizzle pads N to the next multiple of 128 for CUTLASS tiling
        padded_N = ((N + 127) // 128) * 128
        layer.weight_scale = Parameter(
            swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1),
            requires_grad=False,
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.utils.flashinfer import (
            flashinfer_mxfp4_quantize,
            flashinfer_scaled_fp4_mm,
        )

        weight = layer.weight
        out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
        x_2d = x.reshape(-1, x.shape[-1])

        x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d)
        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            weight,
            x_scale,
            layer.weight_scale,
            alpha=None,
            out_dtype=x.dtype,
            backend="cute-dsl",
            block_size=_MXFP4_GROUP_SIZE,
            use_nvfp4=False,
        )

        if bias is not None:
            out = out + bias
        return out.view(out_shape)

FlashInferTrtllmNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 GEMM via FlashInfer's TensorRT-LLM wrapper.

Source code in vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
class FlashInferTrtllmNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 GEMM via FlashInfer's TensorRT-LLM wrapper."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if has_flashinfer():
            return True, None
        return False, "FlashInfer required"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

        weight = layer.weight.data
        weight_scale = layer.weight_scale.data
        epilogue_tile_m = 128

        layer.weight = torch.nn.Parameter(
            shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m),
            requires_grad=False,
        )
        layer.weight_scale = torch.nn.Parameter(
            shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
            .reshape(weight_scale.shape)
            .view(torch.float8_e4m3fn),
            requires_grad=False,
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        output_size = layer.output_size_per_partition
        output_dtype = x.dtype
        output_shape = [*x.shape[:-1], output_size]

        x_fp4, x_blockscale = scaled_fp4_quant(
            x,
            layer.input_global_scale_inv,
            is_sf_swizzled_layout=True,
            backend="flashinfer-trtllm",
        )

        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
            backend="trtllm",
        )

        out = slice_nvfp4_output(out, output_size)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

MarlinMxfp8LinearKernel

Bases: Mxfp8LinearKernel

MXFP8 W8A16 GEMM via Marlin (SM80+).

Source code in vllm/model_executor/kernels/linear/mxfp8/marlin.py
class MarlinMxfp8LinearKernel(Mxfp8LinearKernel):
    """MXFP8 W8A16 GEMM via Marlin (SM80+)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
            is_fp8_marlin_supported,
        )

        if is_fp8_marlin_supported():
            return True, None
        return False, "Marlin FP8 not available"

    @classmethod
    def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
            prepare_mxfp8_layer_for_marlin,
        )

        prepare_mxfp8_layer_for_marlin(layer)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
            apply_mxfp8_marlin_linear,
        )

        return apply_mxfp8_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )

MarlinNvFp4LinearKernel

Bases: NvFp4LinearKernel

NVFP4 weight-only GEMM via Marlin (W4A16).

Source code in vllm/model_executor/kernels/linear/nvfp4/marlin.py
class MarlinNvFp4LinearKernel(NvFp4LinearKernel):
    """NVFP4 weight-only GEMM via Marlin (W4A16)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if is_fp4_marlin_supported():
            return True, None
        return False, "Marlin FP4 not available"

    @classmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        logger.warning_once(
            "Your GPU does not have native support for FP4 computation but "
            "FP4 quantization is being used. Weight-only FP4 compression "
            "will be used leveraging the Marlin kernel. This may degrade "
            "performance for compute-heavy workloads."
        )
        prepare_fp4_layer_for_marlin(layer)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return apply_fp4_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_global_scale=layer.weight_global_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )

MxFp4LinearKernel

Bases: ABC

Base class for MXFP4 quantized linear kernels.

Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). The kernel selection mechanism iterates over registered subclasses in priority order,calling is_supported and can_implement to find the best match for the current hardware.

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
class MxFp4LinearKernel(ABC):
    """Base class for MXFP4 quantized linear kernels.

    Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
    The kernel selection mechanism iterates over registered subclasses in
    priority order,calling ``is_supported`` and ``can_implement`` to find the best
    match for the current hardware.
    """

    def __init__(self, config: MxFp4LinearLayerConfig) -> None:
        assert self.can_implement(config)[0]
        assert self.is_supported()[0]
        self.config = config

    @classmethod
    @abstractmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        """Return whether this kernel can run on the current platform."""
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        """Return whether this kernel can handle *config*."""
        raise NotImplementedError

    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Transform weights into the format required by this kernel.

        Called once after checkpoint weights have been loaded onto the
        device.  Implementations should repack / swizzle / pad weights
        and scales in-place on *layer*.
        """
        raise NotImplementedError

    @abstractmethod
    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Run the quantized GEMM."""
        raise NotImplementedError

apply_weights abstractmethod

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor

Run the quantized GEMM.

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
@abstractmethod
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    """Run the quantized GEMM."""
    raise NotImplementedError

can_implement abstractmethod classmethod

can_implement(
    config: MxFp4LinearLayerConfig,
) -> tuple[bool, str | None]

Return whether this kernel can handle config.

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
@classmethod
@abstractmethod
def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
    """Return whether this kernel can handle *config*."""
    raise NotImplementedError

is_supported abstractmethod classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]

Return whether this kernel can run on the current platform.

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
@classmethod
@abstractmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    """Return whether this kernel can run on the current platform."""
    raise NotImplementedError

process_weights_after_loading abstractmethod

process_weights_after_loading(layer: Module) -> None

Transform weights into the format required by this kernel.

Called once after checkpoint weights have been loaded onto the device. Implementations should repack / swizzle / pad weights and scales in-place on layer.

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Transform weights into the format required by this kernel.

    Called once after checkpoint weights have been loaded onto the
    device.  Implementations should repack / swizzle / pad weights
    and scales in-place on *layer*.
    """
    raise NotImplementedError

MxFp4LinearLayerConfig dataclass

Configuration for an MXFP4 linear layer.

All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per byte) and per-block weight scales (group size 32).

Source code in vllm/model_executor/kernels/linear/mxfp4/base.py
@dataclass
class MxFp4LinearLayerConfig:
    """Configuration for an MXFP4 linear layer.

    All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
    byte) and per-block weight scales (group size 32).
    """

    pass

Mxfp8LinearLayerConfig dataclass

Configuration for an MXFP8 linear layer.

All MXFP8 layers share the same structure: FP8-E4M3 weights with uint8 (E8M0) per-block scales at block size 32.

Source code in vllm/model_executor/kernels/linear/mxfp8/Mxfp8LinearKernel.py
@dataclass
class Mxfp8LinearLayerConfig:
    """Configuration for an MXFP8 linear layer.

    All MXFP8 layers share the same structure: FP8-E4M3 weights with
    uint8 (E8M0) per-block scales at block size 32.
    """

    pass

NvFp4LinearKernel

Bases: ABC

Base class for NVFP4 quantized linear kernels.

Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). The kernel selection mechanism iterates over registered subclasses in priority order,calling is_supported and can_implement to find the best match for the current hardware.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
class NvFp4LinearKernel(ABC):
    """Base class for NVFP4 quantized linear kernels.

    Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
    The kernel selection mechanism iterates over registered subclasses in
    priority order,calling ``is_supported`` and ``can_implement`` to find the best
    match for the current hardware.
    """

    def __init__(self, config: NvFp4LinearLayerConfig) -> None:
        assert self.can_implement(config)[0]
        assert self.is_supported()[0]
        self.config = config

    @classmethod
    @abstractmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        """Return whether this kernel can run on the current platform."""
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
        """Return whether this kernel can handle *config*."""
        raise NotImplementedError

    @abstractmethod
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Transform weights into the format required by this kernel.

        Called once after checkpoint weights have been loaded onto the
        device.  Implementations should repack / swizzle / pad weights
        and scales in-place on *layer*.
        """
        raise NotImplementedError

    @abstractmethod
    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Run the quantized GEMM."""
        raise NotImplementedError

apply_weights abstractmethod

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor

Run the quantized GEMM.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
@abstractmethod
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    """Run the quantized GEMM."""
    raise NotImplementedError

can_implement abstractmethod classmethod

can_implement(
    config: NvFp4LinearLayerConfig,
) -> tuple[bool, str | None]

Return whether this kernel can handle config.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
@classmethod
@abstractmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
    """Return whether this kernel can handle *config*."""
    raise NotImplementedError

is_supported abstractmethod classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]

Return whether this kernel can run on the current platform.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
@classmethod
@abstractmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    """Return whether this kernel can run on the current platform."""
    raise NotImplementedError

process_weights_after_loading abstractmethod

process_weights_after_loading(layer: Module) -> None

Transform weights into the format required by this kernel.

Called once after checkpoint weights have been loaded onto the device. Implementations should repack / swizzle / pad weights and scales in-place on layer.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Transform weights into the format required by this kernel.

    Called once after checkpoint weights have been loaded onto the
    device.  Implementations should repack / swizzle / pad weights
    and scales in-place on *layer*.
    """
    raise NotImplementedError

NvFp4LinearLayerConfig dataclass

Configuration for an NVFP4 linear layer.

All NVFP4 layers share the same structure: packed uint8 weights (2 FP4 values per byte), FP8-E4M3 per-block weight scales (group size 16), and scalar global scales for both weights and activations.

Source code in vllm/model_executor/kernels/linear/nvfp4/base.py
@dataclass
class NvFp4LinearLayerConfig:
    """Configuration for an NVFP4 linear layer.

    All NVFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
    byte), FP8-E4M3 per-block weight scales (group size 16), and scalar global
    scales for both weights and activations.
    """

    pass

TritonW4A16LinearKernel

Bases: MPLinearKernel

Triton-based W4A16 GEMM kernel for ROCm (MI300 and newer).

Supports GPTQ-format int4 weights (uint4b8 symmetric, uint4 asymmetric) with grouped quantization. Weight tensors are transposed from the compressed-tensors checkpoint layout to the kernel's [K, N//8] layout.

Source code in vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
class TritonW4A16LinearKernel(MPLinearKernel):
    """
    Triton-based W4A16 GEMM kernel for ROCm (MI300 and newer).

    Supports GPTQ-format int4 weights (uint4b8 symmetric, uint4 asymmetric)
    with grouped quantization. Weight tensors are transposed from the
    compressed-tensors checkpoint layout to the kernel's [K, N//8] layout.
    """

    SUPPORTED_QUANT_TYPES = TRITON_W4A16_SUPPORTED_QUANT_TYPES

    @classmethod
    def get_min_capability(cls) -> int:
        # Triton handles capability checks itself
        return 0

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        if not (current_platform.is_rocm() or current_platform.is_cuda()):
            return False, "TritonW4A16LinearKernel requires CUDA or ROCm"

        if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
            return (
                False,
                f"Quant type {c.weight_type} not supported; "
                f"supported: {cls.SUPPORTED_QUANT_TYPES}",
            )

        if c.act_type not in (torch.float16, torch.bfloat16):
            return False, "Only float16/bfloat16 activations are supported"

        N = c.partition_weight_shape[1]
        if N % 8 != 0:
            return (
                False,
                f"Output features ({N}) must be divisible by 8 "
                "(8 int4 values packed per int32)",
            )

        if c.has_g_idx:
            return (
                False,
                "Activation reordering (g_idx) is not supported by "
                "TritonW4A16LinearKernel",
            )

        gs = c.group_size
        if (
            gs not in TRITON_W4A16_SUPPORTED_GROUP_SIZES
            and gs != c.full_weight_shape[0]
        ):
            return (
                False,
                f"Group size {gs} not supported; "
                f"supported: {TRITON_W4A16_SUPPORTED_GROUP_SIZES} "
                f"or full K ({c.full_weight_shape[0]})",
            )

        K = c.partition_weight_shape[0]
        eff_gs = gs if gs != -1 else K
        if K % eff_gs != 0:
            return (False, f"Input features {K} not divisible by group size {eff_gs}")

        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """
        Convert compressed-tensors checkpoint layout to kernel layout.

        Checkpoint (from compressed_tensors_wNa16.create_weights):
          weight_packed:     [N, K//8]  int32   input_dim=1, output_dim=0, packed_dim=1
          weight_scale:      [N, K//G]  fp16    input_dim=1, output_dim=0
          weight_zero_point: [N//8, K//G] int32  output_dim=0, packed_dim=0

        Kernel needs:
          qweight: [K, N//8]  int32   (transpose weight_packed)
          scales:  [K//G, N]  fp16    (transpose weight_scale)
          qzeros:  [K//G, N//8] int32 (transpose weight_zero_point)
        """

        # ---- Transform qweight: [N, K//8] → [K//8, N] → back to [K, N//8] ----
        # permute_param_layout_(x, input_dim=0, output_dim=1) rearranges so that
        # the input(K) dimension is at physical dim 0 and output(N) at dim 1.
        # Checkpoint has input_dim=1, output_dim=0, packed_dim=1 (K is packed).
        # After permute we get [K//8, N] (K packed at dim 0, N at dim 1).
        # The kernel wants [K, N//8] (K at dim 0, N packed at dim 1), so we
        # then transpose: [K//8, N].T = [N, K//8] — that's not right.
        #
        # Actually we need to change WHAT is packed:
        #   Original packing: K packed into K//8 (8 K-values per int32)
        #   Kernel packing:   N packed into N//8 (8 N-values per int32)
        # These require a full repack, not just a transpose.
        #
        # Simple approach: unpack → transpose the full [N, K] → repack as [K, N//8].
        # This is done CPU-side at load time (one-time cost).
        def repack_w_q(x: BasevLLMParameter) -> BasevLLMParameter:
            # x.data is [N, K//8] int32, K packed (GPTQ checkpoint format)
            # Step 1: bring to [N, K//8] with output(N) at dim 0
            permute_param_layout_(x, input_dim=1, output_dim=0, packed_dim=1)
            w = x.data  # [N, K//8] int32

            N_dim, K8 = w.shape
            K_dim = K8 * 8
            # Step 2: unpack to [N, K] int32 (vectorized)
            shifts = torch.arange(8, device=w.device, dtype=torch.int32) * 4
            w_unpacked = ((w.unsqueeze(-1) >> shifts) & 0xF).reshape(N_dim, K_dim)
            # Step 3: transpose to [K, N] int32
            w_KN = w_unpacked.t().contiguous()
            # Step 4: repack N into N//8 int32 values → [K, N//8] (vectorized)
            N8 = N_dim // 8
            w_repacked = torch.sum(
                (w_KN.view(K_dim, N8, 8) & 0xF) << shifts,
                dim=2,
                dtype=torch.int32,
            )
            x.data = w_repacked.contiguous()
            return x

        def repack_w_s(x: BasevLLMParameter) -> BasevLLMParameter:
            # x.data is [N, K//G] fp16, bring to [K//G, N]
            permute_param_layout_(x, input_dim=1, output_dim=0)
            x.data = x.data.t().contiguous()
            return x

        self._transform_param(layer, self.w_q_name, repack_w_q)
        self._transform_param(layer, self.w_s_name, repack_w_s)

        if self.w_zp_name is not None:
            zp = getattr(layer, self.w_zp_name, None)
            if zp is not None:
                # Checkpoint: [N//8, K//G] int32 (N packed at dim 0, K//G at dim 1)
                # Kernel needs: [K//G, N//8] — just transpose
                replace_parameter(
                    layer,
                    self.w_zp_name,
                    torch.nn.Parameter(zp.data.t().contiguous(), requires_grad=False),
                )

    def apply_weights(
        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        c = self.config
        w_q, w_s, w_zp, _ = self._get_weight_params(layer)

        x_2d = x.reshape(-1, x.shape[-1]).contiguous()
        out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)

        K = c.partition_weight_shape[0]
        group_size = c.group_size if c.group_size != -1 else K

        # For symmetric types (uint4b8), use the scalar bias; no zeros tensor
        zp_bias = c.weight_type.bias if c.weight_type.has_bias() else 0

        output = triton_w4a16_gemm(
            a=x_2d,
            b_q=w_q,
            scales=w_s,
            qzeros=w_zp,
            group_size=group_size,
            zp_bias=zp_bias,
        )

        if bias is not None:
            output.add_(bias)

        return output.reshape(out_shape)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Convert compressed-tensors checkpoint layout to kernel layout.

Checkpoint (from compressed_tensors_wNa16.create_weights): weight_packed: [N, K//8] int32 input_dim=1, output_dim=0, packed_dim=1 weight_scale: [N, K//G] fp16 input_dim=1, output_dim=0 weight_zero_point: [N//8, K//G] int32 output_dim=0, packed_dim=0

Kernel needs

qweight: [K, N//8] int32 (transpose weight_packed) scales: [K//G, N] fp16 (transpose weight_scale) qzeros: [K//G, N//8] int32 (transpose weight_zero_point)

Source code in vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """
    Convert compressed-tensors checkpoint layout to kernel layout.

    Checkpoint (from compressed_tensors_wNa16.create_weights):
      weight_packed:     [N, K//8]  int32   input_dim=1, output_dim=0, packed_dim=1
      weight_scale:      [N, K//G]  fp16    input_dim=1, output_dim=0
      weight_zero_point: [N//8, K//G] int32  output_dim=0, packed_dim=0

    Kernel needs:
      qweight: [K, N//8]  int32   (transpose weight_packed)
      scales:  [K//G, N]  fp16    (transpose weight_scale)
      qzeros:  [K//G, N//8] int32 (transpose weight_zero_point)
    """

    # ---- Transform qweight: [N, K//8] → [K//8, N] → back to [K, N//8] ----
    # permute_param_layout_(x, input_dim=0, output_dim=1) rearranges so that
    # the input(K) dimension is at physical dim 0 and output(N) at dim 1.
    # Checkpoint has input_dim=1, output_dim=0, packed_dim=1 (K is packed).
    # After permute we get [K//8, N] (K packed at dim 0, N at dim 1).
    # The kernel wants [K, N//8] (K at dim 0, N packed at dim 1), so we
    # then transpose: [K//8, N].T = [N, K//8] — that's not right.
    #
    # Actually we need to change WHAT is packed:
    #   Original packing: K packed into K//8 (8 K-values per int32)
    #   Kernel packing:   N packed into N//8 (8 N-values per int32)
    # These require a full repack, not just a transpose.
    #
    # Simple approach: unpack → transpose the full [N, K] → repack as [K, N//8].
    # This is done CPU-side at load time (one-time cost).
    def repack_w_q(x: BasevLLMParameter) -> BasevLLMParameter:
        # x.data is [N, K//8] int32, K packed (GPTQ checkpoint format)
        # Step 1: bring to [N, K//8] with output(N) at dim 0
        permute_param_layout_(x, input_dim=1, output_dim=0, packed_dim=1)
        w = x.data  # [N, K//8] int32

        N_dim, K8 = w.shape
        K_dim = K8 * 8
        # Step 2: unpack to [N, K] int32 (vectorized)
        shifts = torch.arange(8, device=w.device, dtype=torch.int32) * 4
        w_unpacked = ((w.unsqueeze(-1) >> shifts) & 0xF).reshape(N_dim, K_dim)
        # Step 3: transpose to [K, N] int32
        w_KN = w_unpacked.t().contiguous()
        # Step 4: repack N into N//8 int32 values → [K, N//8] (vectorized)
        N8 = N_dim // 8
        w_repacked = torch.sum(
            (w_KN.view(K_dim, N8, 8) & 0xF) << shifts,
            dim=2,
            dtype=torch.int32,
        )
        x.data = w_repacked.contiguous()
        return x

    def repack_w_s(x: BasevLLMParameter) -> BasevLLMParameter:
        # x.data is [N, K//G] fp16, bring to [K//G, N]
        permute_param_layout_(x, input_dim=1, output_dim=0)
        x.data = x.data.t().contiguous()
        return x

    self._transform_param(layer, self.w_q_name, repack_w_q)
    self._transform_param(layer, self.w_s_name, repack_w_s)

    if self.w_zp_name is not None:
        zp = getattr(layer, self.w_zp_name, None)
        if zp is not None:
            # Checkpoint: [N//8, K//G] int32 (N packed at dim 0, K//G at dim 1)
            # Kernel needs: [K//G, N//8] — just transpose
            replace_parameter(
                layer,
                self.w_zp_name,
                torch.nn.Parameter(zp.data.t().contiguous(), requires_grad=False),
            )

XPUMxFp8LinearKernel

Bases: Mxfp8LinearKernel

MXFP8 W8A8 GEMM on XPU.

Source code in vllm/model_executor/kernels/linear/mxfp8/xpu.py
class XPUMxFp8LinearKernel(Mxfp8LinearKernel):
    """MXFP8 W8A8 GEMM on XPU."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_xpu():
            return False, "XPUMxFp8 only support on XPU"
        return True, None

    @classmethod
    def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight_scale = layer.weight_scale.view(torch.float8_e8m0fnu)
        weight_scale = weight_scale.t().contiguous()
        replace_parameter(layer, "weight", layer.weight.t())
        replace_parameter(layer, "weight_scale", weight_scale.data)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        out_dtype = x.dtype
        x_fp8, x_scale = quant_mxfp8(x)
        return torch.ops._xpu_C.fp8_gemm(
            x_fp8,
            layer.weight,
            out_dtype,
            x_scale,
            layer.weight_scale,
            bias,
        )

XPUW4A8IntLinearKernel

Bases: MPLinearKernel

XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.

Weights are symmetric group-quantized int4 packed as uint4. Activations are dynamically quantized per-token to symmetric int8.

Source code in vllm/model_executor/kernels/linear/mixed_precision/xpu.py
class XPUW4A8IntLinearKernel(MPLinearKernel):
    """XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.

    Weights are symmetric group-quantized int4 packed as uint4.
    Activations are dynamically quantized per-token to symmetric int8.
    """

    @classmethod
    def get_min_capability(cls) -> int:
        return -1

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        if not current_platform.is_xpu():
            return False, "XPUW4A8Int only supported on XPU"
        if c.act_type not in (torch.bfloat16, torch.float16):
            return False, "XPUW4A8Int requires BF16/FP16 activations"
        if c.weight_type != scalar_types.int4:
            return (
                False,
                f"XPUW4A8Int requires int4 weights, got {c.weight_type}",
            )
        if c.zero_points:
            return False, "XPUW4A8Int only supports symmetric weight quantization"
        if c.group_size != -1 and c.group_size % 32 != 0:
            return (
                False,
                f"Group size ({c.group_size}) not supported by XPUW4A8Int, "
                "must be a multiple of 32",
            )
        in_size, out_size = c.partition_weight_shape
        if in_size % 8 != 0 or out_size % 8 != 0:
            return (
                False,
                f"in/out sizes ({in_size}, {out_size}) must be multiples of 8",
            )

        if c.act_type != torch.float16:
            logger.warning_once(
                "XPUW4A8IntLinearKernel is running with model dtype %s, "
                "but int4_gemm_w4a8 produces float16 output. Recommend "
                "setting --dtype float16 for best performance.",
                c.act_type,
            )

        return True, None

    def _pack_int4_weight(self, w: torch.Tensor) -> torch.Tensor:
        # w is [N, K] int8 with values in [-8, 7]
        w_u4 = w.to(torch.int32) + 8  # shift to [0, 15]
        w_u4 = w_u4.reshape(w.shape[0], w.shape[1] // 8, 8)  # [N, K/8, 8]
        shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device)
        packed = ((w_u4 & 0xF) << shifts[None, None, :]).sum(dim=2).to(torch.int32)
        return packed

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight_scale.data = layer.weight_scale.data.t().contiguous()

        device = layer.weight_packed.device
        # TODO: support asymmetric quantization
        weight_zero_point = torch.tensor([8], dtype=torch.int8, device=device)
        layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)

        # weight_packed is [out, in] int8, signed int4 values in [-8, 7]
        w = layer.weight_packed.data  # [out, in]

        # TODO: implement asym case
        packed = self._pack_int4_weight(w)  # [out, in/8] packed uint4

        replace_parameter(
            layer,
            self.w_q_name,
            torch.nn.Parameter(packed, requires_grad=False),
        )

        # Free the original unpacked int8 weight (still registered as "weight")
        # to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory.
        layer.register_parameter("weight", None)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        reshaped_x = x.reshape(-1, x.shape[-1])  # [M, K]
        from vllm._xpu_ops import xpu_ops as ops

        # TODO: static and asymmetric quantization case
        # Common code for CompressedTensorsW4A8Int does not read act symmetry data
        quant_x, x_scale, x_zero = ops.dynamic_per_token_int8_quant_ref(
            reshaped_x, True, 8
        )

        out = torch.ops._xpu_C.int4_gemm_w4a8(
            quant_x,
            x_scale,
            x_zero,
            layer.weight_packed.t(),
            layer.weight_scale,
            layer.weight_zero_point,
            self.config.group_size,
            None,  # g_idx not currently supported
            bias,
        )

        return out.to(x.dtype)

ZentorchInt8ScaledMMLinearKernel

Bases: Int8ScaledMMLinearKernel

Source code in vllm/model_executor/kernels/linear/scaled_mm/zentorch.py
class ZentorchInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_cpu():
            return False, "requires CPU."
        if not current_platform.is_zen_cpu():
            return False, "requires AMD Zen CPU."
        if not has_zentorch_op(["zentorch_dynamic_qlinear"]):
            return (
                False,
                "torch.ops.zentorch.zentorch_dynamic_qlinear is not registered.",
            )
        return True, None

    @classmethod
    def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        if c.is_static_input_scheme:
            return False, "requires dynamic activation quantization."
        if not c.input_symmetric:
            return False, "requires symmetric activation quantization."
        if not c.is_channelwise:
            return False, "requires per-channel weight quantization."
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Prepare weights for ``zentorch_dynamic_qlinear``.

        Keeps weight in [N, K] layout (int8, contiguous) and converts the
        per-channel weight scale to bf16 with shape ``(N,)``.
        """
        w_q_name, w_s_name, _, _, _ = self.layer_param_names
        weight = getattr(layer, w_q_name)
        n = weight.shape[0]
        replace_parameter(
            layer,
            w_q_name,
            torch.nn.Parameter(weight.data.contiguous(), requires_grad=False),
        )

        weight_scale = getattr(layer, w_s_name)
        ws = weight_scale.data
        if ws.dim() == 2 and ws.shape[-1] == 1:
            ws = ws.squeeze(-1)
        ws = ws.to(torch.bfloat16).contiguous()
        assert ws.shape == (n,), (
            f"[zen_cpu] expected weight scale shape ({n},), got {tuple(ws.shape)}"
        )

        replace_parameter(
            layer,
            w_s_name,
            torch.nn.Parameter(ws, requires_grad=False),
        )
        logger.info_once(
            "[zen_cpu] Using zentorch_dynamic_qlinear for W8A8 (dynamic-symmetric)"
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        w_q_name, w_s_name, _, _, _ = self.layer_param_names
        return torch.ops.zentorch.zentorch_dynamic_qlinear(
            x,
            getattr(layer, w_q_name),
            getattr(layer, w_s_name),
            bias,
            zentorch_op_name="zentorch::zentorch_dynamic_qlinear",
        )

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Prepare weights for zentorch_dynamic_qlinear.

Keeps weight in [N, K] layout (int8, contiguous) and converts the per-channel weight scale to bf16 with shape (N,).

Source code in vllm/model_executor/kernels/linear/scaled_mm/zentorch.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Prepare weights for ``zentorch_dynamic_qlinear``.

    Keeps weight in [N, K] layout (int8, contiguous) and converts the
    per-channel weight scale to bf16 with shape ``(N,)``.
    """
    w_q_name, w_s_name, _, _, _ = self.layer_param_names
    weight = getattr(layer, w_q_name)
    n = weight.shape[0]
    replace_parameter(
        layer,
        w_q_name,
        torch.nn.Parameter(weight.data.contiguous(), requires_grad=False),
    )

    weight_scale = getattr(layer, w_s_name)
    ws = weight_scale.data
    if ws.dim() == 2 and ws.shape[-1] == 1:
        ws = ws.squeeze(-1)
    ws = ws.to(torch.bfloat16).contiguous()
    assert ws.shape == (n,), (
        f"[zen_cpu] expected weight scale shape ({n},), got {tuple(ws.shape)}"
    )

    replace_parameter(
        layer,
        w_s_name,
        torch.nn.Parameter(ws, requires_grad=False),
    )
    logger.info_once(
        "[zen_cpu] Using zentorch_dynamic_qlinear for W8A8 (dynamic-symmetric)"
    )

ZentorchWNA16LinearKernel

Bases: CPUWNA16LinearKernel

W4A16 GPTQ kernel backed by torch.ops.zentorch.zentorch_woq_linear.

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
class ZentorchWNA16LinearKernel(CPUWNA16LinearKernel):
    """W4A16 GPTQ kernel backed by ``torch.ops.zentorch.zentorch_woq_linear``."""

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        ok, reason = super().can_implement(c)
        if not ok:
            return ok, reason

        if not current_platform.is_zen_cpu():
            return False, "ZentorchWNA16 requires an AMD Zen CPU."

        if not has_zentorch_op(["zentorch_woq_repack_weight", "zentorch_woq_linear"]):
            return (
                False,
                "torch.ops.zentorch.{zentorch_woq_repack_weight, "
                "zentorch_woq_linear} are not registered.",
            )

        if c.has_g_idx:
            return False, "ZentorchWNA16 does not support activation re-ordering."
        return True, None

    def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
        """Eligibility predicate for the zentorch W4A16 GPTQ fast path.

        Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
        with ``layer`` untouched).
        """
        if (
            self.w_gidx_name is not None
            and getattr(layer, self.w_gidx_name, None) is not None
        ) or (getattr(self.config, "has_g_idx", False)):
            return False

        weight_packed = getattr(layer, self.w_q_name, None)
        weight_scale = getattr(layer, self.w_s_name, None)
        if weight_packed is None or weight_scale is None:
            return False

        bits = self.config.weight_type.mantissa
        pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
        # 4-bit -> 8 values per int32;
        if pack_factor != 8:
            return False

        # GPTQ-only. AWQ packs along the output dim instead.
        in_dim = getattr(weight_packed, "input_dim", None)
        pk_dim = getattr(weight_packed, "packed_dim", None)
        if in_dim is None or pk_dim is None or in_dim != pk_dim:
            return False

        is_ct_format = in_dim == pk_dim == 1
        if not is_ct_format:
            return False

        if weight_packed.dim() != 2 or weight_scale.dim() != 2:
            return False

        # 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
        in_features = weight_packed.shape[1] * 8
        num_groups = weight_scale.shape[1]
        return num_groups > 0 and in_features % num_groups == 0

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Repack CT GPTQ weights into the zentorch WOQ layout.

        Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
        via ``super()`` when the layer doesn't satisfy
        ``_zentorch_woq_eligible``.

        On success, ``layer._zentorch_processed_weights`` is set to ``True``
        """
        if getattr(layer, "_zentorch_processed_weights", False):
            return

        if not self._zentorch_woq_eligible(layer):
            logger.info_once(
                "[zen_cpu] ZentorchWNA16 fast path not eligible for this "
                "layer (AWQ pack layout, g_idx, or non-int32 storage); "
                "falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
            )
            super().process_weights_after_loading(layer)
            return

        if (not self.config.zero_points) and (self.w_zp_name is not None):
            setattr(layer, self.w_zp_name, None)

        if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
            setattr(layer, self.w_gidx_name, None)

        weight_q = getattr(layer, self.w_q_name)
        weight_s = getattr(layer, self.w_s_name)
        weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
        weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s

        bits = self.config.weight_type.mantissa
        pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
        out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
        in_features = weight_packed.shape[1] * pack_factor
        original_shape = torch.Size([out_features, in_features])
        unpack_from_int32 = _import_unpack_from_int32()
        repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default

        weight_unpacked = unpack_from_int32(
            weight_packed,
            bits,
            original_shape,
            packed_dim=weight_q.packed_dim,
        )

        zp_param = (
            getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
        )
        needs_unsigned_offset = self.config.weight_type == scalar_types.uint4

        if needs_unsigned_offset:
            weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
        repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())

        if zp_param is None:
            zp_tc = None
        else:
            zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
            zp = unpack_from_int32(
                zp_tensor,
                bits,
                (out_features, num_groups),
                packed_dim=zp_param.packed_dim,
            )
            if needs_unsigned_offset:
                zp = (zp.to(torch.int32) + 8).clamp(0, 15)
            zp_tc = zp.to(torch.int8).t().contiguous()

        layer._zentorch_woq_packed = repacked.t()
        layer._zentorch_woq_scale = weight_scale.t().contiguous()
        layer._zentorch_woq_zero_point = zp_tc

        for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
            if param_name is None:
                continue
            param = getattr(layer, param_name, None)
            if param is None:
                continue
            if hasattr(param, "data"):
                param.data = torch.empty(0)
            else:
                setattr(layer, param_name, torch.empty(0))

        layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
        layer._zentorch_processed_weights = True
        logger.info_once(
            "[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
            "(weight_type=%s, has_zp=%s)",
            self.config.weight_type,
            zp_tc is not None,
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if getattr(layer, "_zentorch_processed_weights", False):
            return torch.ops.zentorch.zentorch_woq_linear.default(
                x,
                layer._zentorch_woq_packed,
                layer._zentorch_woq_scale,
                layer._zentorch_woq_zero_point,
                bias,
            )
        return super().apply_weights(layer, x, bias)

_zentorch_woq_eligible

_zentorch_woq_eligible(layer: Module) -> bool

Eligibility predicate for the zentorch W4A16 GPTQ fast path.

Constraints (any failure -> cpu_gemm_wna16 path via super() with layer untouched).

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
    """Eligibility predicate for the zentorch W4A16 GPTQ fast path.

    Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
    with ``layer`` untouched).
    """
    if (
        self.w_gidx_name is not None
        and getattr(layer, self.w_gidx_name, None) is not None
    ) or (getattr(self.config, "has_g_idx", False)):
        return False

    weight_packed = getattr(layer, self.w_q_name, None)
    weight_scale = getattr(layer, self.w_s_name, None)
    if weight_packed is None or weight_scale is None:
        return False

    bits = self.config.weight_type.mantissa
    pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
    # 4-bit -> 8 values per int32;
    if pack_factor != 8:
        return False

    # GPTQ-only. AWQ packs along the output dim instead.
    in_dim = getattr(weight_packed, "input_dim", None)
    pk_dim = getattr(weight_packed, "packed_dim", None)
    if in_dim is None or pk_dim is None or in_dim != pk_dim:
        return False

    is_ct_format = in_dim == pk_dim == 1
    if not is_ct_format:
        return False

    if weight_packed.dim() != 2 or weight_scale.dim() != 2:
        return False

    # 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
    in_features = weight_packed.shape[1] * 8
    num_groups = weight_scale.shape[1]
    return num_groups > 0 and in_features % num_groups == 0

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Repack CT GPTQ weights into the zentorch WOQ layout.

Falls back to CPUWNA16LinearKernel.process_weights_after_loading via super() when the layer doesn't satisfy _zentorch_woq_eligible.

On success, layer._zentorch_processed_weights is set to True

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Repack CT GPTQ weights into the zentorch WOQ layout.

    Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
    via ``super()`` when the layer doesn't satisfy
    ``_zentorch_woq_eligible``.

    On success, ``layer._zentorch_processed_weights`` is set to ``True``
    """
    if getattr(layer, "_zentorch_processed_weights", False):
        return

    if not self._zentorch_woq_eligible(layer):
        logger.info_once(
            "[zen_cpu] ZentorchWNA16 fast path not eligible for this "
            "layer (AWQ pack layout, g_idx, or non-int32 storage); "
            "falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
        )
        super().process_weights_after_loading(layer)
        return

    if (not self.config.zero_points) and (self.w_zp_name is not None):
        setattr(layer, self.w_zp_name, None)

    if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
        setattr(layer, self.w_gidx_name, None)

    weight_q = getattr(layer, self.w_q_name)
    weight_s = getattr(layer, self.w_s_name)
    weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
    weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s

    bits = self.config.weight_type.mantissa
    pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
    out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
    in_features = weight_packed.shape[1] * pack_factor
    original_shape = torch.Size([out_features, in_features])
    unpack_from_int32 = _import_unpack_from_int32()
    repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default

    weight_unpacked = unpack_from_int32(
        weight_packed,
        bits,
        original_shape,
        packed_dim=weight_q.packed_dim,
    )

    zp_param = (
        getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
    )
    needs_unsigned_offset = self.config.weight_type == scalar_types.uint4

    if needs_unsigned_offset:
        weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
    repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())

    if zp_param is None:
        zp_tc = None
    else:
        zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
        zp = unpack_from_int32(
            zp_tensor,
            bits,
            (out_features, num_groups),
            packed_dim=zp_param.packed_dim,
        )
        if needs_unsigned_offset:
            zp = (zp.to(torch.int32) + 8).clamp(0, 15)
        zp_tc = zp.to(torch.int8).t().contiguous()

    layer._zentorch_woq_packed = repacked.t()
    layer._zentorch_woq_scale = weight_scale.t().contiguous()
    layer._zentorch_woq_zero_point = zp_tc

    for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
        if param_name is None:
            continue
        param = getattr(layer, param_name, None)
        if param is None:
            continue
        if hasattr(param, "data"):
            param.data = torch.empty(0)
        else:
            setattr(layer, param_name, torch.empty(0))

    layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
    layer._zentorch_processed_weights = True
    logger.info_once(
        "[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
        "(weight_type=%s, has_zp=%s)",
        self.config.weight_type,
        zp_tc is not None,
    )

_filter_kernels_by_backend

_filter_kernels_by_backend(
    backend: str, kernels: list[type]
) -> list[type]

Filter a kernel priority list to only those matching the backend.

Source code in vllm/model_executor/kernels/linear/__init__.py
def _filter_kernels_by_backend(
    backend: str,
    kernels: list[type],
) -> list[type]:
    """Filter a kernel priority list to only those matching the backend."""
    backend_kernels = _LINEAR_BACKEND_KERNEL_MAP.get(backend, set())
    return [k for k in kernels if k in backend_kernels]

_get_linear_backend

_get_linear_backend() -> str

Get the linear_backend setting from the current vllm config.

Source code in vllm/model_executor/kernels/linear/__init__.py
def _get_linear_backend() -> str:
    """Get the linear_backend setting from the current vllm config."""
    from vllm.config import get_current_vllm_config_or_none

    config = get_current_vllm_config_or_none()
    if config is not None:
        return config.kernel_config.linear_backend
    return "auto"

choose_mp_linear_kernel

choose_mp_linear_kernel(
    config: MPLinearLayerConfig,
    compute_capability: int | None = None,
) -> type[MPLinearKernel]

Choose an MPLinearKernel that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance.

Parameters:

Name Type Description Default
config MPLinearLayerConfig

Description of the linear layer to be implemented.

required
compute_capability Optional[int]

The compute capability of the target device, if None uses current_platform to get the compute capability. Defaults to None.

None

Raises:

Type Description
ValueError

If no kernel can implement the given config.

Returns:

Type Description
type[MPLinearKernel]

type[MPLinearKernel]: Chosen kernel.

Source code in vllm/model_executor/kernels/linear/__init__.py
def choose_mp_linear_kernel(
    config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
    """
    Choose an MPLinearKernel that can implement the given config for the given
     compute capability. Attempts to choose the best kernel in terms of
     performance.

    Args:
        config (MPLinearLayerConfig): Description of the linear layer to be
            implemented.
        compute_capability (Optional[int], optional): The compute capability of
            the target device, if None uses `current_platform` to get
            the compute capability. Defaults to None.

    Raises:
        ValueError: If no kernel can implement the given config.

    Returns:
        type[MPLinearKernel]: Chosen kernel.
    """
    if compute_capability is None:
        if current_platform is None:
            raise ValueError("Cannot determine compute capability")
        _cc = current_platform.get_device_capability()
        if _cc is not None:
            compute_capability = _cc[0] * 10 + _cc[1]

    platform_kernels = _POSSIBLE_KERNELS[current_platform._enum]

    # Apply --linear-backend filtering when set.
    linear_backend = _get_linear_backend()
    if linear_backend != "auto":
        filtered = _filter_kernels_by_backend(linear_backend, platform_kernels)
        if not filtered:
            raise ValueError(
                f"--linear-backend={linear_backend} was requested but no "
                f"'{linear_backend}' kernel exists for mixed-precision layers."
            )
        platform_kernels = filtered

    failure_reasons = []
    for kernel in platform_kernels:
        if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
            failure_reasons.append(
                f" {kernel.__name__} disabled by environment variable"
            )
            continue
        if (
            compute_capability is not None
            and kernel.get_min_capability() > compute_capability
        ):
            failure_reasons.append(
                f"{kernel.__name__} requires capability "
                f"{kernel.get_min_capability()}, current compute "
                f" capability is {compute_capability}"
            )
            continue

        can_implement, failure_reason = kernel.can_implement(config)
        if can_implement:
            return kernel
        else:
            failure_reasons.append(
                f" {kernel.__name__} cannot implement due to: {failure_reason}"
            )

    raise ValueError(
        "Failed to find a kernel that can implement the "
        "WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
    )

choose_scaled_mm_linear_kernel

choose_scaled_mm_linear_kernel(
    config: _KernelConfigT,
    possible_kernels: dict[
        PlatformEnum, list[type[_KernelT]]
    ],
    compute_capability: int | None = None,
    force_kernel: type[_KernelT] | None = None,
) -> type[_KernelT]

Choose a _KernelT that can implement the given config for the given compute capability. Attempts to choose the best kernel in terms of performance.

Parameters:

Name Type Description Default
config _KernelConfigT

Description of the linear layer to be implemented.

required
possible_kernels dict[PlatformEnum, list[_KernelT]]

A dictionary of platforms and their list of possible kernels.

required
compute_capability Optional[int]

The compute capability of the target device, if None uses current_platform to get the compute capability. Defaults to None.

None
force_kernel Optional[type[_KernelT]]

An Optional forced kernel to override the possible_kernels if it can be implemented. If None, it will only try the possible kernels.

None

Raises:

Type Description
ValueError

If no kernel can implement the given config.

Returns:

Name Type Description
_KernelT type[_KernelT]

Chosen kernel.

Source code in vllm/model_executor/kernels/linear/__init__.py
def choose_scaled_mm_linear_kernel(
    config: _KernelConfigT,
    possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
    compute_capability: int | None = None,
    force_kernel: type[_KernelT] | None = None,
) -> type[_KernelT]:
    """
    Choose a _KernelT that can implement the given config for the
    given compute capability. Attempts to choose the best kernel in terms of
    performance.

    Args:
        config (_KernelConfigT): Description of the linear layer
            to be implemented.
        possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
            dictionary of platforms and their list of possible kernels.
        compute_capability (Optional[int], optional): The compute capability of
            the target device, if None uses `current_platform` to get the
            compute capability. Defaults to None.
        force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
            the possible_kernels if it can be implemented. If None, it will only try the
            possible kernels.

    Raises:
        ValueError: If no kernel can implement the given config.

    Returns:
        _KernelT: Chosen kernel.
    """

    failure_reason_list = []

    if force_kernel is not None:
        can_implement, failure_reason = is_supported_and_can_implement_kernel(
            force_kernel, config, compute_capability
        )
        if can_implement:
            return force_kernel

        logger.info_once(
            "Tried to force %s, but the kernel couldn't be implemented",
            force_kernel.__name__,
            scope="global",
        )

    platform_kernels = possible_kernels[current_platform._enum]

    # Apply --linear-backend filtering when set.
    linear_backend = _get_linear_backend()
    if linear_backend != "auto":
        filtered = _filter_kernels_by_backend(linear_backend, platform_kernels)
        if not filtered:
            raise ValueError(
                f"--linear-backend={linear_backend} was requested but no "
                f"'{linear_backend}' kernel exists for this layer type."
            )
        platform_kernels = filtered

    for kernel in platform_kernels:
        is_supported_and_can_implement, failure_reason = (
            is_supported_and_can_implement_kernel(kernel, config, compute_capability)
        )
        if is_supported_and_can_implement:
            return kernel
        failure_reason_list.append(failure_reason)

    raise ValueError(
        "Failed to find a kernel that can implement the "
        "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
    )

init_mxfp4_linear_kernel

init_mxfp4_linear_kernel() -> MxFp4LinearKernel

Select and instantiate the best MXFP4 linear kernel for the current platform.

Source code in vllm/model_executor/kernels/linear/__init__.py
def init_mxfp4_linear_kernel() -> MxFp4LinearKernel:
    """Select and instantiate the best MXFP4 linear kernel for the
    current platform."""
    linear_backend = _get_linear_backend()

    force_kernel: type[MxFp4LinearKernel] | None = None
    if linear_backend == "auto" and envs.VLLM_MXFP4_USE_MARLIN:
        force_kernel = MarlinMxFp4LinearKernel

    if force_kernel is not None:
        is_supported, reason = force_kernel.is_supported()
        if not is_supported:
            raise ValueError(
                f"Forced MXFP4 kernel {force_kernel.__name__} is not "
                f"supported: {reason}"
            )
        logger.info_once("Using %s for MXFP4 GEMM", force_kernel.__name__)
        return force_kernel(MxFp4LinearLayerConfig())

    platform = current_platform._enum
    possible = list(_POSSIBLE_MXFP4_KERNELS.get(platform, []))

    # Apply --linear-backend filtering when set.
    if linear_backend != "auto":
        filtered = _filter_kernels_by_backend(linear_backend, possible)
        if not filtered:
            raise ValueError(
                f"--linear-backend={linear_backend} was requested but no "
                f"'{linear_backend}' kernel exists for MXFP4 layers."
            )
        possible = filtered

    failure_reasons = []
    for kernel_cls in possible:
        if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
            failure_reasons.append(
                f" {kernel_cls.__name__} disabled by environment variable"
            )
            continue

        is_supported, reason = kernel_cls.is_supported()
        if not is_supported:
            failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
            continue

        logger.info_once("Using %s for MXFP4 GEMM", kernel_cls.__name__)
        return kernel_cls(MxFp4LinearLayerConfig())

    raise ValueError(
        "Failed to find a kernel that can implement the "
        "MXFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons)
    )

init_mxfp8_linear_kernel

init_mxfp8_linear_kernel() -> Mxfp8LinearKernel

Select and instantiate the best MXFP8 linear kernel for the current platform.

Source code in vllm/model_executor/kernels/linear/__init__.py
def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel:
    """Select and instantiate the best MXFP8 linear kernel for the
    current platform."""
    config = Mxfp8LinearLayerConfig()

    platform = current_platform._enum
    possible = list(_POSSIBLE_MXFP8_KERNELS.get(platform, []))

    # Apply --linear-backend filtering when set.
    linear_backend = _get_linear_backend()
    if linear_backend != "auto":
        filtered = _filter_kernels_by_backend(linear_backend, possible)
        if not filtered:
            raise ValueError(
                f"--linear-backend={linear_backend} was requested but no "
                f"'{linear_backend}' kernel exists for MXFP8 layers."
            )
        possible = filtered

    failure_reasons = []
    for kernel_cls in possible:
        if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
            failure_reasons.append(
                f" {kernel_cls.__name__} disabled by environment variable"
            )
            continue

        is_supported, reason = kernel_cls.is_supported()
        if not is_supported:
            failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
            continue

        can_implement, reason = kernel_cls.can_implement(config)
        if not can_implement:
            failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
            continue

        logger.info_once("Using %s for MXFP8 GEMM", kernel_cls.__name__)
        return kernel_cls(config)

    raise ValueError(
        "Failed to find a kernel that can implement the "
        "MXFP8 linear layer. Reasons: \n" + "\n".join(failure_reasons)
    )

init_nvfp4_linear_kernel

init_nvfp4_linear_kernel() -> NvFp4LinearKernel

Select and instantiate the best NVFP4 linear kernel for the current platform.

Source code in vllm/model_executor/kernels/linear/__init__.py
def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
    """Select and instantiate the best NVFP4 linear kernel for the
    current platform."""
    config = NvFp4LinearLayerConfig()

    # VLLM_BATCH_INVARIANT forces deterministic execution. Prefer the
    # batch-invariant CUTLASS implementation when available, otherwise fall
    # back to emulation. It overrides both --linear-backend and the deprecated
    # env vars below.
    force_kernel: type[NvFp4LinearKernel] | None = None
    linear_backend = _get_linear_backend()
    if envs.VLLM_BATCH_INVARIANT:
        bi_supported, reason = CutlassNvFp4LinearKernel.is_supported()
        if bi_supported:
            if linear_backend not in ("auto", "cutlass"):
                logger.warning_once(
                    "VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
                    "using the CUTLASS backend for deterministic execution.",
                    linear_backend,
                )
            else:
                logger.info_once(
                    "VLLM_BATCH_INVARIANT forces NVFP4 linear to use the "
                    "CUTLASS backend for deterministic execution."
                )
            force_kernel = CutlassNvFp4LinearKernel
        else:
            if linear_backend not in ("auto", "emulation"):
                logger.warning_once(
                    "VLLM_BATCH_INVARIANT overrides --linear-backend=%s; "
                    "using the emulation backend for deterministic execution.",
                    linear_backend,
                )
            logger.info_once(
                "VLLM_BATCH_INVARIANT is set but the batch-invariant NVFP4 "
                "kernel is not supported on this platform; falling back to "
                "emulation for deterministic execution. Reason: %s",
                reason,
            )
            force_kernel = EmulationNvFp4LinearKernel
    elif linear_backend == "auto":
        # Deprecated env-var overrides — only honoured when --linear-backend
        # is "auto". Deprecation warnings are emitted from vllm/envs.py.
        if envs.VLLM_USE_FBGEMM:
            force_kernel = FbgemmNvFp4LinearKernel
        elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
            force_kernel = EmulationNvFp4LinearKernel
        elif envs.VLLM_NVFP4_GEMM_BACKEND is not None:
            backend_name = envs.VLLM_NVFP4_GEMM_BACKEND
            force_kernel = _NVFP4_BACKEND_TO_KERNEL.get(backend_name)
            if force_kernel is None:
                raise ValueError(
                    f"Unknown VLLM_NVFP4_GEMM_BACKEND={backend_name!r}. "
                    f"Valid choices: "
                    f"{list(_NVFP4_BACKEND_TO_KERNEL.keys())}"
                )

    if force_kernel is not None:
        is_supported, reason = force_kernel.is_supported()
        if not is_supported:
            raise ValueError(
                f"Forced NVFP4 kernel {force_kernel.__name__} is not "
                f"supported: {reason}"
            )
        logger.info_once("Using %s for NVFP4 GEMM", force_kernel.__name__)
        return force_kernel(config)

    # Auto-select from registry (or --linear-backend filtered).
    platform = current_platform._enum
    possible = list(_POSSIBLE_NVFP4_KERNELS.get(platform, []))

    # Apply --linear-backend filtering when set.
    if linear_backend != "auto":
        filtered = _filter_kernels_by_backend(linear_backend, possible)
        if not filtered:
            raise ValueError(
                f"--linear-backend={linear_backend} was requested but no "
                f"'{linear_backend}' kernel exists for NVFP4 layers."
            )
        possible = filtered

    failure_reasons = []
    for kernel_cls in possible:
        if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
            failure_reasons.append(
                f" {kernel_cls.__name__} disabled by environment variable"
            )
            continue

        is_supported, reason = kernel_cls.is_supported()
        if not is_supported:
            failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
            continue

        can_implement, reason = kernel_cls.can_implement(config)
        if not can_implement:
            failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
            continue

        if kernel_cls is EmulationNvFp4LinearKernel and failure_reasons:
            logger.warning_once(
                "NVFP4 linear falling back to the slow and unoptimized "
                "emulation backend as no optimized backend is available "
                "(unavailable reasons:\n - %s\n). "
                "In case you expect one of these backends to be used, "
                "please verify your environment.",
                "\n - ".join(failure_reasons),
            )

        logger.info_once("Using %s for NVFP4 GEMM", kernel_cls.__name__)
        return kernel_cls(config)

    raise ValueError(
        "Failed to find a kernel that can implement the "
        "NVFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons)
    )

register_linear_kernel

register_linear_kernel(
    kernel_class: type,
    platform: PlatformEnum,
    kernel_type: str = "mp",
) -> None

Register a new linear kernel class to be considered in kernel selection.

Parameters:

Name Type Description Default
kernel_class type

The kernel class to register.

required
platform PlatformEnum

The platform for which this kernel is applicable.

required
kernel_type str

The type of the kernel, either "mp", "int8", or "fp8". Defaults to "mp".

'mp'

Raises:

Type Description
ValueError

If the kernel_type is not recognized.

Source code in vllm/model_executor/kernels/linear/__init__.py
def register_linear_kernel(
    kernel_class: type,
    platform: PlatformEnum,
    kernel_type: str = "mp",
) -> None:
    """
    Register a new linear kernel class to be considered in kernel selection.

    Args:
        kernel_class (type): The kernel class to register.
        platform (PlatformEnum): The platform for which this kernel is applicable.
        kernel_type (str): The type of the kernel, either "mp", "int8", or "fp8".
            Defaults to "mp".

    Raises:
        ValueError: If the kernel_type is not recognized.
    """
    if kernel_type == "mp":
        if platform not in _POSSIBLE_KERNELS:
            _POSSIBLE_KERNELS[platform] = []
        _POSSIBLE_KERNELS[platform].append(kernel_class)
    elif kernel_type == "int8":
        if platform not in _POSSIBLE_INT8_KERNELS:
            _POSSIBLE_INT8_KERNELS[platform] = []
        _POSSIBLE_INT8_KERNELS[platform].append(kernel_class)
    elif kernel_type == "fp8":
        if platform not in _POSSIBLE_FP8_KERNELS:
            _POSSIBLE_FP8_KERNELS[platform] = []
        _POSSIBLE_FP8_KERNELS[platform].append(kernel_class)
    elif kernel_type == "mxfp8":
        if platform not in _POSSIBLE_MXFP8_KERNELS:
            _POSSIBLE_MXFP8_KERNELS[platform] = []
        _POSSIBLE_MXFP8_KERNELS[platform].append(kernel_class)
    elif kernel_type == "nvfp4":
        if platform not in _POSSIBLE_NVFP4_KERNELS:
            _POSSIBLE_NVFP4_KERNELS[platform] = []
        _POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class)
    elif kernel_type == "mxfp4":
        if platform not in _POSSIBLE_MXFP4_KERNELS:
            _POSSIBLE_MXFP4_KERNELS[platform] = []
        _POSSIBLE_MXFP4_KERNELS[platform].append(kernel_class)
    else:
        raise ValueError(f"Unrecognized kernel type: {kernel_type}")