@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
else:
self.rocm_aiter_fused_experts = None # type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
and current_platform.get_device_capability()[0] >= 9
)
if self.flashinfer_cutlass_moe_enabled:
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
)
else:
if (
self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif self.moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
self.flashinfer_cutlass_moe = None # type: ignore
@property
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(self.moe_quant_config)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_up_dim,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (
envs.VLLM_ROCM_MOE_PADDING
and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
experts_start_id=ep_rank_start,
)
elif current_platform.is_cpu():
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
dtype_w13 = layer.w13_weight.dtype
_, n_w13, k_w13 = layer.w13_weight.size()
dtype_w2 = layer.w2_weight.dtype
_, n_w2, k_w2 = layer.w2_weight.size()
if (
envs.VLLM_CPU_SGL_KERNEL
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
):
packed_w13_weight = torch.ops._C.convert_weight_packed(
layer.w13_weight
)
assert packed_w13_weight.size() == layer.w13_weight.size()
layer.w13_weight.copy_(packed_w13_weight)
del packed_w13_weight
packed_w2_weight = torch.ops._C.convert_weight_packed(
layer.w2_weight
)
assert packed_w2_weight.size() == layer.w2_weight.size()
layer.w2_weight.copy_(packed_w2_weight)
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
router_logits=router_logits,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if self.rocm_aiter_moe_enabled:
result = self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=layer.expert_map,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
def forward_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe(
layer,
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
layer.global_num_experts,
layer.expert_map,
layer.custom_routing_function,
layer.scoring_func,
layer.routed_scaling_factor,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
def forward_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
custom_routing_function=layer.custom_routing_function,
)
def forward_tpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not layer.use_grouped_topk
assert layer.num_expert_group is None
assert layer.topk_group is None
assert layer.custom_routing_function is None
assert layer.apply_router_weight_on_input is False
if layer.scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU."
)
if layer.e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU."
)
assert layer.activation == "silu", (
f"{layer.activation} is not supported for TPU."
)
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
)
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=layer.top_k,
gating_output=router_logits,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
renormalize=layer.renormalize,
)
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu
else:
forward_native = forward_cuda