class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
"""
CuteDSL NvFP4 MoE experts using the FlashInfer functional API.
Uses Standard activation format (non-batched). The kernel handles
routing, expert computation, and reduction internally.
Supports expert parallelism natively.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization is currently supported."
)
self.out_dtype = moe_config.in_dtype
self.hidden_dim = moe_config.hidden_dim
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.topk = moe_config.experts_per_token
self.local_num_experts = moe_config.num_local_experts
self.global_num_experts = moe_config.num_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.local_expert_offset = self.ep_rank * self.local_num_experts
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_cutedsl_moe_nvfp4()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (0,)
workspace2 = (0,)
# K is packed (K//2 for uint8), so output uses hidden_dim.
assert self.hidden_dim == K * 2
output = (M, self.hidden_dim)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self.quant_dtype == "nvfp4"
assert a1q_scale is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
# a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize.
# The functional API expects x_sf with trailing dim: (M, K//16, 1).
x_sf = a1q_scale.unsqueeze(-1)
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
with autotune(_is_fi_autotuning):
flashinfer_cute_dsl_fused_moe_nvfp4(
x=hidden_states,
x_sf=x_sf,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights.float(),
w1_weight=w1,
w1_weight_sf=self.w1_scale,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_scale,
w2_alpha=self.g2_alphas,
num_experts=self.global_num_experts,
top_k=self.topk,
num_local_experts=self.local_num_experts,
local_expert_offset=self.local_expert_offset,
moe_output=output,
)