Skip to content

vllm.lora.ops.triton_ops.fused_moe_lora_op

_adjust_kernel_inputs

_adjust_kernel_inputs(
    num_active_loras: Tensor,
    sorted_token_ids: Tensor | None,
    expert_ids: Tensor,
)

helper function to adjust kernel inputs when sorted_token_ids is None

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
def _adjust_kernel_inputs(
    num_active_loras: torch.Tensor,  # CPU tensor [1], number of active LoRAs
    sorted_token_ids: torch.Tensor | None,
    expert_ids: torch.Tensor,
):
    """
    helper function to adjust kernel inputs when sorted_token_ids is None
    """
    if sorted_token_ids is None:
        stride_tl = 0
        stride_el = 0
        grid_lora_dim = 1
    else:
        stride_tl = sorted_token_ids.stride(0)
        stride_el = expert_ids.stride(0)
        grid_lora_dim = num_active_loras.item()
    return grid_lora_dim, stride_tl, stride_el

_fused_moe_lora_small_batch_kernel

_fused_moe_lora_small_batch_kernel(
    x_ptr,
    A_ptrs,
    B_ptrs,
    out_ptr,
    topk_weights_ptr,
    expert_ids_ptr,
    token_lora_mapping_ptr,
    adapter_enabled_ptr,
    N,
    K,
    top_k_num,
    max_loras,
    work_total,
    pair_slices,
    stride_xm,
    stride_xk,
    stride_A_lora,
    stride_A_expert,
    stride_A_r,
    stride_A_k,
    stride_B_lora,
    stride_B_expert,
    stride_B_n,
    stride_B_r,
    stride_om,
    stride_on,
    slice_n_offset,
    n_tiles_per_program,
    n_chunks_per_pair_slice,
    token_mapping_factor: constexpr,
    MUL_ROUTED_WEIGHT: constexpr,
    ADD_INPUTS: constexpr,
    BLOCK_R: constexpr,
    actual_rank: constexpr,
    BLOCK_N: constexpr,
    BLOCK_K: constexpr,
    NUM_SLICES: constexpr,
    EVEN_K: constexpr,
)

Persistent fused MoE-LoRA kernel for naive_block_assignment inputs.

Each program owns one (pair × slice × n_chunk) work item. A "chunk" covers n_tiles_per_program consecutive output-N tiles, all of which share a single shrink — so the rank-vector is computed once per program and the A weights for that (lora, expert, slice) are loaded once instead of n_tiles_per_program times.

The wrapper picks n_tiles_per_program to keep the grid close to 2*SM_count: at very small batch (work_total ≤ SM_count) the chunk size collapses to 1 and behaviour matches a per-tile GEMV; as batch grows the chunk grows so we trade some N-axis parallelism for shrink reuse. When work_total exceeds the launched grid, the outer stride loop drains the leftover work units serially.

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})
@triton.jit
def _fused_moe_lora_small_batch_kernel(
    # ---- pointers ----
    x_ptr,
    A_ptrs,
    B_ptrs,
    out_ptr,
    topk_weights_ptr,
    expert_ids_ptr,  # (num_tokens * top_k_num,)
    token_lora_mapping_ptr,  # (num_tokens,)
    adapter_enabled_ptr,
    # ---- dims ----
    N,
    K,
    top_k_num,
    max_loras,
    work_total,  # = pair_slices * n_chunks_per_pair_slice
    pair_slices,  # = num_tokens * top_k_num * NUM_SLICES
    # ---- strides ----
    stride_xm,
    stride_xk,
    stride_A_lora,
    stride_A_expert,
    stride_A_r,
    stride_A_k,
    stride_B_lora,
    stride_B_expert,
    stride_B_n,
    stride_B_r,
    stride_om,
    stride_on,
    # ---- scalar (runtime ints, NOT constexpr) ----
    # n_tiles_per_program / n_chunks_per_pair_slice are deliberately
    # runtime: each distinct value would otherwise trigger a fresh Triton
    # compile -> fresh kernel binary -> fresh CUDA graph instance per
    # batch size. Production traces showed that variant explosion adding
    # ~5.9k graph instantiations on top of legacy. Runtime args mean one
    # shared binary across all chunk sizes.
    slice_n_offset,
    n_tiles_per_program,
    n_chunks_per_pair_slice,
    # ---- constexpr ----
    token_mapping_factor: tl.constexpr,
    MUL_ROUTED_WEIGHT: tl.constexpr,
    ADD_INPUTS: tl.constexpr,
    BLOCK_R: tl.constexpr,
    actual_rank: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    NUM_SLICES: tl.constexpr,
    EVEN_K: tl.constexpr,
):
    """Persistent fused MoE-LoRA kernel for naive_block_assignment inputs.

    Each program owns one (pair × slice × n_chunk) work item. A "chunk"
    covers `n_tiles_per_program` consecutive output-N tiles, all of which
    share a single shrink — so the rank-vector is computed once per
    program and the A weights for that (lora, expert, slice) are loaded
    once instead of n_tiles_per_program times.

    The wrapper picks `n_tiles_per_program` to keep the grid close to
    2*SM_count: at very small batch (work_total ≤ SM_count) the chunk
    size collapses to 1 and behaviour matches a per-tile GEMV; as batch
    grows the chunk grows so we trade some N-axis parallelism for shrink
    reuse. When `work_total` exceeds the launched grid, the outer stride
    loop drains the leftover work units serially.
    """
    pid = tl.program_id(axis=0)
    num_programs = tl.num_programs(axis=0)

    offs_r = tl.arange(0, BLOCK_R)
    rank_mask = offs_r < actual_rank
    # Clamp OOB rank lanes so they address row 0 of A/B; the mask zeros
    # the loaded values. Required when BLOCK_R > actual_rank (e.g. rank=4
    # padded to 16) -- without clamping, tl.load would address the next
    # expert's memory.
    safe_offs_r = tl.where(rank_mask, offs_r, 0)
    offs_k = tl.arange(0, BLOCK_K)

    # Persistent stride loop: when grid < work_total each program walks
    # multiple work items. When grid == work_total the loop runs exactly
    # once and the kernel degenerates to the per-tile GEMV.
    for work_id in range(pid, work_total, num_programs):
        n_chunk_idx = work_id % n_chunks_per_pair_slice
        pair_slice_idx = work_id // n_chunks_per_pair_slice
        # NUM_SLICES is constexpr (typ. 1 or 2) so divmod folds.
        pair_idx = pair_slice_idx // NUM_SLICES
        slice_id = pair_slice_idx % NUM_SLICES

        # Resolve lora_id / expert_id; skip the body for inactive lanes.
        # Using a single `valid` flag instead of early `return` keeps the
        # outer stride loop alive — `return` would exit the whole program
        # and skip later work items assigned to this SM.
        token_idx = pair_idx // top_k_num
        lora_id = tl.load(token_lora_mapping_ptr + token_idx)
        valid = (lora_id >= 0) & (lora_id < max_loras)
        enabled = tl.load(adapter_enabled_ptr + tl.where(valid, lora_id, 0))
        valid = valid & (enabled != 0)
        expert_id = tl.load(expert_ids_ptr + pair_idx)
        valid = valid & (expert_id >= 0)

        if valid:
            cur_A_ptr = tl.load(A_ptrs + slice_id).to(
                tl.pointer_type(out_ptr.dtype.element_ty)
            )
            cur_B_ptr = tl.load(B_ptrs + slice_id).to(
                tl.pointer_type(out_ptr.dtype.element_ty)
            )
            A_base = cur_A_ptr + lora_id * stride_A_lora + expert_id * stride_A_expert
            B_base = cur_B_ptr + lora_id * stride_B_lora + expert_id * stride_B_expert

            x_row = pair_idx // token_mapping_factor
            x_row_ptr = x_ptr + x_row * stride_xm

            # SHRINK GEMV (once per program; reused across n_tiles_per_program
            # expand tiles below). Sum-reduction over BLOCK_K with fp32
            # accumulator — same precision path as the one_shot kernel.
            rank_vec = tl.zeros((BLOCK_R,), dtype=tl.float32)
            if EVEN_K:
                for kb in range(0, K, BLOCK_K):
                    cur_k = kb + offs_k
                    x_tile = tl.load(x_row_ptr + cur_k * stride_xk).to(tl.float32)
                    a_tile = tl.load(
                        A_base
                        + safe_offs_r[:, None] * stride_A_r
                        + cur_k[None, :] * stride_A_k,
                        mask=rank_mask[:, None],
                        other=0.0,
                    ).to(tl.float32)
                    rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1)
            else:
                for kb in range(0, K, BLOCK_K):
                    cur_k = kb + offs_k
                    k_mask = cur_k < K
                    x_tile = tl.load(
                        x_row_ptr + cur_k * stride_xk, mask=k_mask, other=0.0
                    ).to(tl.float32)
                    a_tile = tl.load(
                        A_base
                        + safe_offs_r[:, None] * stride_A_r
                        + cur_k[None, :] * stride_A_k,
                        mask=rank_mask[:, None] & k_mask[None, :],
                        other=0.0,
                    ).to(tl.float32)
                    rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1)

            # EXPAND: walk n_tiles_per_program consecutive output-N tiles
            # using the same rank_vec. The loop is a runtime range (not
            # tl.static_range) so a single compiled kernel handles every
            # chunk size — see the note on the kernel signature.
            n_tile_start = n_chunk_idx * n_tiles_per_program
            out_row_ptr = out_ptr + slice_id * slice_n_offset + pair_idx * stride_om

            if MUL_ROUTED_WEIGHT:
                moe_w = tl.load(topk_weights_ptr + pair_idx).to(tl.float32)

            for nt in range(n_tiles_per_program):
                n_lo = (n_tile_start + nt) * BLOCK_N
                if n_lo < N:
                    offs_n = n_lo + tl.arange(0, BLOCK_N)
                    n_mask = offs_n < N
                    b_tile = tl.load(
                        B_base
                        + offs_n[:, None] * stride_B_n
                        + safe_offs_r[None, :] * stride_B_r,
                        mask=n_mask[:, None] & rank_mask[None, :],
                        other=0.0,
                    ).to(tl.float32)
                    out_tile = tl.sum(b_tile * rank_vec[None, :], axis=1)

                    if MUL_ROUTED_WEIGHT:
                        out_tile = out_tile * moe_w

                    out_ptrs = out_row_ptr + offs_n * stride_on
                    if ADD_INPUTS:
                        prev = tl.load(out_ptrs, mask=n_mask, other=0.0).to(tl.float32)
                        tl.store(
                            out_ptrs,
                            (prev + out_tile).to(out_ptr.dtype.element_ty),
                            mask=n_mask,
                        )
                    else:
                        tl.store(
                            out_ptrs,
                            out_tile.to(out_ptr.dtype.element_ty),
                            mask=n_mask,
                        )

_get_expert_id

_get_expert_id(
    expert_ids_ptr,
    lora_id,
    pid_m,
    stride_el,
    max_loras,
    naive_block_assignment: constexpr,
)

Returns expert_id

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@triton.jit
def _get_expert_id(
    expert_ids_ptr,
    lora_id,
    pid_m,
    stride_el,
    max_loras,
    naive_block_assignment: tl.constexpr,
):
    """Returns expert_id"""
    if naive_block_assignment:
        return tl.load(expert_ids_ptr + pid_m)
    else:
        ind = lora_id * stride_el + pid_m
        return tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)

_get_lora_id

_get_lora_id(
    lora_ids,
    token_lora_mapping_ptr,
    lora_idx,
    pid_m,
    top_k_num,
    naive_block_assignment: constexpr,
)

Returns lora_id

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@triton.jit
def _get_lora_id(
    lora_ids,
    token_lora_mapping_ptr,
    lora_idx,
    pid_m,
    top_k_num,
    naive_block_assignment: tl.constexpr,
):
    """Returns lora_id"""
    if naive_block_assignment:
        token_idx = pid_m // top_k_num
        return tl.load(token_lora_mapping_ptr + token_idx)
    else:
        return tl.load(lora_ids + lora_idx)

_get_ptr

_get_ptr(lora_weights: list[Tensor], device: device)

_LORA_PTR_DICT collects the required information during profile_run, After this, it remains constant and subsequent usage is through LUT. Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
    """
    `_LORA_PTR_DICT` collects the required information during `profile_run`,
    After this, it remains constant and subsequent usage is through LUT.
    Refer to:
    https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
    """
    key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)

    if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
        return ptr_tensor

    tensor_ptrs = []
    for lora_weight in lora_weights:
        tensor_ptrs.append(lora_weight.data_ptr())
    ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)

    _LORA_PTR_DICT[key] = ptr_tensor
    return _LORA_PTR_DICT.get(key)

_get_token_offs

_get_token_offs(
    sorted_token_ids_ptr,
    lora_id,
    pid_m,
    offs,
    stride_tl,
    max_loras,
    num_valid_tokens,
    naive_block_assignment: constexpr,
    BLOCK_SIZE_M: constexpr,
)

Returns token offsets

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@triton.jit
def _get_token_offs(
    sorted_token_ids_ptr,
    lora_id,
    pid_m,
    offs,
    stride_tl,
    max_loras,
    num_valid_tokens,
    naive_block_assignment: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
):
    """Returns token offsets"""
    if naive_block_assignment:
        return tl.where(offs == 0, pid_m, num_valid_tokens)
    else:
        offs_token_id = pid_m * BLOCK_SIZE_M + offs
        token_ind = stride_tl * lora_id + offs_token_id
        return tl.load(
            sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
        )

_pick_small_batch_chunk

_pick_small_batch_chunk(
    pair_slices: int, N_tiles: int, sm_count: int
) -> int

Pick n_tiles_per_program so the launched grid stays near 2*SM_count.

Sizes for occupancy first (more programs in flight → better latency hiding for the K-loop A/x loads). Once the per-tile grid already exceeds 2*SM_count we increase the chunk size to amortise the shrink cost — at that point the GPU is saturated by per-program work and packing more tiles per program lets the rank_vec be reused.

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
def _pick_small_batch_chunk(pair_slices: int, N_tiles: int, sm_count: int) -> int:
    """Pick `n_tiles_per_program` so the launched grid stays near
    2*SM_count.

    Sizes for occupancy first (more programs in flight → better latency
    hiding for the K-loop A/x loads). Once the per-tile grid already
    exceeds 2*SM_count we increase the chunk size to amortise the shrink
    cost — at that point the GPU is saturated by per-program work and
    packing more tiles per program lets the rank_vec be reused.
    """
    target_grid = max(1, 2 * sm_count)
    total_work = pair_slices * N_tiles
    if total_work <= target_grid:
        return 1
    ntpp = (total_work + target_grid - 1) // target_grid
    return min(ntpp, N_tiles)

_run_fused_moe_lora_one_shot

_run_fused_moe_lora_one_shot(
    output: Tensor,
    qcurr_hidden_states: Tensor,
    lora_a_stacked: list[Tensor],
    lora_b_stacked: list[Tensor],
    topk_weights: Tensor,
    sorted_token_ids: Tensor | None,
    expert_ids: Tensor,
    num_tokens_post_padded: Tensor | None,
    token_lora_mapping: Tensor,
    max_lora_rank: int,
    top_k_num: int,
    lora_ids: Tensor,
    num_active_loras: Tensor,
    adapter_enabled: Tensor,
    mul_routed_weight: bool,
    block_size_m: int,
    add_inputs: bool = True,
) -> None

Fast-path wrapper: launches one fused shrink+expand kernel.

The shape contract matches _fused_moe_lora. output has shape (num_tokens, top_k_num, num_slices * N_per_slice). When add_inputs=True (default) the kernel reads-modifies-writes output in place; when add_inputs=False the kernel overwrites output with the LoRA delta only. The latter is used by the dual-stream path that sums LoRA into the base output on a separate stream.

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
def _run_fused_moe_lora_one_shot(
    output: torch.Tensor,
    qcurr_hidden_states: torch.Tensor,
    lora_a_stacked: list[torch.Tensor],
    lora_b_stacked: list[torch.Tensor],
    topk_weights: torch.Tensor,
    sorted_token_ids: torch.Tensor | None,
    expert_ids: torch.Tensor,
    num_tokens_post_padded: torch.Tensor | None,
    token_lora_mapping: torch.Tensor,
    max_lora_rank: int,
    top_k_num: int,
    lora_ids: torch.Tensor,
    num_active_loras: torch.Tensor,
    adapter_enabled: torch.Tensor,
    mul_routed_weight: bool,
    block_size_m: int,
    add_inputs: bool = True,
) -> None:
    """Fast-path wrapper: launches one fused shrink+expand kernel.

    The shape contract matches `_fused_moe_lora`. `output` has shape
    `(num_tokens, top_k_num, num_slices * N_per_slice)`. When
    `add_inputs=True` (default) the kernel reads-modifies-writes `output`
    in place; when `add_inputs=False` the kernel overwrites `output` with
    the LoRA delta only. The latter is used by the dual-stream path that
    sums LoRA into the base output on a separate stream.
    """
    num_slices = len(lora_a_stacked)
    device = qcurr_hidden_states.device

    A0 = lora_a_stacked[0]
    B0 = lora_b_stacked[0]
    max_loras_w = A0.shape[0]
    rank = A0.shape[2]
    K = A0.shape[3]
    N_per_slice = B0.shape[2]

    # rank padding is to next pow2 with a floor of 16 (tensor-core minimum
    # K-dim). Beyond 128 the (BLOCK_M, BLOCK_R) accumulator outgrows the
    # register file; rank tiling would be needed but is out of scope for
    # this kernel. Tried floor=32 to double MMA density per K-step but it
    # regressed across all M (+8 to +40%): the (64,32) fp32 accumulator +
    # widened B tile pushed register count past spill threshold, lowering
    # occupancy by more than the MMA gain saved.
    assert rank <= 128, (
        f"fused_moe_lora_one_shot supports max_lora_rank<=128; got rank={rank}"
    )
    BLOCK_R = max(triton.next_power_of_2(rank), 16)

    num_experts = A0.shape[1]
    naive = sorted_token_ids is None
    if sorted_token_ids is None:
        EM_grid = topk_weights.numel()
        BLOCK_M = 16
        stride_tl_ = 0
        stride_el = 0
        grid_lora_dim = 1
    else:
        EM_grid = sorted_token_ids.shape[1]
        # BLOCK_M must equal moe_lora_align_block_size's block_size. The
        # caller passes that explicitly; deriving it from tensor shapes is
        # unsafe because sorted_token_ids.shape[1] is the raw padded length
        # (not necessarily a multiple of block_size — e.g. OLMoE prefill
        # produces sorted=139200 with expert_ids=1088 and block_size=128).
        # tl.arange and tl.dot need block_size_m to be a power of 2 and at
        # least 16. The Python-side assertion gives a clearer error than
        # the cryptic Triton compile failure.
        assert block_size_m >= 16 and (block_size_m & (block_size_m - 1)) == 0, (
            f"shrink_block_size_m must be a power of 2 and >=16; got {block_size_m}"
        )
        BLOCK_M = block_size_m
        stride_tl_ = sorted_token_ids.stride(0)
        stride_el = expert_ids.stride(0)
        grid_lora_dim = int(num_active_loras.item())

    # Empty-work guards: the grid would otherwise have a zero dimension,
    # which Triton rejects. None of these is a hot path in production -- a
    # batch with zero tokens, an EM_grid of zero, or zero active LoRAs all
    # mean there's nothing to add to `output`.
    if EM_grid == 0 or grid_lora_dim == 0 or num_slices == 0:
        return

    token_mapping_factor = 1 if mul_routed_weight else top_k_num

    A_ptrs = _get_ptr(lora_a_stacked, device)
    B_ptrs = _get_ptr(lora_b_stacked, device)

    # Flatten (num_tokens, top_k) → flat_token axis. The kernel addresses
    # output via offs_token * stride_om, which is correct iff the dim-0 /
    # dim-1 strides collapse cleanly: stride(0) == top_k * stride(1). All
    # production callers pass contiguous output, so this always holds; the
    # explicit check guards against future regressions where a non-trivial
    # view (e.g. permute) would silently break in-place accumulation.
    assert output.dim() == 3, f"output must be 3-D, got {output.shape}"
    assert output.stride(0) == output.shape[1] * output.stride(1), (
        "fused_moe_lora_one_shot requires output.stride(0) == top_k*stride(1); "
        f"got shape={output.shape} strides={output.stride()}"
    )
    out_view = output.view(-1, output.shape[-1])
    M_blocks = triton.cdiv(EM_grid, BLOCK_M) if not naive else EM_grid

    # NPID_FACTOR heuristic: scale N-axis parallelism when base CTA count is
    # short of saturating the SM array. Cap by the cost of redundant shrink.
    sm_count = current_platform.num_compute_units(device.index)
    base_programs = max(M_blocks * num_slices * grid_lora_dim, 1)
    shrink_ratio = K / max(K + N_per_slice, 1)
    max_npid_by_budget = max(1, int(1.5 / max(shrink_ratio, 1e-3)) + 1)
    target = 2 * sm_count
    if base_programs >= int(1.5 * sm_count):
        npid = 1
    else:
        npid_occ = max(1, min(16, (target + base_programs - 1) // base_programs))
        npid = min(npid_occ, max_npid_by_budget)
    npid = max(1, min(npid, max(1, N_per_slice // 128)))

    # Robust defaults across the prefill regime (H100/H200/B200, bf16/fp16).
    # NPID > 1 is the small-M / under-saturated path -- more warps help
    # amortise the inner-N expand loop. ns=3 instead of 4: GB200 ncu showed
    # the 4-stage pipeline pushed register count to 168/thread and capped
    # achieved occupancy at ~17% (3 blocks/SM, register-bound); ns=3 frees
    # ~30 regs/thread which keeps a 4th block resident on small grids.
    # Tried BLOCK_N=64 for w13 (N=192) to avoid the half-wasted second
    # tile: regressed 11-29% because the "waste" was just masked stores
    # (cheap) and the extra iteration added load + index overhead.
    if npid > 1:
        block_n, nw, ns = 128, 8, 3
    else:
        block_n, nw, ns = 128, 4, 3
    # BLOCK_K choice: for hidden-sized K (≥256, i.e. the K=hidden_size
    # shrink input on w13) force BLOCK_K=128 -- the wider tile halves the
    # K-loop trip count and removes the scoreboard stalls that dominated
    # M=16-64 on GB200 (kernel time -13% to -37% vs the work_per_expert
    # heuristic which picked 64 for low-tokens-per-expert ratios). For
    # small-K shapes (e.g. w2 with K=192 where the down-proj reads the
    # MoE intermediate) keep the work_per_expert heuristic: BLOCK_K=128
    # would force the EVEN_K=False masked path and add no K-loop savings
    # (K/64=3 vs K/128=2 masked) while inflating per-program startup.
    if K >= 256:
        block_k = 128
    else:
        work_per_expert = topk_weights.numel() / max(num_experts, 1)
        block_k = 128 if work_per_expert >= 16 else 64

    grid = (M_blocks * npid, num_slices, grid_lora_dim)

    _fused_moe_lora_one_shot_kernel[grid](
        qcurr_hidden_states,
        A_ptrs,
        B_ptrs,
        out_view,
        topk_weights,
        sorted_token_ids,
        expert_ids,
        num_tokens_post_padded,
        token_lora_mapping,
        lora_ids,
        adapter_enabled,
        N_per_slice,
        K,
        topk_weights.numel(),
        top_k_num,
        max_loras_w,
        qcurr_hidden_states.stride(0),
        qcurr_hidden_states.stride(1),
        A0.stride(0),
        A0.stride(1),
        A0.stride(2),
        A0.stride(3),
        B0.stride(0),
        B0.stride(1),
        B0.stride(2),
        B0.stride(3),
        out_view.stride(0),
        out_view.stride(1),
        stride_tl_,
        stride_el,
        N_per_slice,
        token_mapping_factor=token_mapping_factor,
        naive_block_assignment=naive,
        MUL_ROUTED_WEIGHT=mul_routed_weight,
        BLOCK_M=BLOCK_M,
        BLOCK_R=BLOCK_R,
        actual_rank=rank,
        NPID_FACTOR=npid,
        BLOCK_N=block_n,
        BLOCK_K=block_k,
        ADD_INPUTS=add_inputs,
        num_warps=nw,
        num_stages=ns,
    )

_run_fused_moe_lora_small_batch

_run_fused_moe_lora_small_batch(
    output: Tensor,
    qcurr_hidden_states: Tensor,
    lora_a_stacked: list[Tensor],
    lora_b_stacked: list[Tensor],
    topk_weights: Tensor,
    expert_ids_flat: Tensor,
    token_lora_mapping: Tensor,
    top_k_num: int,
    adapter_enabled: Tensor,
    mul_routed_weight: bool,
    add_inputs: bool = True,
) -> None

Small-batch GEMV-style wrapper. Naive-block-assignment inputs only.

Shape contract matches _run_fused_moe_lora_one_shot: output is (num_tokens, top_k_num, num_slices * N_per_slice) with contiguous-style strides, expert_ids_flat is the flattened topk_ids of shape (num_tokens * top_k_num,), and the rank-padded LoRA weights live in lora_a_stacked / lora_b_stacked.

The kernel is persistent over (pair × slice × n_chunk) work items — each program does one shrink and reuses the rank vector across n_tiles_per_program expand tiles. The chunk size scales with the pair-slice count so very small batches keep per-tile parallelism while medium batches cut redundant shrinks.

Source code in vllm/lora/ops/triton_ops/fused_moe_lora_op.py
def _run_fused_moe_lora_small_batch(
    output: torch.Tensor,
    qcurr_hidden_states: torch.Tensor,
    lora_a_stacked: list[torch.Tensor],
    lora_b_stacked: list[torch.Tensor],
    topk_weights: torch.Tensor,
    expert_ids_flat: torch.Tensor,  # (num_tokens * top_k_num,)
    token_lora_mapping: torch.Tensor,
    top_k_num: int,
    adapter_enabled: torch.Tensor,
    mul_routed_weight: bool,
    add_inputs: bool = True,
) -> None:
    """Small-batch GEMV-style wrapper. Naive-block-assignment inputs only.

    Shape contract matches `_run_fused_moe_lora_one_shot`: `output` is
    `(num_tokens, top_k_num, num_slices * N_per_slice)` with
    contiguous-style strides, `expert_ids_flat` is the flattened
    `topk_ids` of shape `(num_tokens * top_k_num,)`, and the
    rank-padded LoRA weights live in `lora_a_stacked` /
    `lora_b_stacked`.

    The kernel is persistent over (pair × slice × n_chunk) work items —
    each program does one shrink and reuses the rank vector across
    `n_tiles_per_program` expand tiles. The chunk size scales with the
    pair-slice count so very small batches keep per-tile parallelism
    while medium batches cut redundant shrinks.
    """
    num_slices = len(lora_a_stacked)
    device = qcurr_hidden_states.device

    A0 = lora_a_stacked[0]
    B0 = lora_b_stacked[0]
    max_loras_w = A0.shape[0]
    rank = A0.shape[2]
    K = A0.shape[3]
    N_per_slice = B0.shape[2]

    # Rank padding: floor 16 (tensor-core min K), ceil to next pow2. The
    # ≤64 cap is set conservatively for the prototype: at rank 64 the
    # per-program register footprint is rank_vec(64 fp32) + b_tile(BLOCK_N
    # × 64 fp32) = e.g. 128*64*4 = 32 KiB, comfortably within the 64 KiB
    # register file even with num_warps=8. Doubling to 128 would push us
    # against the limit and require shared-memory staging.
    assert rank <= 64, f"fused_moe_lora_small_batch supports rank<=64; got rank={rank}"
    BLOCK_R = max(triton.next_power_of_2(rank), 16)

    num_tokens = topk_weights.shape[0]
    M_grid = num_tokens * top_k_num
    if M_grid == 0 or num_slices == 0:
        return

    token_mapping_factor = 1 if mul_routed_weight else top_k_num

    A_ptrs = _get_ptr(lora_a_stacked, device)
    B_ptrs = _get_ptr(lora_b_stacked, device)

    assert output.dim() == 3, f"output must be 3-D, got {output.shape}"
    assert output.stride(0) == output.shape[1] * output.stride(1), (
        "fused_moe_lora_small_batch requires output.stride(0) == "
        f"top_k*stride(1); got shape={output.shape} strides={output.stride()}"
    )
    out_view = output.view(-1, output.shape[-1])

    # Block sizes. BLOCK_N=128 matches the one_shot's expand tile and gives
    # 6-24 N tiles for typical N ∈ [768, 3072], enough to saturate the SM
    # array once M_grid * num_slices reaches ~SM_count. BLOCK_K=128 halves
    # the K-loop trip count vs 64 and pays for itself once K ≥ 1024 (the
    # only regime we care about — hidden sizes are always large here).
    BLOCK_N = 128
    BLOCK_K = 128
    nw = 4
    ns = 3

    N_tiles = triton.cdiv(N_per_slice, BLOCK_N)
    pair_slices = M_grid * num_slices

    sm_count = current_platform.num_compute_units(device.index)
    n_tiles_per_program = _pick_small_batch_chunk(pair_slices, N_tiles, sm_count)
    n_chunks = triton.cdiv(N_tiles, n_tiles_per_program)
    work_total = pair_slices * n_chunks

    # Grid sizing: keep parallelism uncapped when work_total is small (so
    # very small batches still spread across SMs); cap at 2*SM_count once
    # we have plenty of work, letting the in-kernel stride loop drain the
    # remainder.
    grid_size = min(work_total, max(1, 2 * sm_count))
    grid = (grid_size,)

    _fused_moe_lora_small_batch_kernel[grid](
        qcurr_hidden_states,
        A_ptrs,
        B_ptrs,
        out_view,
        topk_weights,
        expert_ids_flat,
        token_lora_mapping,
        adapter_enabled,
        N_per_slice,
        K,
        top_k_num,
        max_loras_w,
        work_total,
        pair_slices,
        qcurr_hidden_states.stride(0),
        qcurr_hidden_states.stride(1),
        A0.stride(0),
        A0.stride(1),
        A0.stride(2),
        A0.stride(3),
        B0.stride(0),
        B0.stride(1),
        B0.stride(2),
        B0.stride(3),
        out_view.stride(0),
        out_view.stride(1),
        N_per_slice,
        n_tiles_per_program,
        n_chunks,
        token_mapping_factor=token_mapping_factor,
        MUL_ROUTED_WEIGHT=mul_routed_weight,
        ADD_INPUTS=add_inputs,
        BLOCK_R=BLOCK_R,
        actual_rank=rank,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
        NUM_SLICES=num_slices,
        num_warps=nw,
        num_stages=ns,
    )