Skip to content

vllm.model_executor.kernels.linear.mixed_precision

Modules:

Name Description
rdna3_w4a16

W4A16 GPTQ kernel for AMD RDNA3 (gfx1100) — fp16 + bf16.

triton_w4a16

Triton-based W4A16 GEMM kernel for ROCm MI300.

xpu
zentorch

Zentorch W4A16 GPTQ weight-only-quantized linear kernel for AMD Zen CPUs.

TritonW4A16LinearKernel

Bases: MPLinearKernel

Triton-based W4A16 GEMM kernel for ROCm (MI300 and newer).

Supports GPTQ-format int4 weights (uint4b8 symmetric, uint4 asymmetric) with grouped quantization. Weight tensors are transposed from the compressed-tensors checkpoint layout to the kernel's [K, N//8] layout.

Source code in vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
class TritonW4A16LinearKernel(MPLinearKernel):
    """
    Triton-based W4A16 GEMM kernel for ROCm (MI300 and newer).

    Supports GPTQ-format int4 weights (uint4b8 symmetric, uint4 asymmetric)
    with grouped quantization. Weight tensors are transposed from the
    compressed-tensors checkpoint layout to the kernel's [K, N//8] layout.
    """

    SUPPORTED_QUANT_TYPES = TRITON_W4A16_SUPPORTED_QUANT_TYPES

    @classmethod
    def get_min_capability(cls) -> int:
        # Triton handles capability checks itself
        return 0

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        if not (current_platform.is_rocm() or current_platform.is_cuda()):
            return False, "TritonW4A16LinearKernel requires CUDA or ROCm"

        if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
            return (
                False,
                f"Quant type {c.weight_type} not supported; "
                f"supported: {cls.SUPPORTED_QUANT_TYPES}",
            )

        if c.act_type not in (torch.float16, torch.bfloat16):
            return False, "Only float16/bfloat16 activations are supported"

        N = c.partition_weight_shape[1]
        if N % 8 != 0:
            return (
                False,
                f"Output features ({N}) must be divisible by 8 "
                "(8 int4 values packed per int32)",
            )

        if c.has_g_idx:
            return (
                False,
                "Activation reordering (g_idx) is not supported by "
                "TritonW4A16LinearKernel",
            )

        gs = c.group_size
        if (
            gs not in TRITON_W4A16_SUPPORTED_GROUP_SIZES
            and gs != c.full_weight_shape[0]
        ):
            return (
                False,
                f"Group size {gs} not supported; "
                f"supported: {TRITON_W4A16_SUPPORTED_GROUP_SIZES} "
                f"or full K ({c.full_weight_shape[0]})",
            )

        K = c.partition_weight_shape[0]
        eff_gs = gs if gs != -1 else K
        if K % eff_gs != 0:
            return (False, f"Input features {K} not divisible by group size {eff_gs}")

        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """
        Convert compressed-tensors checkpoint layout to kernel layout.

        Checkpoint (from compressed_tensors_wNa16.create_weights):
          weight_packed:     [N, K//8]  int32   input_dim=1, output_dim=0, packed_dim=1
          weight_scale:      [N, K//G]  fp16    input_dim=1, output_dim=0
          weight_zero_point: [N//8, K//G] int32  output_dim=0, packed_dim=0

        Kernel needs:
          qweight: [K, N//8]  int32   (transpose weight_packed)
          scales:  [K//G, N]  fp16    (transpose weight_scale)
          qzeros:  [K//G, N//8] int32 (transpose weight_zero_point)
        """

        # ---- Transform qweight: [N, K//8] → [K//8, N] → back to [K, N//8] ----
        # permute_param_layout_(x, input_dim=0, output_dim=1) rearranges so that
        # the input(K) dimension is at physical dim 0 and output(N) at dim 1.
        # Checkpoint has input_dim=1, output_dim=0, packed_dim=1 (K is packed).
        # After permute we get [K//8, N] (K packed at dim 0, N at dim 1).
        # The kernel wants [K, N//8] (K at dim 0, N packed at dim 1), so we
        # then transpose: [K//8, N].T = [N, K//8] — that's not right.
        #
        # Actually we need to change WHAT is packed:
        #   Original packing: K packed into K//8 (8 K-values per int32)
        #   Kernel packing:   N packed into N//8 (8 N-values per int32)
        # These require a full repack, not just a transpose.
        #
        # Simple approach: unpack → transpose the full [N, K] → repack as [K, N//8].
        # This is done CPU-side at load time (one-time cost).
        def repack_w_q(x: BasevLLMParameter) -> BasevLLMParameter:
            # x.data is [N, K//8] int32, K packed (GPTQ checkpoint format)
            # Step 1: bring to [N, K//8] with output(N) at dim 0
            permute_param_layout_(x, input_dim=1, output_dim=0, packed_dim=1)
            w = x.data  # [N, K//8] int32

            N_dim, K8 = w.shape
            K_dim = K8 * 8
            # Step 2: unpack to [N, K] int32 (vectorized)
            shifts = torch.arange(8, device=w.device, dtype=torch.int32) * 4
            w_unpacked = ((w.unsqueeze(-1) >> shifts) & 0xF).reshape(N_dim, K_dim)
            # Step 3: transpose to [K, N] int32
            w_KN = w_unpacked.t().contiguous()
            # Step 4: repack N into N//8 int32 values → [K, N//8] (vectorized)
            N8 = N_dim // 8
            w_repacked = torch.sum(
                (w_KN.view(K_dim, N8, 8) & 0xF) << shifts,
                dim=2,
                dtype=torch.int32,
            )
            x.data = w_repacked.contiguous()
            return x

        def repack_w_s(x: BasevLLMParameter) -> BasevLLMParameter:
            # x.data is [N, K//G] fp16, bring to [K//G, N]
            permute_param_layout_(x, input_dim=1, output_dim=0)
            x.data = x.data.t().contiguous()
            return x

        self._transform_param(layer, self.w_q_name, repack_w_q)
        self._transform_param(layer, self.w_s_name, repack_w_s)

        if self.w_zp_name is not None:
            zp = getattr(layer, self.w_zp_name, None)
            if zp is not None:
                # Checkpoint: [N//8, K//G] int32 (N packed at dim 0, K//G at dim 1)
                # Kernel needs: [K//G, N//8] — just transpose
                replace_parameter(
                    layer,
                    self.w_zp_name,
                    torch.nn.Parameter(zp.data.t().contiguous(), requires_grad=False),
                )

    def apply_weights(
        self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        c = self.config
        w_q, w_s, w_zp, _ = self._get_weight_params(layer)

        x_2d = x.reshape(-1, x.shape[-1]).contiguous()
        out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)

        K = c.partition_weight_shape[0]
        group_size = c.group_size if c.group_size != -1 else K

        # For symmetric types (uint4b8), use the scalar bias; no zeros tensor
        zp_bias = c.weight_type.bias if c.weight_type.has_bias() else 0

        output = triton_w4a16_gemm(
            a=x_2d,
            b_q=w_q,
            scales=w_s,
            qzeros=w_zp,
            group_size=group_size,
            zp_bias=zp_bias,
        )

        if bias is not None:
            output.add_(bias)

        return output.reshape(out_shape)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Convert compressed-tensors checkpoint layout to kernel layout.

Checkpoint (from compressed_tensors_wNa16.create_weights): weight_packed: [N, K//8] int32 input_dim=1, output_dim=0, packed_dim=1 weight_scale: [N, K//G] fp16 input_dim=1, output_dim=0 weight_zero_point: [N//8, K//G] int32 output_dim=0, packed_dim=0

Kernel needs

qweight: [K, N//8] int32 (transpose weight_packed) scales: [K//G, N] fp16 (transpose weight_scale) qzeros: [K//G, N//8] int32 (transpose weight_zero_point)

Source code in vllm/model_executor/kernels/linear/mixed_precision/triton_w4a16.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """
    Convert compressed-tensors checkpoint layout to kernel layout.

    Checkpoint (from compressed_tensors_wNa16.create_weights):
      weight_packed:     [N, K//8]  int32   input_dim=1, output_dim=0, packed_dim=1
      weight_scale:      [N, K//G]  fp16    input_dim=1, output_dim=0
      weight_zero_point: [N//8, K//G] int32  output_dim=0, packed_dim=0

    Kernel needs:
      qweight: [K, N//8]  int32   (transpose weight_packed)
      scales:  [K//G, N]  fp16    (transpose weight_scale)
      qzeros:  [K//G, N//8] int32 (transpose weight_zero_point)
    """

    # ---- Transform qweight: [N, K//8] → [K//8, N] → back to [K, N//8] ----
    # permute_param_layout_(x, input_dim=0, output_dim=1) rearranges so that
    # the input(K) dimension is at physical dim 0 and output(N) at dim 1.
    # Checkpoint has input_dim=1, output_dim=0, packed_dim=1 (K is packed).
    # After permute we get [K//8, N] (K packed at dim 0, N at dim 1).
    # The kernel wants [K, N//8] (K at dim 0, N packed at dim 1), so we
    # then transpose: [K//8, N].T = [N, K//8] — that's not right.
    #
    # Actually we need to change WHAT is packed:
    #   Original packing: K packed into K//8 (8 K-values per int32)
    #   Kernel packing:   N packed into N//8 (8 N-values per int32)
    # These require a full repack, not just a transpose.
    #
    # Simple approach: unpack → transpose the full [N, K] → repack as [K, N//8].
    # This is done CPU-side at load time (one-time cost).
    def repack_w_q(x: BasevLLMParameter) -> BasevLLMParameter:
        # x.data is [N, K//8] int32, K packed (GPTQ checkpoint format)
        # Step 1: bring to [N, K//8] with output(N) at dim 0
        permute_param_layout_(x, input_dim=1, output_dim=0, packed_dim=1)
        w = x.data  # [N, K//8] int32

        N_dim, K8 = w.shape
        K_dim = K8 * 8
        # Step 2: unpack to [N, K] int32 (vectorized)
        shifts = torch.arange(8, device=w.device, dtype=torch.int32) * 4
        w_unpacked = ((w.unsqueeze(-1) >> shifts) & 0xF).reshape(N_dim, K_dim)
        # Step 3: transpose to [K, N] int32
        w_KN = w_unpacked.t().contiguous()
        # Step 4: repack N into N//8 int32 values → [K, N//8] (vectorized)
        N8 = N_dim // 8
        w_repacked = torch.sum(
            (w_KN.view(K_dim, N8, 8) & 0xF) << shifts,
            dim=2,
            dtype=torch.int32,
        )
        x.data = w_repacked.contiguous()
        return x

    def repack_w_s(x: BasevLLMParameter) -> BasevLLMParameter:
        # x.data is [N, K//G] fp16, bring to [K//G, N]
        permute_param_layout_(x, input_dim=1, output_dim=0)
        x.data = x.data.t().contiguous()
        return x

    self._transform_param(layer, self.w_q_name, repack_w_q)
    self._transform_param(layer, self.w_s_name, repack_w_s)

    if self.w_zp_name is not None:
        zp = getattr(layer, self.w_zp_name, None)
        if zp is not None:
            # Checkpoint: [N//8, K//G] int32 (N packed at dim 0, K//G at dim 1)
            # Kernel needs: [K//G, N//8] — just transpose
            replace_parameter(
                layer,
                self.w_zp_name,
                torch.nn.Parameter(zp.data.t().contiguous(), requires_grad=False),
            )

XPUW4A8IntLinearKernel

Bases: MPLinearKernel

XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.

Weights are symmetric group-quantized int4 packed as uint4. Activations are dynamically quantized per-token to symmetric int8.

Source code in vllm/model_executor/kernels/linear/mixed_precision/xpu.py
class XPUW4A8IntLinearKernel(MPLinearKernel):
    """XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.

    Weights are symmetric group-quantized int4 packed as uint4.
    Activations are dynamically quantized per-token to symmetric int8.
    """

    @classmethod
    def get_min_capability(cls) -> int:
        return -1

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        if not current_platform.is_xpu():
            return False, "XPUW4A8Int only supported on XPU"
        if c.act_type not in (torch.bfloat16, torch.float16):
            return False, "XPUW4A8Int requires BF16/FP16 activations"
        if c.weight_type != scalar_types.int4:
            return (
                False,
                f"XPUW4A8Int requires int4 weights, got {c.weight_type}",
            )
        if c.zero_points:
            return False, "XPUW4A8Int only supports symmetric weight quantization"
        if c.group_size != -1 and c.group_size % 32 != 0:
            return (
                False,
                f"Group size ({c.group_size}) not supported by XPUW4A8Int, "
                "must be a multiple of 32",
            )
        in_size, out_size = c.partition_weight_shape
        if in_size % 8 != 0 or out_size % 8 != 0:
            return (
                False,
                f"in/out sizes ({in_size}, {out_size}) must be multiples of 8",
            )

        if c.act_type != torch.float16:
            logger.warning_once(
                "XPUW4A8IntLinearKernel is running with model dtype %s, "
                "but int4_gemm_w4a8 produces float16 output. Recommend "
                "setting --dtype float16 for best performance.",
                c.act_type,
            )

        return True, None

    def _pack_int4_weight(self, w: torch.Tensor) -> torch.Tensor:
        # w is [N, K] int8 with values in [-8, 7]
        w_u4 = w.to(torch.int32) + 8  # shift to [0, 15]
        w_u4 = w_u4.reshape(w.shape[0], w.shape[1] // 8, 8)  # [N, K/8, 8]
        shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device)
        packed = ((w_u4 & 0xF) << shifts[None, None, :]).sum(dim=2).to(torch.int32)
        return packed

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        layer.weight_scale.data = layer.weight_scale.data.t().contiguous()

        device = layer.weight_packed.device
        # TODO: support asymmetric quantization
        weight_zero_point = torch.tensor([8], dtype=torch.int8, device=device)
        layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)

        # weight_packed is [out, in] int8, signed int4 values in [-8, 7]
        w = layer.weight_packed.data  # [out, in]

        # TODO: implement asym case
        packed = self._pack_int4_weight(w)  # [out, in/8] packed uint4

        replace_parameter(
            layer,
            self.w_q_name,
            torch.nn.Parameter(packed, requires_grad=False),
        )

        # Free the original unpacked int8 weight (still registered as "weight")
        # to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory.
        layer.register_parameter("weight", None)

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        reshaped_x = x.reshape(-1, x.shape[-1])  # [M, K]
        from vllm._xpu_ops import xpu_ops as ops

        # TODO: static and asymmetric quantization case
        # Common code for CompressedTensorsW4A8Int does not read act symmetry data
        quant_x, x_scale, x_zero = ops.dynamic_per_token_int8_quant_ref(
            reshaped_x, True, 8
        )

        out = torch.ops._xpu_C.int4_gemm_w4a8(
            quant_x,
            x_scale,
            x_zero,
            layer.weight_packed.t(),
            layer.weight_scale,
            layer.weight_zero_point,
            self.config.group_size,
            None,  # g_idx not currently supported
            bias,
        )

        return out.to(x.dtype)

ZentorchWNA16LinearKernel

Bases: CPUWNA16LinearKernel

W4A16 GPTQ kernel backed by torch.ops.zentorch.zentorch_woq_linear.

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
class ZentorchWNA16LinearKernel(CPUWNA16LinearKernel):
    """W4A16 GPTQ kernel backed by ``torch.ops.zentorch.zentorch_woq_linear``."""

    @classmethod
    def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
        ok, reason = super().can_implement(c)
        if not ok:
            return ok, reason

        if not current_platform.is_zen_cpu():
            return False, "ZentorchWNA16 requires an AMD Zen CPU."

        if not has_zentorch_op(["zentorch_woq_repack_weight", "zentorch_woq_linear"]):
            return (
                False,
                "torch.ops.zentorch.{zentorch_woq_repack_weight, "
                "zentorch_woq_linear} are not registered.",
            )

        if c.has_g_idx:
            return False, "ZentorchWNA16 does not support activation re-ordering."
        return True, None

    def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
        """Eligibility predicate for the zentorch W4A16 GPTQ fast path.

        Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
        with ``layer`` untouched).
        """
        if (
            self.w_gidx_name is not None
            and getattr(layer, self.w_gidx_name, None) is not None
        ) or (getattr(self.config, "has_g_idx", False)):
            return False

        weight_packed = getattr(layer, self.w_q_name, None)
        weight_scale = getattr(layer, self.w_s_name, None)
        if weight_packed is None or weight_scale is None:
            return False

        bits = self.config.weight_type.mantissa
        pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
        # 4-bit -> 8 values per int32;
        if pack_factor != 8:
            return False

        # GPTQ-only. AWQ packs along the output dim instead.
        in_dim = getattr(weight_packed, "input_dim", None)
        pk_dim = getattr(weight_packed, "packed_dim", None)
        if in_dim is None or pk_dim is None or in_dim != pk_dim:
            return False

        is_ct_format = in_dim == pk_dim == 1
        if not is_ct_format:
            return False

        if weight_packed.dim() != 2 or weight_scale.dim() != 2:
            return False

        # 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
        in_features = weight_packed.shape[1] * 8
        num_groups = weight_scale.shape[1]
        return num_groups > 0 and in_features % num_groups == 0

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Repack CT GPTQ weights into the zentorch WOQ layout.

        Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
        via ``super()`` when the layer doesn't satisfy
        ``_zentorch_woq_eligible``.

        On success, ``layer._zentorch_processed_weights`` is set to ``True``
        """
        if getattr(layer, "_zentorch_processed_weights", False):
            return

        if not self._zentorch_woq_eligible(layer):
            logger.info_once(
                "[zen_cpu] ZentorchWNA16 fast path not eligible for this "
                "layer (AWQ pack layout, g_idx, or non-int32 storage); "
                "falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
            )
            super().process_weights_after_loading(layer)
            return

        if (not self.config.zero_points) and (self.w_zp_name is not None):
            setattr(layer, self.w_zp_name, None)

        if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
            setattr(layer, self.w_gidx_name, None)

        weight_q = getattr(layer, self.w_q_name)
        weight_s = getattr(layer, self.w_s_name)
        weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
        weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s

        bits = self.config.weight_type.mantissa
        pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
        out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
        in_features = weight_packed.shape[1] * pack_factor
        original_shape = torch.Size([out_features, in_features])
        unpack_from_int32 = _import_unpack_from_int32()
        repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default

        weight_unpacked = unpack_from_int32(
            weight_packed,
            bits,
            original_shape,
            packed_dim=weight_q.packed_dim,
        )

        zp_param = (
            getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
        )
        needs_unsigned_offset = self.config.weight_type == scalar_types.uint4

        if needs_unsigned_offset:
            weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
        repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())

        if zp_param is None:
            zp_tc = None
        else:
            zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
            zp = unpack_from_int32(
                zp_tensor,
                bits,
                (out_features, num_groups),
                packed_dim=zp_param.packed_dim,
            )
            if needs_unsigned_offset:
                zp = (zp.to(torch.int32) + 8).clamp(0, 15)
            zp_tc = zp.to(torch.int8).t().contiguous()

        layer._zentorch_woq_packed = repacked.t()
        layer._zentorch_woq_scale = weight_scale.t().contiguous()
        layer._zentorch_woq_zero_point = zp_tc

        for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
            if param_name is None:
                continue
            param = getattr(layer, param_name, None)
            if param is None:
                continue
            if hasattr(param, "data"):
                param.data = torch.empty(0)
            else:
                setattr(layer, param_name, torch.empty(0))

        layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
        layer._zentorch_processed_weights = True
        logger.info_once(
            "[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
            "(weight_type=%s, has_zp=%s)",
            self.config.weight_type,
            zp_tc is not None,
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if getattr(layer, "_zentorch_processed_weights", False):
            return torch.ops.zentorch.zentorch_woq_linear.default(
                x,
                layer._zentorch_woq_packed,
                layer._zentorch_woq_scale,
                layer._zentorch_woq_zero_point,
                bias,
            )
        return super().apply_weights(layer, x, bias)

_zentorch_woq_eligible

_zentorch_woq_eligible(layer: Module) -> bool

Eligibility predicate for the zentorch W4A16 GPTQ fast path.

Constraints (any failure -> cpu_gemm_wna16 path via super() with layer untouched).

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
    """Eligibility predicate for the zentorch W4A16 GPTQ fast path.

    Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
    with ``layer`` untouched).
    """
    if (
        self.w_gidx_name is not None
        and getattr(layer, self.w_gidx_name, None) is not None
    ) or (getattr(self.config, "has_g_idx", False)):
        return False

    weight_packed = getattr(layer, self.w_q_name, None)
    weight_scale = getattr(layer, self.w_s_name, None)
    if weight_packed is None or weight_scale is None:
        return False

    bits = self.config.weight_type.mantissa
    pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
    # 4-bit -> 8 values per int32;
    if pack_factor != 8:
        return False

    # GPTQ-only. AWQ packs along the output dim instead.
    in_dim = getattr(weight_packed, "input_dim", None)
    pk_dim = getattr(weight_packed, "packed_dim", None)
    if in_dim is None or pk_dim is None or in_dim != pk_dim:
        return False

    is_ct_format = in_dim == pk_dim == 1
    if not is_ct_format:
        return False

    if weight_packed.dim() != 2 or weight_scale.dim() != 2:
        return False

    # 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
    in_features = weight_packed.shape[1] * 8
    num_groups = weight_scale.shape[1]
    return num_groups > 0 and in_features % num_groups == 0

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Repack CT GPTQ weights into the zentorch WOQ layout.

Falls back to CPUWNA16LinearKernel.process_weights_after_loading via super() when the layer doesn't satisfy _zentorch_woq_eligible.

On success, layer._zentorch_processed_weights is set to True

Source code in vllm/model_executor/kernels/linear/mixed_precision/zentorch.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Repack CT GPTQ weights into the zentorch WOQ layout.

    Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
    via ``super()`` when the layer doesn't satisfy
    ``_zentorch_woq_eligible``.

    On success, ``layer._zentorch_processed_weights`` is set to ``True``
    """
    if getattr(layer, "_zentorch_processed_weights", False):
        return

    if not self._zentorch_woq_eligible(layer):
        logger.info_once(
            "[zen_cpu] ZentorchWNA16 fast path not eligible for this "
            "layer (AWQ pack layout, g_idx, or non-int32 storage); "
            "falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
        )
        super().process_weights_after_loading(layer)
        return

    if (not self.config.zero_points) and (self.w_zp_name is not None):
        setattr(layer, self.w_zp_name, None)

    if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
        setattr(layer, self.w_gidx_name, None)

    weight_q = getattr(layer, self.w_q_name)
    weight_s = getattr(layer, self.w_s_name)
    weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
    weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s

    bits = self.config.weight_type.mantissa
    pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
    out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
    in_features = weight_packed.shape[1] * pack_factor
    original_shape = torch.Size([out_features, in_features])
    unpack_from_int32 = _import_unpack_from_int32()
    repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default

    weight_unpacked = unpack_from_int32(
        weight_packed,
        bits,
        original_shape,
        packed_dim=weight_q.packed_dim,
    )

    zp_param = (
        getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
    )
    needs_unsigned_offset = self.config.weight_type == scalar_types.uint4

    if needs_unsigned_offset:
        weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
    repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())

    if zp_param is None:
        zp_tc = None
    else:
        zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
        zp = unpack_from_int32(
            zp_tensor,
            bits,
            (out_features, num_groups),
            packed_dim=zp_param.packed_dim,
        )
        if needs_unsigned_offset:
            zp = (zp.to(torch.int32) + 8).clamp(0, 15)
        zp_tc = zp.to(torch.int8).t().contiguous()

    layer._zentorch_woq_packed = repacked.t()
    layer._zentorch_woq_scale = weight_scale.t().contiguous()
    layer._zentorch_woq_zero_point = zp_tc

    for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
        if param_name is None:
            continue
        param = getattr(layer, param_name, None)
        if param is None:
            continue
        if hasattr(param, "data"):
            param.data = torch.empty(0)
        else:
            setattr(layer, param_name, torch.empty(0))

    layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
    layer._zentorch_processed_weights = True
    logger.info_once(
        "[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
        "(weight_type=%s, has_zp=%s)",
        self.config.weight_type,
        zp_tc is not None,
    )