Skip to content

vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe

Utility helpers for NVFP4 + FlashInfer fused-MoE path

interleave_linear_and_gate

interleave_linear_and_gate(
    x: Tensor, group_size: int = 64, dim: int = -1
) -> Tensor

Interleave gate and linear weight rows for CuteDSL wrapper.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def interleave_linear_and_gate(
    x: torch.Tensor,
    group_size: int = 64,
    dim: int = -1,
) -> torch.Tensor:
    """Interleave gate and linear weight rows for CuteDSL wrapper."""
    sizes = x.size()
    dim = dim % x.dim()
    assert sizes[dim] % (group_size * 2) == 0, (
        f"dim {dim} size {sizes[dim]} must be divisible by {group_size * 2}"
    )
    prev_sizes = sizes[:dim]
    post_sizes = sizes[dim + 1 :]
    x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes)
    x = x.transpose(dim, dim + 1).contiguous().view(*sizes)
    return x

is_flashinfer_fp4_cutlass_moe_available

is_flashinfer_fp4_cutlass_moe_available() -> bool

Return True when FlashInfer CUTLASS NV-FP4 kernels can be used.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
    """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
    return (
        envs.VLLM_USE_FLASHINFER_MOE_FP4
        and has_flashinfer_cutlass_fused_moe()
        and current_platform.is_cuda()
        and current_platform.has_device_capability(100)
    )

prepare_nvfp4_moe_layer_for_flashinfer_cutedsl

prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
    layer: FusedMoE,
    w13: Tensor,
    w13_scale: Tensor,
    w13_scale_2: Tensor,
    a13_scale: Tensor,
    w2: Tensor,
    w2_scale: Tensor,
    w2_scale_2: Tensor,
    a2_scale: Tensor,
) -> tuple[
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
]

Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend.

Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper, and interleaves w13 gate/linear rows.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
    layer: "FusedMoE",
    w13: torch.Tensor,
    w13_scale: torch.Tensor,
    w13_scale_2: torch.Tensor,
    a13_scale: torch.Tensor,
    w2: torch.Tensor,
    w2_scale: torch.Tensor,
    w2_scale_2: torch.Tensor,
    a2_scale: torch.Tensor,
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    """Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend.

    Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper,
    and interleaves w13 gate/linear rows.
    """
    from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout

    # Global scaling factors (same as other FlashInfer backends).
    num_experts = w13.shape[0]
    a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
    a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)

    half = w13.shape[1] // 2
    w13 = torch.cat([w13[:, half:], w13[:, :half]], dim=1)
    w13_scale = torch.cat([w13_scale[:, half:], w13_scale[:, :half]], dim=1)

    # Interleave up/gate rows for w13 weights and scales.
    w13 = interleave_linear_and_gate(w13, group_size=64, dim=1)
    w13_scale = interleave_linear_and_gate(w13_scale, group_size=64, dim=1)

    # Convert w13 scale factors: linear → swizzled → MMA layout.
    w13_scale = swizzle_blockscale(w13_scale)
    E, M_padded, K_sf_padded = w13_scale.shape
    w13_scale_flat = w13_scale.reshape(E * M_padded, K_sf_padded)
    w13_scale = convert_sf_to_mma_layout(
        w13_scale_flat,
        m=M_padded,
        k=K_sf_padded * 16,
        num_groups=E,
        sf_vec_size=16,
    )

    # Convert w2 scale factors: linear → swizzled → MMA layout.
    w2_scale = swizzle_blockscale(w2_scale)
    E, M_padded, K_sf_padded = w2_scale.shape
    w2_scale_flat = w2_scale.reshape(E * M_padded, K_sf_padded)
    w2_scale = convert_sf_to_mma_layout(
        w2_scale_flat,
        m=M_padded,
        k=K_sf_padded * 16,
        num_groups=E,
        sf_vec_size=16,
    )

    return (
        w13,
        w13_scale,
        w13_scale_2,
        a13_scale,
        w2,
        w2_scale,
        w2_scale_2,
        a2_scale,
    )

reorder_w1w3_to_w3w1

reorder_w1w3_to_w3w1(
    weight: Tensor, scale: Tensor, dim: int = -2
) -> tuple[Tensor, Tensor]

Re-order the concatenated [w1, w3] tensors to [w3, w1]

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def reorder_w1w3_to_w3w1(
    weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
    """Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
    size = weight.size(dim)
    assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
    half = size // 2

    w1, w3 = weight.split(half, dim=dim)
    s1, s3 = scale.split(half, dim=dim)

    return (
        torch.cat([w3, w1], dim=dim).contiguous(),
        torch.cat([s3, s1], dim=dim).contiguous(),
    )