Skip to content

vllm.distributed.weight_transfer

Weight transfer engines for syncing model weights from trainers to inference workers.

Modules:

Name Description
base

Base class for weight transfer engines.

factory

Factory for weight transfer engines with lazy loading.

ipc_engine

IPC-based weight transfer engine using CUDA IPC for communication.

nccl_engine

NCCL-based weight transfer engine.

packed_tensor

Packed tensor utilities for efficient weight transfer.

WeightTransferEngine

Bases: ABC, Generic[TInitInfo, TUpdateInfo]

Base class for weight transfer engines that handle transport of model weights from a trainer to inference workers.

This abstraction separates weight transfer transport logic from the worker implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in.

Subclasses should define

init_info_cls: Type of backend-specific initialization info update_info_cls: Type of backend-specific update info

Source code in vllm/distributed/weight_transfer/base.py
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
    """
    Base class for weight transfer engines that handle transport of model weights
    from a trainer to inference workers.

    This abstraction separates weight transfer transport logic from the worker
    implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
    plugged in.

    Subclasses should define:
        init_info_cls: Type of backend-specific initialization info
        update_info_cls: Type of backend-specific update info
    """

    # Subclasses should override these class attributes
    init_info_cls: type[TInitInfo]
    update_info_cls: type[TUpdateInfo]

    def __init__(
        self,
        config: WeightTransferConfig,
        parallel_config: ParallelConfig,
        model: torch.nn.Module,
    ) -> None:
        """
        Initialize the weight transfer engine.

        Args:
            config: The configuration for the weight transfer engine
            parallel_config: The configuration for the parallel setup
            model: The local model instance which will receive the weights
        """
        self.config = config
        self.parallel_config = parallel_config
        self.model = model

    def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
        """
        Construct typed init info from dict with validation.

        Args:
            init_dict: Dictionary containing backend-specific initialization parameters

        Returns:
            Typed backend-specific init info dataclass

        Raises:
            ValueError: If init_dict is invalid for this backend
        """
        try:
            return self.init_info_cls(**init_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid init_info for {self.__class__.__name__}: {e}"
            ) from e

    def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
        """
        Construct typed update info from dict with validation.

        Args:
            update_dict: Dictionary containing backend-specific update parameters

        Returns:
            Typed backend-specific update info dataclass

        Raises:
            ValueError: If update_dict is invalid for this backend
        """
        try:
            return self.update_info_cls(**update_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid update_info for {self.__class__.__name__}: {e}"
            ) from e

    @abstractmethod
    def init_transfer_engine(self, init_info: TInitInfo) -> None:
        """
        Initialize the weight transfer mechanism.
        This is called once at the beginning of training.

        Args:
            init_info: Backend-specific initialization info
        """
        raise NotImplementedError

    @abstractmethod
    def receive_weights(
        self,
        update_info: TUpdateInfo,
        load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
    ) -> None:
        """
        Receive weights from the trainer and load them incrementally.

        Args:
            update_info: Backend-specific update info containing parameter metadata
                        and any backend-specific data
            load_weights: Callable that loads weights into the model. Called
                         incrementally for each weight to avoid OOM.
        """
        raise NotImplementedError

    @abstractmethod
    def shutdown(self) -> None:
        """
        Shutdown the weight transfer engine.
        This should be called when the worker is shutting down.
        """
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def trainer_send_weights(
        iterator: Iterator[tuple[str, torch.Tensor]],
        trainer_args: dict[str, Any] | Any,
    ) -> None:
        """
        Send weights from trainer to inference workers.

        This is a static method that can be called from the trainer process
        to send weights to all inference workers.

        Args:
            iterator: Iterator of model parameters. Returns (name, tensor) tuples.
                     The tensors should be on the appropriate device for the backend.
            trainer_args: Dictionary containing backend-specific arguments needed
                         to send weights. The structure depends on the backend:
                         - NCCL: Contains 'group', 'src', 'packed', etc.
                         - IPC: Contains 'mode' ('http' or 'ray'),
                                'llm_handle' (for Ray), 'url' (for HTTP), etc.

        Example:
            >>> param_iter = ((n, p) for n, p in model.named_parameters())
            >>> engine.trainer_send_weights(param_iter, trainer_args)
        """
        raise NotImplementedError

__init__

__init__(
    config: WeightTransferConfig,
    parallel_config: ParallelConfig,
    model: Module,
) -> None

Initialize the weight transfer engine.

Parameters:

Name Type Description Default
config WeightTransferConfig

The configuration for the weight transfer engine

required
parallel_config ParallelConfig

The configuration for the parallel setup

required
model Module

The local model instance which will receive the weights

required
Source code in vllm/distributed/weight_transfer/base.py
def __init__(
    self,
    config: WeightTransferConfig,
    parallel_config: ParallelConfig,
    model: torch.nn.Module,
) -> None:
    """
    Initialize the weight transfer engine.

    Args:
        config: The configuration for the weight transfer engine
        parallel_config: The configuration for the parallel setup
        model: The local model instance which will receive the weights
    """
    self.config = config
    self.parallel_config = parallel_config
    self.model = model

init_transfer_engine abstractmethod

init_transfer_engine(init_info: TInitInfo) -> None

Initialize the weight transfer mechanism. This is called once at the beginning of training.

Parameters:

Name Type Description Default
init_info TInitInfo

Backend-specific initialization info

required
Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
    """
    Initialize the weight transfer mechanism.
    This is called once at the beginning of training.

    Args:
        init_info: Backend-specific initialization info
    """
    raise NotImplementedError

parse_init_info

parse_init_info(init_dict: dict[str, Any]) -> TInitInfo

Construct typed init info from dict with validation.

Parameters:

Name Type Description Default
init_dict dict[str, Any]

Dictionary containing backend-specific initialization parameters

required

Returns:

Type Description
TInitInfo

Typed backend-specific init info dataclass

Raises:

Type Description
ValueError

If init_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
    """
    Construct typed init info from dict with validation.

    Args:
        init_dict: Dictionary containing backend-specific initialization parameters

    Returns:
        Typed backend-specific init info dataclass

    Raises:
        ValueError: If init_dict is invalid for this backend
    """
    try:
        return self.init_info_cls(**init_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid init_info for {self.__class__.__name__}: {e}"
        ) from e

parse_update_info

parse_update_info(
    update_dict: dict[str, Any],
) -> TUpdateInfo

Construct typed update info from dict with validation.

Parameters:

Name Type Description Default
update_dict dict[str, Any]

Dictionary containing backend-specific update parameters

required

Returns:

Type Description
TUpdateInfo

Typed backend-specific update info dataclass

Raises:

Type Description
ValueError

If update_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
    """
    Construct typed update info from dict with validation.

    Args:
        update_dict: Dictionary containing backend-specific update parameters

    Returns:
        Typed backend-specific update info dataclass

    Raises:
        ValueError: If update_dict is invalid for this backend
    """
    try:
        return self.update_info_cls(**update_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid update_info for {self.__class__.__name__}: {e}"
        ) from e

receive_weights abstractmethod

receive_weights(
    update_info: TUpdateInfo,
    load_weights: Callable[
        [list[tuple[str, Tensor]]], None
    ],
) -> None

Receive weights from the trainer and load them incrementally.

Parameters:

Name Type Description Default
update_info TUpdateInfo

Backend-specific update info containing parameter metadata and any backend-specific data

required
load_weights Callable[[list[tuple[str, Tensor]]], None]

Callable that loads weights into the model. Called incrementally for each weight to avoid OOM.

required
Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def receive_weights(
    self,
    update_info: TUpdateInfo,
    load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
    """
    Receive weights from the trainer and load them incrementally.

    Args:
        update_info: Backend-specific update info containing parameter metadata
                    and any backend-specific data
        load_weights: Callable that loads weights into the model. Called
                     incrementally for each weight to avoid OOM.
    """
    raise NotImplementedError

shutdown abstractmethod

shutdown() -> None

Shutdown the weight transfer engine. This should be called when the worker is shutting down.

Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def shutdown(self) -> None:
    """
    Shutdown the weight transfer engine.
    This should be called when the worker is shutting down.
    """
    raise NotImplementedError

trainer_send_weights abstractmethod staticmethod

trainer_send_weights(
    iterator: Iterator[tuple[str, Tensor]],
    trainer_args: dict[str, Any] | Any,
) -> None

Send weights from trainer to inference workers.

This is a static method that can be called from the trainer process to send weights to all inference workers.

Parameters:

Name Type Description Default
iterator Iterator[tuple[str, Tensor]]

Iterator of model parameters. Returns (name, tensor) tuples. The tensors should be on the appropriate device for the backend.

required
trainer_args dict[str, Any] | Any

Dictionary containing backend-specific arguments needed to send weights. The structure depends on the backend: - NCCL: Contains 'group', 'src', 'packed', etc. - IPC: Contains 'mode' ('http' or 'ray'), 'llm_handle' (for Ray), 'url' (for HTTP), etc.

required
Example

param_iter = ((n, p) for n, p in model.named_parameters()) engine.trainer_send_weights(param_iter, trainer_args)

Source code in vllm/distributed/weight_transfer/base.py
@staticmethod
@abstractmethod
def trainer_send_weights(
    iterator: Iterator[tuple[str, torch.Tensor]],
    trainer_args: dict[str, Any] | Any,
) -> None:
    """
    Send weights from trainer to inference workers.

    This is a static method that can be called from the trainer process
    to send weights to all inference workers.

    Args:
        iterator: Iterator of model parameters. Returns (name, tensor) tuples.
                 The tensors should be on the appropriate device for the backend.
        trainer_args: Dictionary containing backend-specific arguments needed
                     to send weights. The structure depends on the backend:
                     - NCCL: Contains 'group', 'src', 'packed', etc.
                     - IPC: Contains 'mode' ('http' or 'ray'),
                            'llm_handle' (for Ray), 'url' (for HTTP), etc.

    Example:
        >>> param_iter = ((n, p) for n, p in model.named_parameters())
        >>> engine.trainer_send_weights(param_iter, trainer_args)
    """
    raise NotImplementedError

WeightTransferEngineFactory

Factory for creating weight transfer engines with lazy loading.

This factory implements a registry pattern that supports: - Lazy loading: Engine modules are only imported when actually needed - Extensibility: Custom engines can be registered at runtime - Centralized registration: All built-in engines registered in one place

Source code in vllm/distributed/weight_transfer/factory.py
class WeightTransferEngineFactory:
    """Factory for creating weight transfer engines with lazy loading.

    This factory implements a registry pattern that supports:
    - Lazy loading: Engine modules are only imported when actually needed
    - Extensibility: Custom engines can be registered at runtime
    - Centralized registration: All built-in engines registered in one place
    """

    _registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {}

    @classmethod
    def register_engine(
        cls,
        name: str,
        module_path_or_cls: str | type[WeightTransferEngine],
        class_name: str | None = None,
    ) -> None:
        """Register an engine with lazy-loading or direct class reference.

        Supports two calling conventions:
        1. Lazy loading: register_engine(name, module_path, class_name)
        2. Direct class: register_engine(name, engine_cls)

        Args:
            name: The name to register the engine under (e.g., "nccl")
            module_path_or_cls: Either a module path string for lazy loading,
                or the engine class directly
            class_name: Name of the engine class (required if module_path is string)

        Raises:
            ValueError: If an engine with the same name is already registered
        """
        if name in cls._registry:
            raise ValueError(f"Weight transfer engine '{name}' is already registered.")

        if isinstance(module_path_or_cls, str):
            # Lazy loading path
            module_path = module_path_or_cls
            if class_name is None:
                raise ValueError(
                    "class_name is required when registering with module path"
                )

            def loader() -> type[WeightTransferEngine]:
                module = importlib.import_module(module_path)
                return getattr(module, class_name)

            cls._registry[name] = loader
        else:
            # Direct class registration
            engine_cls = module_path_or_cls
            cls._registry[name] = lambda: engine_cls

    @classmethod
    def create_engine(
        cls,
        config: "WeightTransferConfig",
        parallel_config: "ParallelConfig",
        model: "torch.nn.Module",
    ) -> WeightTransferEngine:
        """Create a weight transfer engine instance.

        Args:
            config: Weight transfer configuration containing the backend name
            parallel_config: Parallel configuration for the engine
            model: The local model instance which will receive the weights

        Returns:
            An initialized weight transfer engine instance

        Raises:
            ValueError: If the backend is not registered
        """
        backend = config.backend
        if backend not in cls._registry:
            available = list(cls._registry.keys())
            raise ValueError(
                f"Invalid weight transfer backend: {backend}. "
                f"Available engines: {available}"
            )
        engine_cls = cls._registry[backend]()

        logger.info(
            "Creating weight transfer engine: %s",
            engine_cls.__name__,
        )

        return engine_cls(config, parallel_config, model)

create_engine classmethod

create_engine(
    config: WeightTransferConfig,
    parallel_config: ParallelConfig,
    model: Module,
) -> WeightTransferEngine

Create a weight transfer engine instance.

Parameters:

Name Type Description Default
config WeightTransferConfig

Weight transfer configuration containing the backend name

required
parallel_config ParallelConfig

Parallel configuration for the engine

required
model Module

The local model instance which will receive the weights

required

Returns:

Type Description
WeightTransferEngine

An initialized weight transfer engine instance

Raises:

Type Description
ValueError

If the backend is not registered

Source code in vllm/distributed/weight_transfer/factory.py
@classmethod
def create_engine(
    cls,
    config: "WeightTransferConfig",
    parallel_config: "ParallelConfig",
    model: "torch.nn.Module",
) -> WeightTransferEngine:
    """Create a weight transfer engine instance.

    Args:
        config: Weight transfer configuration containing the backend name
        parallel_config: Parallel configuration for the engine
        model: The local model instance which will receive the weights

    Returns:
        An initialized weight transfer engine instance

    Raises:
        ValueError: If the backend is not registered
    """
    backend = config.backend
    if backend not in cls._registry:
        available = list(cls._registry.keys())
        raise ValueError(
            f"Invalid weight transfer backend: {backend}. "
            f"Available engines: {available}"
        )
    engine_cls = cls._registry[backend]()

    logger.info(
        "Creating weight transfer engine: %s",
        engine_cls.__name__,
    )

    return engine_cls(config, parallel_config, model)

register_engine classmethod

register_engine(
    name: str,
    module_path_or_cls: str | type[WeightTransferEngine],
    class_name: str | None = None,
) -> None

Register an engine with lazy-loading or direct class reference.

Supports two calling conventions: 1. Lazy loading: register_engine(name, module_path, class_name) 2. Direct class: register_engine(name, engine_cls)

Parameters:

Name Type Description Default
name str

The name to register the engine under (e.g., "nccl")

required
module_path_or_cls str | type[WeightTransferEngine]

Either a module path string for lazy loading, or the engine class directly

required
class_name str | None

Name of the engine class (required if module_path is string)

None

Raises:

Type Description
ValueError

If an engine with the same name is already registered

Source code in vllm/distributed/weight_transfer/factory.py
@classmethod
def register_engine(
    cls,
    name: str,
    module_path_or_cls: str | type[WeightTransferEngine],
    class_name: str | None = None,
) -> None:
    """Register an engine with lazy-loading or direct class reference.

    Supports two calling conventions:
    1. Lazy loading: register_engine(name, module_path, class_name)
    2. Direct class: register_engine(name, engine_cls)

    Args:
        name: The name to register the engine under (e.g., "nccl")
        module_path_or_cls: Either a module path string for lazy loading,
            or the engine class directly
        class_name: Name of the engine class (required if module_path is string)

    Raises:
        ValueError: If an engine with the same name is already registered
    """
    if name in cls._registry:
        raise ValueError(f"Weight transfer engine '{name}' is already registered.")

    if isinstance(module_path_or_cls, str):
        # Lazy loading path
        module_path = module_path_or_cls
        if class_name is None:
            raise ValueError(
                "class_name is required when registering with module path"
            )

        def loader() -> type[WeightTransferEngine]:
            module = importlib.import_module(module_path)
            return getattr(module, class_name)

        cls._registry[name] = loader
    else:
        # Direct class registration
        engine_cls = module_path_or_cls
        cls._registry[name] = lambda: engine_cls