Skip to content

vllm.model_executor.layers.fused_moe.router.gate_linear

GateLinear

Bases: ReplicatedLinear

MoE gate linear layer with multi-tier GEMM dispatch:

  1. DSV3 specialized kernel (SM90+, fp32 out, M<=16, H=7168, E=256/384)
  2. fp32 specialized kernel (SM90+, bf16/fp32 in, fp32 out, M<=32, H=3072, E=256)
  3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 weight + fp32 out_dtype)
  4. F.linear via ReplicatedLinear (ultimate fallback)

The out_dtype attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization method which is only known later).

Source code in vllm/model_executor/layers/fused_moe/router/gate_linear.py
@PluggableLayer.register("gate_linear")
class GateLinear(ReplicatedLinear):
    """MoE gate linear layer with multi-tier GEMM dispatch:

    1. DSV3 specialized kernel (SM90+, fp32 out, M<=16, H=7168, E=256/384)
    2. fp32 specialized kernel  (SM90+, bf16/fp32 in, fp32 out,
       M<=32, H=3072, E=256)
    3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 weight + fp32 out_dtype)
    4. F.linear via ReplicatedLinear (ultimate fallback)

    The ``out_dtype`` attribute is mutable and can be set after init
    (e.g. when the required dtype depends on the expert quantization
    method which is only known later).
    """

    # Dimensions supported by the DSV3 specialized kernel
    DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
    DSV3_SUPPORTED_HIDDEN_SIZES = [7168]

    # Dimensions supported by the fp32 specialized kernel
    FP32_SUPPORTED_NUM_EXPERTS = [256]
    FP32_SUPPORTED_HIDDEN_SIZES = [3072]
    FP32_MAX_TOKENS = 32

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = False,
        out_dtype: torch.dtype | None = None,
        params_dtype: torch.dtype | None = None,
        force_fp32_compute: bool = False,
        prefix: str = "",
    ):
        is_hopper_or_blackwell = current_platform.is_device_capability(
            (9, 0)
        ) or current_platform.is_device_capability_family(100)
        can_use_specialized_kernels = (
            current_platform.is_cuda() and is_hopper_or_blackwell and not bias
        )

        # If fp32 compute is required and no specialized kernel is available,
        # store weights in fp32 so the fallback linear path computes in fp32.
        if force_fp32_compute and not can_use_specialized_kernels:
            params_dtype = torch.float32

        super().__init__(
            input_size,
            output_size,
            bias=bias,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=prefix,
        )
        self.out_dtype = out_dtype

        # DSV3 specialized kernel eligibility (SM90+, exact dims)
        self.allow_specialized_router_gemm = can_use_specialized_kernels
        self.allow_dsv3_router_gemm = (
            self.allow_specialized_router_gemm
            and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
            and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
        )

        # fp32 specialized kernel eligibility (SM90+, exact dims, fp32 weight)
        self.allow_fp32_router_gemm = (
            not bias
            and self.weight.dtype == torch.float32
            and current_platform.is_cuda()
            and is_hopper_or_blackwell
            and output_size in self.FP32_SUPPORTED_NUM_EXPERTS
            and input_size in self.FP32_SUPPORTED_HIDDEN_SIZES
        )

        # cuBLAS bf16→fp32 eligibility
        self.allow_cublas_router_gemm = (
            self.allow_specialized_router_gemm
            and self.weight.dtype == torch.bfloat16
            and self.out_dtype == torch.float32
        )

    def set_out_dtype(self, out_dtype: torch.dtype) -> None:
        """Set output dtype for the router logits after init.

        Useful when the required dtype depends on the expert quantization
        method which is only known after the gate is constructed.
        """
        if self.out_dtype is not None:
            raise ValueError("out_dtype has already been set")
        self.out_dtype = out_dtype

        if (
            not self.allow_cublas_router_gemm
            and self.allow_specialized_router_gemm
            and out_dtype == torch.float32
        ):
            self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16

    def forward(
        self, x: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
        # Tier 1: DSV3 specialized kernel
        if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
            output = ops.dsv3_router_gemm(
                hidden_states=x,
                router_weight=self.weight,
                output_dtype=self.out_dtype,
            )
            return output, None

        # Tier 2: fp32 specialized kernel (H=3072, E=256, M<=32)
        # Dispatch is wrapped in a custom op so that torch.compile/CUDA-graph
        # capture does not freeze the runtime num_tokens branch.
        if self.allow_fp32_router_gemm and x.dtype in (
            torch.float32,
            torch.bfloat16,
        ):
            output = torch.ops.vllm.fp32_router_gemm_dispatch(x, self.weight)
            return output, None

        # Tier 3: cuBLAS bf16→fp32
        if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
            output = torch.mm(x, self.weight.T, out_dtype=torch.float32)
            return output, None

        # Tier 4: F.linear (ReplicatedLinear)
        if self.out_dtype is not None and x.dtype != self.weight.dtype:
            x = x.to(self.weight.dtype)
        output, output_bias = super().forward(x)
        if self.out_dtype is not None and output.dtype != self.out_dtype:
            output = output.to(self.out_dtype)
        return output, output_bias

set_out_dtype

set_out_dtype(out_dtype: dtype) -> None

Set output dtype for the router logits after init.

Useful when the required dtype depends on the expert quantization method which is only known after the gate is constructed.

Source code in vllm/model_executor/layers/fused_moe/router/gate_linear.py
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
    """Set output dtype for the router logits after init.

    Useful when the required dtype depends on the expert quantization
    method which is only known after the gate is constructed.
    """
    if self.out_dtype is not None:
        raise ValueError("out_dtype has already been set")
    self.out_dtype = out_dtype

    if (
        not self.allow_cublas_router_gemm
        and self.allow_specialized_router_gemm
        and out_dtype == torch.float32
    ):
        self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16

fp32_router_gemm_dispatch_impl

fp32_router_gemm_dispatch_impl(
    x: Tensor, weight: Tensor
) -> Tensor

Dynamically run fp32 specialized gemm if num_tokens <= FP32_MAX_TOKENS, otherwise fall back to F.linear. This must be wrapped in a custom op because our torch.compile integration does not support runtime dispatching on num_tokens.

Source code in vllm/model_executor/layers/fused_moe/router/gate_linear.py
def fp32_router_gemm_dispatch_impl(
    x: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
    """
    Dynamically run fp32 specialized gemm if num_tokens <= FP32_MAX_TOKENS,
    otherwise fall back to F.linear.
    This must be wrapped in a custom op because our torch.compile integration
    does not support runtime dispatching on num_tokens.
    """
    if x.shape[0] <= _FP32_ROUTER_GEMM_MAX_TOKENS:
        return ops.fp32_router_gemm(x, weight)
    else:
        return torch.nn.functional.linear(x.float(), weight)