Speed up docs build (#44635)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-06-05 15:51:44 +01:00
committed by GitHub
parent c66b19800b
commit a80af24356
32 changed files with 234 additions and 159 deletions
+2
View File
@@ -103,6 +103,8 @@ pre-commit run mypy-3.10 --all-files --hook-stage manual
The line length limit for Python code is 88 characters. If you are not sure, use pre-commit to check.
Use [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) (`Args:`/`Returns:`/`Raises:` sections), not reStructuredText/Sphinx fields (`:param:`, `:return:`, `:rtype:`).
### Commit messages
Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example:
+1 -1
View File
@@ -169,7 +169,7 @@ speculative decoding, breaking down the guarantees into three key areas:
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](/tests/v1/spec_decode).
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../../tests/v1/spec_decode).
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
3. **vLLM Logprob Stability**
+9 -9
View File
@@ -83,22 +83,22 @@ plugins:
- "re:vllm\\._.*" # Internal modules
- "vllm.third_party"
- "vllm.vllm_flash_attn"
- "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files
- "vllm.transformers_utils.configs"
- "vllm.transformers_utils.processors"
- !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default
- mkdocstrings:
handlers:
python:
options:
show_symbol_type_heading: true
show_symbol_type_toc: true
filters:
- "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs
summary:
modules: true
show_signature_annotations: true
separate_signature: true
filters: []
show_overloads: true
signature_crossrefs: true
# Recommendations from api-autonav
docstring_section_style: list
parameter_headings: true
show_symbol_type_heading: true
show_symbol_type_toc: true
summary: true
inventories:
- https://docs.python.org/3/objects.inv
- https://typing-extensions.readthedocs.io/en/latest/objects.inv
+10 -6
View File
@@ -868,9 +868,10 @@ def cutlass_scaled_mm_azp(
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
:param azp_adj: In the per-tensor case, this should include the azp.
Always per-channel.
:param azp: Only set in the per-token case. Per-token if set.
Args:
azp_adj: In the per-tensor case, this should include the azp.
Always per-channel.
azp: Only set in the per-token case. Per-token if set.
"""
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
@@ -3886,9 +3887,12 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.
:param x: value to be transformed inplace
:param inplace: modify value in place
:return: value after transformation
Args:
x: value to be transformed inplace
inplace: modify value in place
Returns:
value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)
+9 -6
View File
@@ -82,11 +82,12 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
Results are cached by resolved types to avoid repeated
inspect.getsource() calls.
:return:
Args:
srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
Results are cached by resolved types to avoid repeated
inspect.getsource() calls.
"""
# Resolve instances to their class for a hashable cache key.
cache_key = tuple(
@@ -99,7 +100,9 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
Returns:
A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
@@ -276,9 +276,11 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
Args:
node: The auto-functionalized node
mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
@@ -317,10 +319,11 @@ class FixFunctionalizationPass(VllmInductorPass):
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
Args:
graph: Graph to insert the defunctionalized node into
node: The auto-functionalized node to defunctionalize
args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
f"node must be auto-functionalized, is {node} instead"
@@ -108,9 +108,13 @@ class NoOpEliminationPass(VllmInductorPass):
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
Args:
dim: The dimension arg to reshape/slice
i_dim: The corresponding dimension in the input tensor
Returns:
Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
+10 -7
View File
@@ -180,8 +180,9 @@ class CuMemAllocator:
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
Args:
offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
@@ -230,9 +231,10 @@ class CuMemAllocator:
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
Args:
tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items():
if tags is None or data.tag in tags:
@@ -255,8 +257,9 @@ class CuMemAllocator:
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
Args:
tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CuMemAllocator.default_tag
+10 -5
View File
@@ -132,7 +132,8 @@ class KVEventAggregator:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
Args:
events: List of KVCacheEvent objects.
"""
if not isinstance(events, list):
raise TypeError("events must be a list of KVCacheEvent.")
@@ -142,7 +143,8 @@ class KVEventAggregator:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
Returns:
List of events present in all workers.
"""
return [
event
@@ -154,7 +156,8 @@ class KVEventAggregator:
"""
Return all events for all workers.
:return: List of events for all workers.
Returns:
List of events for all workers.
"""
return list(self._event_counter.elements())
@@ -168,7 +171,8 @@ class KVEventAggregator:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
Args:
count: Number to increment the workers by.
"""
if count <= 0:
raise ValueError("count must be positive.")
@@ -184,7 +188,8 @@ class KVEventAggregator:
"""
Return the number of workers.
:return: int number of workers.
Returns:
int number of workers.
"""
return self._num_workers
@@ -439,13 +439,12 @@ def _init_lmcache_engine(
`LMCACHE_CONFIG_FILE` to load the configuration file. If that environment
variable is not set, this function will return None.
:param lmcache_config: The LMCache configuration.
:type lmcache_config: LMCacheEngineConfig
:param vllm_config: The vLLM configuration.
:type vllm_config: VllmConfig
Args:
lmcache_config: The LMCache configuration.
vllm_config: The vLLM configuration.
:return: The initialized LMCache engine
:rtype: LMCacheEngine
Returns:
The initialized LMCache engine
"""
if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME):
return curr_engine
+17 -11
View File
@@ -113,11 +113,14 @@ def register_op(
"""
Register a new vLLM IR op.
:param f: the native implementation of the op
:param name: the name of the op, defaults to the function name
:param activations: list of activation params, defaults to params starting with 'x'
:param allow_inplace: add a maybe_inplace overload that allows inplace impls
:return: the IrOp object if f is provided, otherwise a decorator
Args:
f: the native implementation of the op
name: the name of the op, defaults to the function name
activations: list of activation params, defaults to params starting with 'x'
allow_inplace: add a maybe_inplace overload that allows inplace impls
Returns:
the IrOp object if f is provided, otherwise a decorator
Example usage:
```python
@@ -245,14 +248,17 @@ class IrOp:
supported: bool = True,
supports_args: Callable[..., bool] | None = None,
inplace: bool = False,
):
) -> Callable[[Callable[..., Any]], "IrOpImpl"]:
"""
Register an implementation for this custom op.
:param provider: The name of the provider, must be unique.
:param supported: Static support check, use this to check platform support.
:param supports_args: Dynamic arg support check, used for types and shapes.
:param inplace: Does this op reuse activation input memory for outputs
:return: A decorator that registers the implementation.
Args:
provider: The name of the provider, must be unique.
supported: Static support check, use this to check platform support.
supports_args: Dynamic arg support check, used for types and shapes.
inplace: Does this op reuse activation input memory for outputs
Returns:
A decorator that registers the implementation.
The decorated function must have the same semantics and signature as
the native implementation.
+3 -3
View File
@@ -12,9 +12,9 @@ from typing import Any
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
Args:
srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
"""
hasher = hashlib.sha256()
for src in srcs:
@@ -599,12 +599,14 @@ class FusedMoE(PluggableLayer):
):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
Args:
shard_dim: dimension to shard
expert_data: parameter for a particular expert
shard_id: either w1, w2, or w3
loaded_weight: checkpoint weight to load into the param
tp_rank: tensor parallel rank
load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
# In the case where we have actorder/g_idx, we do not partition the
@@ -178,8 +178,9 @@ class QuantizationConfig(ABC):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
Args:
hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
@@ -267,8 +267,11 @@ class CompressedTensorsConfig(QuantizationConfig):
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements
Args:
config: The `quantization_config` dictionary from config.json
Returns:
A tuple with two elements
1. A dictionary mapping target layer names to their corresponding
sparsity_config
2. A list of layer names to ignore for sparsity
@@ -296,8 +299,11 @@ class CompressedTensorsConfig(QuantizationConfig):
cls, config: dict[str, Any]
) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
Args:
config: The `quantization_config` dictionary from config.json
Returns:
A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
@@ -967,7 +973,9 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
Args:
kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
@@ -38,11 +38,11 @@ class CompressedTensorsScheme(ABC):
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
Args:
layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
x: input to the layer
bias: bias parameter
"""
raise NotImplementedError()
@@ -133,12 +133,11 @@ def find_matched_target(
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
Args:
layer_name: layer name
module: torch.nn.Module
targets: list of targets to match the layer against
fused_mapping: map from fused layer names to its components
"""
if layer_name is None:
@@ -161,9 +160,10 @@ def _find_first_match(
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
Args:
value: string to compare the list of targets against
targets: list of targets to match the layer against
check_contains: whether or not to do a substring match
"""
for target in targets:
@@ -205,9 +205,10 @@ def _match_fused_layer(
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Args:
layer_name: layer name
target_layers: list of targets to match the layer against
fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
@@ -114,8 +114,9 @@ class GGUFConfig(QuantizationConfig):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
Args:
hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
if self.unquantized_modules is not None:
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
@@ -46,16 +46,17 @@ class QuantFP8(CustomOp):
compile_native: bool = True,
):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in
column major format
:param compile_native: Manually compile forward_native if compile mode > None
Args:
static: static or dynamic quantization
group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
PER_CHANNEL, or arbitrary block size)
num_token_padding: Pad the token dimension of output to this
size
tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
column_major_scales: For group quantization, output scales in
column major format
compile_native: Manually compile forward_native if compile mode > None
"""
super().__init__(compile_native=compile_native)
self.static = static
@@ -47,7 +47,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
Args:
quant_config: the appropriate QuantizationConfig
"""
def __init__(self, quant_config: QuantizationConfig):
@@ -87,8 +87,9 @@ class QuarkConfig(QuantizationConfig):
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
Args:
hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {}
@@ -724,7 +725,9 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_config: the quark kv cache scheme
Args:
kv_cache_config: the quark kv cache scheme
"""
if kv_cache_config is None:
return
@@ -38,11 +38,11 @@ class QuarkScheme(ABC):
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
Args:
layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
x: input to the layer
bias: bias parameter
"""
raise NotImplementedError
@@ -123,7 +123,8 @@ def initialize_online_processing(layer: torch.nn.Module):
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading
:param layer: layer whose parameter weight loaders will be wrapped
Args:
layer: layer whose parameter weight loaders will be wrapped
"""
info = get_layerwise_info(layer)
@@ -222,8 +223,9 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
:param model: model to finalize processing for
:param model_config: config needed for applying processing to attention layers
Args:
model: model to finalize processing for
model_config: config needed for applying processing to attention layers
"""
if hasattr(model, "_original_do_torchao_reload"):
model._do_torchao_reload = model._original_do_torchao_reload
@@ -175,9 +175,12 @@ def get_numel_loaded(
"""
Determine how many elements would be loaded by a weight loader call.
:param weight loader: used to load weights
:param args: bound arguments to weight loader
:return: number of elements loaded by the weight loader, the return value of the
Args:
weight_loader: used to load weights
args: bound arguments to weight loader
Returns:
number of elements loaded by the weight loader, the return value of the
weight loader
"""
with CopyCounter() as counter:
@@ -20,9 +20,12 @@ def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.T
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
Args:
tensor: tensor to be sanitized
layer: layer whose references should be removed
Returns:
sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer:
@@ -38,10 +41,12 @@ def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Te
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
Args:
tensor: tensor to be sanitized
layer: layer whose references should be removed
Returns:
sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
@@ -49,8 +49,11 @@ def has_device_tensors(bound_args: BoundArguments) -> bool:
"""
Return True if the loaded weights exist on an accelerator device
:param bound_args: args to load weights
:return: True if weights are on accelerator device
Args:
bound_args: args to load weights
Returns:
True if weights are on accelerator device
"""
return any(
isinstance(value, torch.Tensor) and value.device.type not in ("meta", "cpu")
@@ -62,8 +65,11 @@ def get_info_size(info: LayerReloadingInfo) -> int:
"""
Calculate the number of bytes used by loaded weights for a given layer
:param info: layerwise info to get size of
:return: number of bytes used by loaded weights
Args:
info: layerwise info to get size of
Returns:
number of bytes used by loaded weights
"""
return sum(
value.nbytes
+2 -1
View File
@@ -277,7 +277,8 @@ class OlmoModel(nn.Module):
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
Args:
input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
+2 -1
View File
@@ -314,7 +314,8 @@ class Olmo2Model(nn.Module):
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
Args:
input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
+10 -9
View File
@@ -3,6 +3,7 @@
from collections.abc import Callable, Hashable
from fractions import Fraction
from typing import Any
from weakref import WeakValueDictionary
import torch
@@ -42,10 +43,9 @@ class BasevLLMParameter(Parameter):
"""
Initialize the BasevLLMParameter
:param data: torch tensor with the parameter data
:param weight_loader: weight loader callable
:returns: a torch.nn.parameter
Args:
data: torch tensor with the parameter data
weight_loader: weight loader callable
"""
# During weight loading, we often do something like:
@@ -445,15 +445,16 @@ class SharedWeightParameter(BasevLLMParameter):
"currently support tensor parallelism"
)
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
def add_partition(self, index: int, data_key: Hashable, *args: Any, **kwargs: Any):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
:param index: index of partition to add
:param data_key: hashable key used to key shared tensors
:param *args: arguments for `torch.empty`
:param **kwargs: keyword arguments for `torch.empty`
Args:
index: index of partition to add
data_key: hashable key used to key shared tensors
*args: arguments for `torch.empty`
**kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if data_key not in self.tensors_registry:
+5 -2
View File
@@ -84,8 +84,11 @@ def maybe_model_redirect(model: str) -> str:
"""
Use model_redirect to redirect the model name to a local folder.
:param model: hf model name
:return: maybe redirect to a local folder
Args:
model: hf model name
Returns:
maybe redirect to a local folder
"""
model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH
+6 -3
View File
@@ -808,14 +808,17 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: "QuantKey"):
def fused_output_quant_supported(self, quant_key: "QuantKey") -> bool:
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
Args:
quant_key: QuantKey object that describes the quantization op
Returns:
is fusion supported for this type of quantization
"""
return False
+15 -12
View File
@@ -1878,9 +1878,8 @@ class GPUModelRunner(
SpecDecodeMetadata | None,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
]
Returns:
tuple[logits_indices, spec_decode_metadata]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -2205,7 +2204,8 @@ class GPUModelRunner(
slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
Returns:
tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0:
@@ -2503,9 +2503,11 @@ class GPUModelRunner(
num_common_prefix_blocks: list[int],
) -> list[list[int]] | None:
"""
:return: Optional[cascade_attn_prefix_lens]
cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
None if we should not use cascade attention
Returns:
Optional[cascade_attn_prefix_lens]
cascade_attn_prefix_lens is 2D:
``[kv_cache_group_id][attn_group_idx]``,
None if we should not use cascade attention
"""
use_cascade_attn = False
@@ -5324,11 +5326,12 @@ class GPUModelRunner(
"""
Reload weights from a weights iterator or from disk
:param weights_iterator: weights to load into model
:param weights_path: path to load weights from if weights_iterator is not
provided. Use path of original model if neither is provided.
:param is_checkpoint_format: set to False if weights have already been processed
into kernel format (repacking, renaming, etc.)
Args:
weights_iterator: weights to load into model
weights_path: path to load weights from if weights_iterator is not
provided. Use path of original model if neither is provided.
is_checkpoint_format: set to False if weights have already been
processed into kernel format (repacking, renaming, etc.)
"""
# TODO(@kylesayrs): generalize to all runners and loaders
# argument validation