mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[AutoDeploy] merge feat/ad-2025-07-07 (#6196)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Co-authored-by: Neta Zmora <nzmora@nvidia.com> Co-authored-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
5234502717
commit
41fb8aa8b1
0
benchmarks/cpp/__init__.py
Normal file
0
benchmarks/cpp/__init__.py
Normal file
0
benchmarks/cpp/utils/__init__.py
Normal file
0
benchmarks/cpp/utils/__init__.py
Normal file
6
examples/auto_deploy/.vscode/launch.json
vendored
6
examples/auto_deploy/.vscode/launch.json
vendored
@ -16,8 +16,10 @@
|
||||
"--args.model-factory=AutoModelForCausalLM",
|
||||
"--benchmark.enabled=false",
|
||||
"--prompt.batch-size=2",
|
||||
"--args.model-kwargs",
|
||||
"num_hidden_layers=3,num_attention_heads=32",
|
||||
"--args.model-kwargs.num-hidden-layers=3",
|
||||
"--args.model-kwargs.num-attention-heads=32",
|
||||
"--prompt.sp-kwargs.max-tokens=128",
|
||||
// "--dry-run", // uncomment to print the final config and return
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
|
||||
<div align="left">
|
||||
|
||||
AutoDeploy is designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models like those from Hugging Face, to TensorRT-LLM. It automates graph transformations to integrate inference optimizations such as tensor parallelism, KV-caching and quantization. AutoDeploy supports optimized in-framework deployment, minimizing the amount of manual modification needed.
|
||||
AutoDeploy is an experimental feature in beta stage designed to simplify and accelerate the deployment of PyTorch models, including off-the-shelf models like those from Hugging Face, to TensorRT-LLM. It automates graph transformations to integrate inference optimizations such as tensor parallelism, KV-caching and quantization. AutoDeploy supports optimized in-framework deployment, minimizing the amount of manual modification needed.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
@ -146,7 +146,7 @@ Below is a non-exhaustive list of common config options:
|
||||
| `--args.skip-loading-weights` | Only load the architecture, not the weights |
|
||||
| `--args.model-kwargs` | Extra kwargs that are being passed to the model initializer in the model factory |
|
||||
| `--args.tokenizer-kwargs` | Extra kwargs that are being passed to the tokenizer initializer in the model factory |
|
||||
| `--args.world-size` | The number of GPUs for Tensor Parallel |
|
||||
| `--args.world-size` | The number of GPUs used for auto-sharding the model |
|
||||
| `--args.runtime` | Specifies which type of Engine to use during runtime (`"demollm"` or `"trtllm"`) |
|
||||
| `--args.compile-backend` | Specifies how to compile the graph at the end |
|
||||
| `--args.attn-backend` | Specifies kernel implementation for attention |
|
||||
@ -157,7 +157,7 @@ Below is a non-exhaustive list of common config options:
|
||||
| `--prompt.batch-size` | Number of queries to generate |
|
||||
| `--benchmark.enabled` | Whether to run the built-in benchmark (true/false) |
|
||||
|
||||
For default values and additional configuration options, refer to the `ExperimentConfig` class in [build_and_run_ad.py](./build_and_run_ad.py) file.
|
||||
For default values and additional configuration options, refer to the [`ExperimentConfig`](./build_and_run_ad.py) class in [build_and_run_ad.py](./build_and_run_ad.py) file.
|
||||
|
||||
Here is a more complete example of using the script:
|
||||
|
||||
@ -172,7 +172,7 @@ python build_and_run_ad.py \
|
||||
--benchmark.enabled True
|
||||
```
|
||||
|
||||
#### Logging Level
|
||||
### Logging Level
|
||||
|
||||
Use the following env variable to specify the logging level of our built-in logger ordered by
|
||||
decreasing verbosity;
|
||||
@ -223,9 +223,6 @@ AutoDeploy can be seamlessly integrated into your existing workflows using TRT-L
|
||||
|
||||
Here is an example of how you can build an LLM object with AutoDeploy integration:
|
||||
|
||||
<details>
|
||||
<summary>Click to expand the example</summary>
|
||||
|
||||
```
|
||||
from tensorrt_llm._torch.auto_deploy import LLM
|
||||
|
||||
@ -233,7 +230,7 @@ from tensorrt_llm._torch.auto_deploy import LLM
|
||||
# Construct the LLM high-level interface object with autodeploy as backend
|
||||
llm = LLM(
|
||||
model=<HF_MODEL_CARD_OR_DIR>,
|
||||
world_size=<NUM_WORLD_RANK>,
|
||||
world_size=<DESIRED_WORLD_SIZE>,
|
||||
compile_backend="torch-compile",
|
||||
model_kwargs={"num_hidden_layers": 2}, # test with smaller model configuration
|
||||
attn_backend="flashinfer", # choose between "triton" and "flashinfer"
|
||||
@ -249,28 +246,207 @@ llm = LLM(
|
||||
|
||||
```
|
||||
|
||||
Please consult the [AutoDeploy `LLM` API](../../tensorrt_llm/_torch/auto_deploy/llm.py) and the
|
||||
[`AutoDeployConfig` class](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)
|
||||
for more detail on how AutoDeploy is configured via the `**kwargs` of the `LLM` API.
|
||||
|
||||
### Expert Configuration of LLM API
|
||||
|
||||
For expert TensorRT-LLM users, we also expose the full set of [`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)
|
||||
*at your own risk* (the argument list diverges from TRT-LLM's argument list):
|
||||
|
||||
<details>
|
||||
<summary>Click to expand for more details on using LlmArgs directly</summary>
|
||||
|
||||
- All config fields that are used by the AutoDeploy core pipeline (i.e. the `InferenceOptimizer`) are
|
||||
_exclusively_ exposed in the [`AutoDeployConfig` class](../../tensorrt_llm/_torch/auto_deploy/llm_args.py).
|
||||
Please make sure to refer to those first.
|
||||
- For expert users we expose the full set of [`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)
|
||||
that can be used to configure the [AutoDeploy `LLM` API](../../tensorrt_llm/_torch/auto_deploy/llm.py) including runtime options.
|
||||
- Note that some fields in the full [`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)
|
||||
object are overlapping, duplicated, and/or _ignored_ in AutoDeploy, particularly arguments
|
||||
pertaining to configuring the model itself since AutoDeploy's model ingestion+optimize pipeline
|
||||
significantly differs from the default manual workflow in TensorRT-LLM.
|
||||
- However, with the proper care the full [`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)
|
||||
objects can be used to configure advanced runtime options in TensorRT-LLM.
|
||||
- Note that any valid field can be simply provided as keyword argument ("`**kwargs`") to the
|
||||
[AutoDeploy `LLM` API](../../tensorrt_llm/_torch/auto_deploy/llm.py).
|
||||
|
||||
</details>
|
||||
|
||||
For more examples on TRT-LLM LLM API, visit [`this page`](https://nvidia.github.io/TensorRT-LLM/examples/llm_api_examples.html).
|
||||
### Expert Configuration of `build_and_run_ad.py`
|
||||
|
||||
______________________________________________________________________
|
||||
For expert users, `build_and_run_ad.py` provides advanced configuration capabilities through a flexible argument parser powered by PyDantic Settings and OmegaConf. You can use dot notation for CLI arguments, provide multiple YAML configuration files, and leverage sophisticated configuration precedence rules to create complex deployment configurations.
|
||||
|
||||
<details>
|
||||
<summary>Click to expand for detailed configuration examples</summary>
|
||||
|
||||
#### CLI Arguments with Dot Notation
|
||||
|
||||
The script supports flexible CLI argument parsing using dot notation to modify nested configurations dynamically. You can target any field in both the [`ExperimentConfig`](./build_and_run_ad.py) and nested [`AutoDeployConfig`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)/[`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.) objects:
|
||||
|
||||
```bash
|
||||
# Configure model parameters
|
||||
# NOTE: config values like num_hidden_layers are automatically resolved into the appropriate nested
|
||||
# dict value ``{"args": {"model_kwargs": {"num_hidden_layers": 10}}}`` although not explicitly
|
||||
# specified as CLI arg
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--args.model-kwargs.num-hidden-layers=10 \
|
||||
--args.model-kwargs.hidden-size=2048 \
|
||||
--args.tokenizer-kwargs.padding-side=left
|
||||
|
||||
# Configure runtime and backend settings
|
||||
python build_and_run_ad.py \
|
||||
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
|
||||
--args.world-size=2 \
|
||||
--args.compile-backend=torch-opt \
|
||||
--args.attn-backend=flashinfer
|
||||
|
||||
# Configure prompting and benchmarking
|
||||
python build_and_run_ad.py \
|
||||
--model "microsoft/phi-4" \
|
||||
--prompt.batch-size=4 \
|
||||
--prompt.sp-kwargs.max-tokens=200 \
|
||||
--prompt.sp-kwargs.temperature=0.7 \
|
||||
--benchmark.enabled=true \
|
||||
--benchmark.bs=8 \
|
||||
--benchmark.isl=1024
|
||||
```
|
||||
|
||||
#### YAML Configuration Files
|
||||
|
||||
Both [`ExperimentConfig`](./build_and_run_ad.py) and [`AutoDeployConfig`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py)/[`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py) inherit from [`DynamicYamlMixInForSettings`](../../tensorrt_llm/_torch/auto_deploy/utils/_config.py), enabling you to provide multiple YAML configuration files that are automatically deep-merged at runtime.
|
||||
|
||||
Create a YAML configuration file (e.g., `my_config.yaml`):
|
||||
|
||||
```yaml
|
||||
# my_config.yaml
|
||||
args:
|
||||
model_kwargs:
|
||||
num_hidden_layers: 12
|
||||
hidden_size: 1024
|
||||
world_size: 4
|
||||
compile_backend: torch-compile
|
||||
attn_backend: triton
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 16
|
||||
transforms:
|
||||
sharding:
|
||||
strategy: auto
|
||||
quantization:
|
||||
enabled: false
|
||||
|
||||
prompt:
|
||||
batch_size: 8
|
||||
sp_kwargs:
|
||||
max_tokens: 150
|
||||
temperature: 0.8
|
||||
top_k: 50
|
||||
|
||||
benchmark:
|
||||
enabled: true
|
||||
num: 20
|
||||
bs: 4
|
||||
isl: 1024
|
||||
osl: 256
|
||||
```
|
||||
|
||||
Create an additional override file (e.g., `production.yaml`):
|
||||
|
||||
```yaml
|
||||
# production.yaml
|
||||
args:
|
||||
world_size: 8
|
||||
compile_backend: torch-opt
|
||||
max_batch_size: 32
|
||||
|
||||
benchmark:
|
||||
enabled: false
|
||||
```
|
||||
|
||||
Then use these configurations:
|
||||
|
||||
```bash
|
||||
# Using single YAML config
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml
|
||||
|
||||
# Using multiple YAML configs (deep merged in order, later files have higher priority)
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml production.yaml
|
||||
|
||||
# Targeting nested AutoDeployConfig with separate YAML
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs my_config.yaml \
|
||||
--args.yaml-configs autodeploy_overrides.yaml
|
||||
```
|
||||
|
||||
#### Configuration Precedence and Deep Merging
|
||||
|
||||
The configuration system follows a strict precedence order where higher priority sources override lower priority ones:
|
||||
|
||||
1. **CLI Arguments** (highest priority) - Direct command line arguments
|
||||
1. **YAML Configs** - Files specified via `--yaml-configs` and `--args.yaml-configs`
|
||||
1. **Default Settings** (lowest priority) - Built-in defaults from the config classes
|
||||
|
||||
**Deep Merging**: Unlike simple overwriting, deep merging intelligently combines nested dictionaries recursively. For example:
|
||||
|
||||
```yaml
|
||||
# Base config
|
||||
args:
|
||||
model_kwargs:
|
||||
num_hidden_layers: 10
|
||||
hidden_size: 1024
|
||||
max_seq_len: 2048
|
||||
```
|
||||
|
||||
```yaml
|
||||
# Override config
|
||||
args:
|
||||
model_kwargs:
|
||||
hidden_size: 2048 # This will override
|
||||
# num_hidden_layers: 10 remains unchanged
|
||||
world_size: 4 # This gets added
|
||||
```
|
||||
|
||||
**Nested Config Behavior**: When using nested configurations, outer YAML configs become init settings for inner objects, giving them higher precedence:
|
||||
|
||||
```bash
|
||||
# The outer yaml-configs affects the entire ExperimentConfig
|
||||
# The inner args.yaml-configs affects only the AutoDeployConfig
|
||||
python build_and_run_ad.py \
|
||||
--model "meta-llama/Meta-Llama-3.1-8B-Instruct" \
|
||||
--yaml-configs experiment_config.yaml \
|
||||
--args.yaml-configs autodeploy_config.yaml \
|
||||
--args.world-size=8 # CLI override beats both YAML configs
|
||||
```
|
||||
|
||||
#### Built-in Default Configuration
|
||||
|
||||
Both [`AutoDeployConfig`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py) and [`LlmArgs`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py) classes automatically load a built-in [`default.yaml`](../../tensorrt_llm/_torch/auto_deploy/config/default.yaml) configuration file that provides sensible defaults for the AutoDeploy inference optimizer pipeline. This file is specified in the [`_get_config_dict()`](../../tensorrt_llm/_torch/auto_deploy/llm_args.py) function and defines default transform configurations for graph optimization stages.
|
||||
|
||||
The built-in defaults are automatically merged with your configurations at the lowest priority level, ensuring that your custom settings always override the defaults. You can inspect the current default configuration to understand the baseline transform pipeline:
|
||||
|
||||
```bash
|
||||
# View the default configuration
|
||||
cat tensorrt_llm/_torch/auto_deploy/config/default.yaml
|
||||
|
||||
# Override specific transform settings
|
||||
python build_and_run_ad.py \
|
||||
--model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
|
||||
--args.transforms.export-to-gm.strict=true
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Roadmap
|
||||
|
||||
1. **Model Coverage:**
|
||||
|
||||
- Expand support for additional LLM variants and features:
|
||||
- LoRA
|
||||
- Speculative Decoding
|
||||
- Model specialization for disaggregated serving
|
||||
|
||||
1. **Performance Optimization:**
|
||||
|
||||
- Enhance inference speed and efficiency with:
|
||||
- MoE fusion and all-reduce fusion techniques
|
||||
- Reuse of TRT-LLM PyTorch operators for greater efficiency
|
||||
|
||||
______________________________________________________________________
|
||||
Check out our [Github Project Board](https://github.com/orgs/NVIDIA/projects/83) to learn more about
|
||||
the current progress in AutoDeploy and where you can help.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
@ -1,13 +1,23 @@
|
||||
"""Main entrypoint to build, test, and prompt AutoDeploy inference models."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, CliApp, CliImplicitFlag
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
CliApp,
|
||||
CliImplicitFlag,
|
||||
CliUnknownArgs,
|
||||
SettingsConfigDict,
|
||||
)
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM, LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import _try_decode_dict_with_str_values
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig, DemoLLM
|
||||
from tensorrt_llm._torch.auto_deploy.utils._config import (
|
||||
DynamicYamlMixInForSettings,
|
||||
deep_merge_dicts,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.benchmark import benchmark, store_benchmark_results
|
||||
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
@ -18,7 +28,11 @@ torch._dynamo.config.cache_size_limit = 20
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Prompt configuration."""
|
||||
"""Prompt configuration.
|
||||
|
||||
This configuration class can be used for this example script to configure the example prompts
|
||||
and the sampling parameters.
|
||||
"""
|
||||
|
||||
batch_size: int = Field(default=2, description="Number of queries")
|
||||
queries: Union[str, List[str]] = Field(
|
||||
@ -54,13 +68,16 @@ class PromptConfig(BaseModel):
|
||||
@classmethod
|
||||
def validate_sp_kwargs(cls, sp_kwargs):
|
||||
"""Insert desired defaults for sampling params and try parsing string values as JSON."""
|
||||
sp_kwargs = {**cls.model_fields["sp_kwargs"].default_factory(), **sp_kwargs}
|
||||
sp_kwargs = _try_decode_dict_with_str_values(sp_kwargs)
|
||||
return sp_kwargs
|
||||
default = cls.model_fields["sp_kwargs"].get_default(call_default_factory=True)
|
||||
return deep_merge_dicts(default, sp_kwargs)
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Benchmark configuration."""
|
||||
"""Benchmark configuration.
|
||||
|
||||
This configuration class can be used for this example script to configure the simple
|
||||
benchmarking we run at the end of the script.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="If true, run simple benchmark")
|
||||
num: int = Field(default=10, ge=1, description="By default run 10 times and get average")
|
||||
@ -73,18 +90,26 @@ class BenchmarkConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ExperimentConfig(BaseSettings):
|
||||
"""Experiment Configuration based on Pydantic BaseModel."""
|
||||
class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"""Experiment Configuration for the example script.
|
||||
|
||||
model_config = ConfigDict(
|
||||
This configuration aggregates all relevant configurations for this example script. It is also
|
||||
used to auto-generate the CLI interface.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
extra="forbid",
|
||||
cli_kebab_case=True,
|
||||
cli_ignore_unknown_args=True,
|
||||
nested_model_default_partial_update=True,
|
||||
)
|
||||
extra_cli_args: CliUnknownArgs
|
||||
|
||||
### CORE ARGS ##################################################################################
|
||||
# The main LLM arguments - contains model, tokenizer, backend configs, etc.
|
||||
args: LlmArgs = Field(
|
||||
description="The main LLM arguments containing model, tokenizer, backend configs, etc."
|
||||
# The main AutoDeploy arguments - contains model, tokenizer, backend configs, etc.
|
||||
args: AutoDeployConfig = Field(
|
||||
description="The main AutoDeploy arguments containing model, tokenizer, backend configs, etc. "
|
||||
"Please check `tensorrt_llm._torch.auto_deploy.llm_args.AutoDeployConfig` for more details."
|
||||
)
|
||||
|
||||
# Optional model field for convenience - if provided, will be used to initialize args.model
|
||||
@ -119,16 +144,50 @@ class ExperimentConfig(BaseSettings):
|
||||
data["args"]["model"] = data["model"]
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_extra_cli_args(cls, data: Dict) -> Dict:
|
||||
"""Process extra CLI args.
|
||||
|
||||
This model validator enables the user to provide additional CLI args that may not be
|
||||
auto-generated by the CLI app. A common use case for this would to modify graph transforms
|
||||
dynamically via CLI arguments.
|
||||
|
||||
For example, the user can provide a CLI argument for raw dictionaries like this, e.g., for
|
||||
``model_kwargs``: ``--args.model-kwargs.num-hidden-layers=10``.
|
||||
"""
|
||||
# build a clean dotlist: ["a.b=1","c.d.e=foo",…]
|
||||
raw: List[str] = data.pop("extra_cli_args", [])
|
||||
dotlist = []
|
||||
it: Iterator[str] = iter(raw)
|
||||
for tok in it:
|
||||
if not tok.startswith("--"):
|
||||
continue
|
||||
body = tok[2:]
|
||||
if "=" in body:
|
||||
body, val = body.split("=", 1)
|
||||
else:
|
||||
# flag + separate value
|
||||
val = next(it, None)
|
||||
# ensure kebab-case is converted to snake_case
|
||||
dotlist.append(f"{body.replace('-', '_')}={val}")
|
||||
|
||||
return deep_merge_dicts(data, OmegaConf.from_dotlist(dotlist))
|
||||
|
||||
@field_validator("model", mode="after")
|
||||
@classmethod
|
||||
def sync_model_with_args(cls, model_value, info):
|
||||
args: LlmArgs = info.data["args"]
|
||||
return args.model if args is not None else model_value
|
||||
if "args" not in info.data:
|
||||
return model_value
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
return args.model
|
||||
|
||||
@field_validator("prompt", mode="after")
|
||||
@classmethod
|
||||
def sync_prompt_batch_size_with_args_max_batch_size(cls, prompt: PromptConfig, info):
|
||||
args: LlmArgs = info.data["args"]
|
||||
if "args" not in info.data:
|
||||
return prompt
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
if args.max_batch_size < prompt.batch_size:
|
||||
args.max_batch_size = prompt.batch_size
|
||||
return prompt
|
||||
@ -136,7 +195,9 @@ class ExperimentConfig(BaseSettings):
|
||||
@field_validator("benchmark", mode="after")
|
||||
@classmethod
|
||||
def adjust_args_for_benchmark(cls, benchmark: BenchmarkConfig, info):
|
||||
args: LlmArgs = info.data["args"]
|
||||
if "args" not in info.data:
|
||||
return benchmark
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
if benchmark.enabled:
|
||||
# propagate benchmark settings to args
|
||||
args.max_batch_size = max(benchmark.bs, args.max_batch_size)
|
||||
@ -151,7 +212,6 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM:
|
||||
"demollm": DemoLLM,
|
||||
"trtllm": LLM,
|
||||
}
|
||||
ad_logger.info(f"{config.args._parallel_config=}")
|
||||
llm = llm_lookup[config.args.runtime](**config.args.to_dict())
|
||||
return llm
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.compile import compile_and_capture
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.fusion import fuse_gemms
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.quantization import quantize
|
||||
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
|
||||
@ -138,10 +138,10 @@ def main():
|
||||
|
||||
if args.restore_from:
|
||||
quant_state_dict = model.state_dict()
|
||||
gm = quantize(gm, {}).to("cuda")
|
||||
quantize(gm, {}).to("cuda")
|
||||
gm.load_state_dict(quant_state_dict, strict=False)
|
||||
|
||||
gm = fuse_gemms(gm)
|
||||
fuse_gemms(gm)
|
||||
|
||||
gm = compile_and_capture(gm, backend="torch-opt", args=(), kwargs=flux_kwargs)
|
||||
|
||||
|
||||
@ -30,7 +30,8 @@ nvidia-nccl-cu12
|
||||
nvidia-cuda-nvrtc-cu12
|
||||
transformers==4.53.1
|
||||
pydantic>=2.9.1
|
||||
pydantic-settings
|
||||
pydantic-settings[yaml]
|
||||
omegaconf
|
||||
pillow==10.3.0
|
||||
wheel<=0.45.1
|
||||
optimum
|
||||
|
||||
3
setup.py
3
setup.py
@ -115,6 +115,7 @@ package_data += [
|
||||
'tools/plugin_gen/templates/*',
|
||||
'bench/build/benchmark_config.yml',
|
||||
'evaluate/lm_eval_tasks/**/*',
|
||||
"_torch/auto_deploy/config/*.yaml",
|
||||
]
|
||||
|
||||
|
||||
@ -185,7 +186,7 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
|
||||
|
||||
with zipfile.ZipFile(wheel_path) as wheel:
|
||||
for file in wheel.filelist:
|
||||
if file.filename.endswith(".py"):
|
||||
if file.filename.endswith((".py", ".yaml")):
|
||||
continue
|
||||
for filename_pattern in package_data:
|
||||
if fnmatch.fnmatchcase(file.filename,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# import submodules that require registration process
|
||||
from . import compile, custom_ops, models, shim # noqa: F401
|
||||
from . import compile, custom_ops, export, models, shim # noqa: F401
|
||||
|
||||
# import AutoDeploy LLM and LlmArgs
|
||||
from .llm import *
|
||||
|
||||
@ -35,10 +35,11 @@ class CapturedGraph(nn.Module):
|
||||
self._out_buffer_flat: List[torch.Tensor] = None
|
||||
self._args_hash: Optional[Tuple[int, ...]] = None
|
||||
self.cuda_graph_batch_sizes = (
|
||||
cuda_graph_batch_sizes
|
||||
sorted(cuda_graph_batch_sizes, reverse=True)
|
||||
if cuda_graph_batch_sizes is not None
|
||||
else self._get_graph_batch_sizes(self.max_batch_size)
|
||||
)
|
||||
self._cuda_graph_mem_pool = None
|
||||
|
||||
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
|
||||
return tuple(hash(a) for a in flat_args)
|
||||
@ -64,7 +65,7 @@ class CapturedGraph(nn.Module):
|
||||
# capture graph now
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
with torch.cuda.graph(graph, pool=self._cuda_graph_mem_pool):
|
||||
# compute output
|
||||
out = self.model(*args, **kwargs)
|
||||
# write out into output buffer up to out batch size
|
||||
@ -73,7 +74,7 @@ class CapturedGraph(nn.Module):
|
||||
for o_buffer, o in zip(self._out_buffer_flat, out_flat):
|
||||
o_buffer[: o.shape[0]] = o
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self._cuda_graph_mem_pool = self._cuda_graph_mem_pool or graph.pool()
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
@ -88,7 +89,7 @@ class CapturedGraph(nn.Module):
|
||||
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
|
||||
|
||||
# return as sorted list
|
||||
return sorted(batch_sizes)
|
||||
return sorted(batch_sizes, reverse=True)
|
||||
|
||||
def capture_graph(self, *args, **kwargs):
|
||||
"""Capture and pre-fetch the graph for variable batch size."""
|
||||
@ -118,6 +119,7 @@ class CapturedGraph(nn.Module):
|
||||
|
||||
# capture output once with max batch size to capture output buffers
|
||||
with CudaGraphWarmUpPhase():
|
||||
ad_logger.info(f"Warm up with {self.max_batch_size=} before graph capture")
|
||||
out = self.model(*args, **kwargs)
|
||||
self._out_buffer_flat, out_spec = tree_flatten(out)
|
||||
assert out_spec == self._out_spec, "Output spec mismatch."
|
||||
|
||||
21
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Normal file
21
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
# Additional default args for AutoDeployConfig/LlmArgs in _torch/auto_deploy/llm_args.py
|
||||
transforms:
|
||||
build_model:
|
||||
stage: factory
|
||||
device: meta
|
||||
# nothing to clean up
|
||||
run_graph_cleanup: false
|
||||
requires_clean_graph: false
|
||||
export_to_gm:
|
||||
stage: export
|
||||
clone_state_dict: false
|
||||
strict: false
|
||||
# nothing to clean up
|
||||
run_graph_cleanup: false
|
||||
requires_clean_graph: false
|
||||
cleanup_noop_slice:
|
||||
stage: post_export
|
||||
cleanup_noop_add:
|
||||
stage: post_export
|
||||
cleanup_input_constraints:
|
||||
stage: post_export
|
||||
@ -7,7 +7,9 @@ from .flashinfer_rope import *
|
||||
from .linear import *
|
||||
from .mla import *
|
||||
from .quant import *
|
||||
from .rms_norm import *
|
||||
from .torch_attention import *
|
||||
from .torch_backend_attention import *
|
||||
from .torch_moe import *
|
||||
from .torch_rope import *
|
||||
from .triton_attention import *
|
||||
|
||||
@ -100,6 +100,8 @@ def _paged_generate_mha(
|
||||
n_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -338,6 +340,7 @@ def _generate_mha_rope_fusion(
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
HEAD_BLOCK_SIZE,
|
||||
-1,
|
||||
)
|
||||
attention_kv_stage2[(b, n_heads, 1)](
|
||||
stage1_output_values,
|
||||
@ -348,6 +351,8 @@ def _generate_mha_rope_fusion(
|
||||
n_heads,
|
||||
d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -414,7 +419,9 @@ def _flattened_context_mha_rope_fusion(
|
||||
d_head,
|
||||
SEQ_BLOCK,
|
||||
max_cache_seq_len,
|
||||
num_stages=2,
|
||||
-1,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -117,14 +117,20 @@ class SequenceInfo:
|
||||
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
|
||||
# we use the provided max_num_tokens to calculate the number of pages
|
||||
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
|
||||
self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
|
||||
# Num pages can not be less than max_batch_size.
|
||||
self._num_pages = max(
|
||||
self.max_batch_size,
|
||||
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
|
||||
)
|
||||
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
|
||||
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
|
||||
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
|
||||
self.input_pos = torch.empty_like(self.seq_len)
|
||||
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
|
||||
self.pages_per_seq = torch.empty_like(self.seq_len)
|
||||
|
||||
assert self.num_pages >= self.max_batch_size, (
|
||||
"num_pages must be greater than max_batch_size"
|
||||
)
|
||||
# dynamic shape descriptors for tensor args
|
||||
self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
|
||||
|
||||
@ -378,10 +384,11 @@ class SequenceInfo:
|
||||
def _update_position_ids(self) -> None:
|
||||
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
|
||||
position_ids_list = [
|
||||
torch.arange(in_pos, in_pos + seq_len, dtype=torch.long)
|
||||
num
|
||||
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
|
||||
for num in range(in_pos, in_pos + seq_len)
|
||||
]
|
||||
self.position_ids = torch.cat(position_ids_list, dim=0).to(self.device)
|
||||
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)
|
||||
|
||||
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
|
||||
if self.is_generate:
|
||||
@ -398,13 +405,15 @@ class SequenceInfo:
|
||||
seq_lens = [len(ids) for ids in input_ids]
|
||||
self.seq_len.zero_()
|
||||
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
|
||||
|
||||
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
|
||||
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
|
||||
# set new input_ids as new tensor from flattened input_ids
|
||||
ids_tnsr_list = [
|
||||
lst.detach() if isinstance(lst, torch.Tensor) else torch.tensor(lst, dtype=torch.int)
|
||||
ids_list = [
|
||||
val
|
||||
for lst in input_ids
|
||||
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
|
||||
]
|
||||
self.input_ids = torch.cat(ids_tnsr_list, dim=0).to(self.device)
|
||||
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)
|
||||
|
||||
# set derivative properties
|
||||
self._sequence_lengths = seq_lens
|
||||
|
||||
82
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Normal file
82
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Custom operator for FlashInfer and Triton RMSNorm implementation."""
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from .triton_kernels.rms_norm import rms_norm
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_rms_norm", mutates_args=())
|
||||
def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Custom operator for FlashInfer RMSNorm implementation.
|
||||
|
||||
Args:
|
||||
input: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using FlashInfer implementation.
|
||||
"""
|
||||
# Flashinfer rmsnorm expects a 2D input
|
||||
input_flat = input.reshape(-1, input.shape[-1])
|
||||
rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)
|
||||
return rmsnorm_flat.reshape(input.shape)
|
||||
|
||||
|
||||
@flashinfer_rmsnorm.register_fake
|
||||
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Fake implementation for the custom operator during tracing.
|
||||
|
||||
Args:
|
||||
input: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Empty tensor with same shape as input.
|
||||
"""
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_rms_norm", mutates_args=())
|
||||
def triton_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Custom operator for Triton RMSNorm implementation.
|
||||
|
||||
Args:
|
||||
input: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using Triton implementation.
|
||||
"""
|
||||
return rms_norm(input, weight, eps)
|
||||
|
||||
|
||||
@triton_rmsnorm.register_fake
|
||||
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Fake implementation for the custom operator during tracing."""
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_rmsnorm", mutates_args=())
|
||||
def torch_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Custom operator for Torch RMSNorm implementation.
|
||||
|
||||
Args:
|
||||
input: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
"""
|
||||
input_dtype = input.dtype
|
||||
input = input.to(torch.float32)
|
||||
variance = input.pow(2).mean(-1, keepdim=True)
|
||||
input = input * torch.rsqrt(variance + eps)
|
||||
return weight * input.to(input_dtype)
|
||||
|
||||
|
||||
@torch_rmsnorm.register_fake
|
||||
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Fake implementation for the custom operator during tracing."""
|
||||
return torch.empty_like(input)
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -113,6 +115,9 @@ def bsnd_grouped_sdpa(
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Attention that assumes the input layout is bsnd.
|
||||
|
||||
@ -132,7 +137,16 @@ def bsnd_grouped_sdpa(
|
||||
|
||||
@bsnd_grouped_sdpa.register_fake
|
||||
def bsnd_grouped_sdpa_fake(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
sinks=None,
|
||||
sliding_window=None,
|
||||
logit_cap=None,
|
||||
):
|
||||
"""Fake implementation of bnsd grouped SDPA."""
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
@ -0,0 +1,495 @@
|
||||
"""Torch backend attention using pure PyTorch reference implementations."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch._subclasses import FakeTensor
|
||||
from torch.fx import Node
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import extract_op_args
|
||||
from .attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
BufferInitializerDict,
|
||||
CacheConfig,
|
||||
CacheInitializerDict,
|
||||
Constant,
|
||||
MHACallable,
|
||||
PrepareMetadataCallable,
|
||||
SequenceInfo,
|
||||
)
|
||||
from .torch_attention import repeat_kv, update_kv_cache
|
||||
|
||||
|
||||
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
|
||||
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
|
||||
if logit_cap is not None and logit_cap > 0.0:
|
||||
return logit_cap * torch.tanh(attn_scores / logit_cap)
|
||||
return attn_scores
|
||||
|
||||
|
||||
def _torch_generate_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
logit_cap: Optional[float] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Generate-only attention (single token per sequence) using manual computation with existing update_kv_cache."""
|
||||
b, s, n_heads, head_dim = q.shape # q has shape (b, 1, n_heads, head_dim) in generate phase
|
||||
assert s == 1, f"Expected sequence length 1 for generate phase, got {s}"
|
||||
n_kv_heads = k.shape[2] # k has shape (b, 1, n_kv_heads, head_dim)
|
||||
|
||||
# Update KV cache for single token
|
||||
for i in range(b):
|
||||
cache_idx = cache_loc[i].item()
|
||||
pos = input_pos[i].item()
|
||||
k_cache[cache_idx, pos] = k[i, 0] # Remove sequence dim
|
||||
v_cache[cache_idx, pos] = v[i, 0] # Remove sequence dim
|
||||
|
||||
# Compute attention for each sequence using manual computation
|
||||
for i in range(b):
|
||||
cache_idx = cache_loc[i].item()
|
||||
pos = input_pos[i].item()
|
||||
|
||||
# Get query, key, value for this sequence
|
||||
q_i = q[i, 0] # [n_heads, head_dim]
|
||||
|
||||
# Apply sliding window: limit the range of keys/values we attend to
|
||||
if sliding_window_size is not None and sliding_window_size > 0:
|
||||
# Sliding window: attend to [max(0, pos - sliding_window_size + 1), pos]
|
||||
start_pos = max(0, pos - sliding_window_size + 1)
|
||||
k_i = k_cache[cache_idx, start_pos : pos + 1] # [window_len, n_kv_heads, head_dim]
|
||||
v_i = v_cache[cache_idx, start_pos : pos + 1] # [window_len, n_kv_heads, v_head_dim]
|
||||
else:
|
||||
# No sliding window: attend to all previous tokens [0, pos]
|
||||
k_i = k_cache[cache_idx, : pos + 1] # [seq_len, n_kv_heads, head_dim]
|
||||
v_i = v_cache[cache_idx, : pos + 1] # [seq_len, n_kv_heads, v_head_dim]
|
||||
|
||||
# Transpose for attention: [n_heads, 1, head_dim] and [n_kv_heads, seq_len, head_dim]
|
||||
q_i = q_i.unsqueeze(1) # [n_heads, 1, head_dim]
|
||||
k_i = k_i.transpose(0, 1) # [n_kv_heads, seq_len, head_dim]
|
||||
v_i = v_i.transpose(0, 1) # [n_kv_heads, seq_len, v_head_dim]
|
||||
|
||||
# Handle GQA using existing repeat_kv function if needed
|
||||
if n_heads != n_kv_heads:
|
||||
n_rep = n_heads // n_kv_heads
|
||||
# Reshape to [batch, num_kv_heads, seq_len, head_dim] for repeat_kv
|
||||
# k_i is currently [n_kv_heads, seq_len, head_dim]
|
||||
k_i_batch = k_i.unsqueeze(0) # [1, n_kv_heads, seq_len, head_dim]
|
||||
v_i_batch = v_i.unsqueeze(0) # [1, n_kv_heads, seq_len, v_head_dim]
|
||||
k_i_expanded = repeat_kv(k_i_batch, n_rep) # [1, n_heads, seq_len, head_dim]
|
||||
v_i_expanded = repeat_kv(v_i_batch, n_rep) # [1, n_heads, seq_len, v_head_dim]
|
||||
k_i = k_i_expanded[0] # [n_heads, seq_len, head_dim]
|
||||
v_i = v_i_expanded[0] # [n_heads, seq_len, v_head_dim]
|
||||
|
||||
# Compute attention scores
|
||||
attn_scores = torch.matmul(q_i, k_i.transpose(-2, -1)) * scale # [n_heads, 1, seq_len]
|
||||
|
||||
# Apply logit softcapping if enabled
|
||||
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
|
||||
|
||||
# Apply sinks if provided (following the model file pattern)
|
||||
if sinks is not None:
|
||||
# Concatenate sinks to attention scores
|
||||
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
|
||||
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
# Use only the non-sink portion for computing output (ignore sinks)
|
||||
attn_out = torch.matmul(
|
||||
attn_weights[..., : -sinks.size(-1)], v_i
|
||||
) # [n_heads, 1, v_head_dim]
|
||||
else:
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_out = torch.matmul(attn_weights, v_i) # [n_heads, 1, v_head_dim]
|
||||
|
||||
# Store result: remove sequence dimension
|
||||
out[i] = attn_out.squeeze(1) # [n_heads, v_head_dim]
|
||||
|
||||
|
||||
def _torch_context_mha(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
logit_cap: Optional[float] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
|
||||
# Update KV cache first using existing function
|
||||
update_kv_cache(k, v, k_cache, v_cache, seq_len, input_pos, cache_loc, seq_start)
|
||||
|
||||
# Compute attention for each sequence
|
||||
attn_outputs = []
|
||||
for idx in range(seq_len.shape[0]):
|
||||
seq_len_i = seq_len[idx].item()
|
||||
input_pos_i = input_pos[idx].item()
|
||||
cache_loc_i = cache_loc[idx].item()
|
||||
seq_start_i = seq_start[idx].item()
|
||||
|
||||
# Skip sequences with zero length
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
|
||||
# Get query for this sequence
|
||||
q_seq = q[seq_start_i : seq_start_i + seq_len_i] # [seq_len_i, n_heads, head_dim]
|
||||
|
||||
# Get keys and values from cache
|
||||
kv_seq_len = input_pos_i + seq_len_i
|
||||
k_seq = k_cache[cache_loc_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
v_seq = v_cache[cache_loc_i, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
|
||||
# Manual attention computation (shared path for both softcapping and non-softcapping)
|
||||
n_heads = q_seq.shape[1]
|
||||
n_kv_heads = k_seq.shape[1]
|
||||
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim] format
|
||||
q_seq_t = q_seq.transpose(0, 1).unsqueeze(0) # [1, n_heads, seq_len_i, head_dim]
|
||||
k_seq_t = k_seq.transpose(0, 1).unsqueeze(0) # [1, n_kv_heads, kv_seq_len, head_dim]
|
||||
v_seq_t = v_seq.transpose(0, 1).unsqueeze(0) # [1, n_kv_heads, kv_seq_len, head_dim]
|
||||
|
||||
# Handle GQA by repeating KV if needed
|
||||
if n_heads != n_kv_heads:
|
||||
n_rep = n_heads // n_kv_heads
|
||||
k_seq_t = repeat_kv(k_seq_t, n_rep) # [1, n_heads, kv_seq_len, head_dim]
|
||||
v_seq_t = repeat_kv(v_seq_t, n_rep) # [1, n_heads, kv_seq_len, head_dim]
|
||||
|
||||
# Compute attention scores: Q @ K^T
|
||||
attn_scores = (
|
||||
torch.matmul(q_seq_t, k_seq_t.transpose(-2, -1)) * scale
|
||||
) # [1, n_heads, seq_len_i, kv_seq_len]
|
||||
|
||||
# Apply causal mask
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seq_len_i, kv_seq_len, device=q.device, dtype=torch.bool),
|
||||
diagonal=kv_seq_len - seq_len_i + 1,
|
||||
)
|
||||
|
||||
# Apply sliding window mask if specified
|
||||
if sliding_window_size is not None and sliding_window_size > 0:
|
||||
# Create sliding window mask: each query position i can only attend to keys in [i-window_size+1, i]
|
||||
# For context phase, we need to account for the offset between query and key positions
|
||||
|
||||
# Query positions are [input_pos_i, input_pos_i + seq_len_i)
|
||||
# Key positions are [0, input_pos_i + seq_len_i)
|
||||
query_positions = torch.arange(
|
||||
input_pos_i, input_pos_i + seq_len_i, device=q.device
|
||||
) # [seq_len_i]
|
||||
key_positions = torch.arange(0, kv_seq_len, device=q.device) # [kv_seq_len]
|
||||
|
||||
# Create position difference matrix: query_pos - key_pos
|
||||
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(
|
||||
0
|
||||
) # [seq_len_i, kv_seq_len]
|
||||
|
||||
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
|
||||
sliding_window_mask = (pos_diff < 0) | (
|
||||
pos_diff >= sliding_window_size
|
||||
) # [seq_len_i, kv_seq_len]
|
||||
|
||||
# Combine causal and sliding window masks
|
||||
combined_mask = causal_mask | sliding_window_mask
|
||||
else:
|
||||
combined_mask = causal_mask
|
||||
|
||||
attn_scores.masked_fill_(combined_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
||||
|
||||
# Apply logit softcapping if enabled
|
||||
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
|
||||
|
||||
# Apply sinks if provided (following the model file pattern)
|
||||
if sinks is not None:
|
||||
# Concatenate sinks to attention scores
|
||||
sinks = sinks.reshape(1, -1, 1, 1).expand(
|
||||
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
|
||||
)
|
||||
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
# Use only the non-sink portion for computing output (ignore sinks)
|
||||
attn_out = torch.matmul(
|
||||
attn_weights[..., : -sinks.size(-1)], v_seq_t
|
||||
) # [1, n_heads, seq_len_i, v_head_dim]
|
||||
else:
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
attn_out = torch.matmul(attn_weights, v_seq_t) # [1, n_heads, seq_len_i, v_head_dim]
|
||||
|
||||
# Remove batch dimension and transpose back to [seq_len_i, n_heads, v_head_dim]
|
||||
attn_out = attn_out[0].transpose(0, 1)
|
||||
|
||||
attn_outputs.append(attn_out)
|
||||
|
||||
# Concatenate all outputs
|
||||
if len(attn_outputs) == 0:
|
||||
# No sequences to process - this shouldn't happen but handle gracefully
|
||||
out.zero_()
|
||||
elif len(attn_outputs) == 1:
|
||||
# Single sequence
|
||||
out.copy_(attn_outputs[0])
|
||||
else:
|
||||
# Multiple sequences or context phase
|
||||
out.copy_(torch.cat(attn_outputs, dim=0))
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_cached_attention_with_cache", mutates_args=())
|
||||
def torch_backend_mha_with_cache(
|
||||
# Q, K, V
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
# METADATA
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
# CACHES
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
# BUFFERS
|
||||
# <none>
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Torch backend MHA with cache that takes q, k, v in BSND layout."""
|
||||
# Get dimensions
|
||||
num_kv_heads, qk_head_dim = k_cache.shape[-2:]
|
||||
v_head_dim = v_cache.shape[-1]
|
||||
b, s = q.shape[:2]
|
||||
|
||||
# check for num_heads
|
||||
num_heads = q.shape[2] // qk_head_dim if q.ndim == 3 else q.shape[2]
|
||||
|
||||
# Define output shape
|
||||
output_shape = (b, s, num_heads * v_head_dim) if q.ndim == 3 else (b, s, num_heads, v_head_dim)
|
||||
|
||||
# Reshape to standard layout
|
||||
if s == 1:
|
||||
bs_view = (b, s)
|
||||
else:
|
||||
bs_view = (b * s,)
|
||||
|
||||
q = q.contiguous().view(*bs_view, num_heads, qk_head_dim)
|
||||
k = k.contiguous().view(*bs_view, num_kv_heads, qk_head_dim)
|
||||
v = v.contiguous().view(*bs_view, num_kv_heads, v_head_dim)
|
||||
|
||||
scale = 1.0 / math.sqrt(qk_head_dim) if scale is None else scale
|
||||
|
||||
# Create output tensor
|
||||
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
|
||||
|
||||
# Compute attention
|
||||
if s == 1:
|
||||
# Generate-only phase
|
||||
_torch_generate_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cache_loc,
|
||||
input_pos,
|
||||
scale,
|
||||
y,
|
||||
logit_cap,
|
||||
sliding_window_size,
|
||||
sinks,
|
||||
)
|
||||
else:
|
||||
# Context phase
|
||||
_torch_context_mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
seq_start,
|
||||
scale,
|
||||
y,
|
||||
logit_cap,
|
||||
sliding_window_size,
|
||||
sinks,
|
||||
)
|
||||
|
||||
return y.view(*output_shape)
|
||||
|
||||
|
||||
@torch_backend_mha_with_cache.register_fake
|
||||
def torch_backend_mha_with_cache_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
seq_start: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
scale: Optional[float],
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
):
|
||||
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_cached_attention_prepare_metadata", mutates_args=())
|
||||
def torch_backend_prepare_metadata(
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
pages_per_seq: torch.Tensor,
|
||||
page_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for torch backend attention (similar to triton backend)."""
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
|
||||
seq_start = torch.zeros_like(seq_len[:num_seq])
|
||||
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
|
||||
return (
|
||||
seq_len[:num_seq].clone(),
|
||||
input_pos[:num_seq].clone(),
|
||||
cache_loc[:num_seq].clone(),
|
||||
seq_start,
|
||||
)
|
||||
|
||||
|
||||
@torch_backend_prepare_metadata.register_fake
|
||||
def torch_backend_prepare_metadata_fake(
|
||||
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
|
||||
):
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
|
||||
return (
|
||||
torch.empty_like(seq_len[:num_seq]),
|
||||
torch.empty_like(input_pos[:num_seq]),
|
||||
torch.empty_like(cache_loc[:num_seq]),
|
||||
torch.empty_like(seq_len[:num_seq]),
|
||||
)
|
||||
|
||||
|
||||
@AttentionRegistry.register("torch")
|
||||
class TorchBackendAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def is_paged(cls) -> bool:
|
||||
"""Return if the attention op is paged or not."""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
"""Get the attention layout expected by the source op and the cached attention op."""
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
"""Get the number of qkv arguments expected by the source op."""
|
||||
return 3
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
return torch.ops.auto_deploy.torch_cached_attention_prepare_metadata, 4
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: CacheConfig
|
||||
) -> CacheInitializerDict:
|
||||
# source op is [bsnd] layout already
|
||||
k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
|
||||
v_fake: FakeTensor = source_attn_node.args[2].meta["val"]
|
||||
num_kv_heads = k_fake.shape[2]
|
||||
k_head_dim = k_fake.shape[3]
|
||||
v_head_dim = v_fake.shape[3]
|
||||
|
||||
def _get_k_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for torch backend"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
num_kv_heads,
|
||||
k_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or k_fake.dtype,
|
||||
)
|
||||
|
||||
def _get_v_cache(si: SequenceInfo):
|
||||
assert not si.is_paged, "Paged cache not supported for torch backend"
|
||||
return torch.empty(
|
||||
si.num_pages,
|
||||
si.page_size,
|
||||
num_kv_heads,
|
||||
v_head_dim,
|
||||
device=si.device,
|
||||
dtype=cache_config.dtype or v_fake.dtype,
|
||||
)
|
||||
|
||||
return {"k_cache": _get_k_cache, "v_cache": _get_v_cache}
|
||||
|
||||
@classmethod
|
||||
def get_global_buffer_initializers(cls, source_attn_node: Node) -> BufferInitializerDict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
# Check other arguments
|
||||
attn_mask, dropout_p, is_causal = extract_op_args(
|
||||
source_attn_node, "attn_mask", "dropout_p", "is_causal"
|
||||
)
|
||||
if attn_mask is not None or dropout_p != 0.0 or not is_causal:
|
||||
ad_logger.debug(
|
||||
"Unsupported attention arguments for "
|
||||
f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}"
|
||||
)
|
||||
|
||||
# Get scale from args or kwargs
|
||||
if len(source_attn_node.args) > 6:
|
||||
scale = source_attn_node.args[6]
|
||||
else:
|
||||
scale = source_attn_node.kwargs.get("scale", None)
|
||||
|
||||
# Validate scale
|
||||
if not isinstance(scale, float):
|
||||
ad_logger.warning("Provided scale is not a float. Using default scale instead.")
|
||||
scale = None
|
||||
|
||||
# Get sinks, sliding_window, and logit_cap from args or kwargs
|
||||
sinks = extract_op_args(source_attn_node, "sinks")[0]
|
||||
sliding_window = extract_op_args(source_attn_node, "sliding_window")[0]
|
||||
logit_cap = extract_op_args(source_attn_node, "logit_cap")[0]
|
||||
|
||||
return [
|
||||
scale, # softmax scale
|
||||
sinks, # sinks parameter
|
||||
sliding_window, # sliding window parameter
|
||||
logit_cap, # logit cap parameter
|
||||
]
|
||||
@ -1,9 +1,45 @@
|
||||
from typing import List
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _template_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
mlps: List[Callable[[torch.Tensor], torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Mixtral-style generic MoE template, dispatching tokens to expert MLPs based on routing info."""
|
||||
x_shape = x.shape
|
||||
hidden_dim = x_shape[-1]
|
||||
x = x.view(-1, hidden_dim)
|
||||
num_experts = len(mlps)
|
||||
|
||||
final_hidden_states = torch.zeros_like(x)
|
||||
valid_mask = (selected_experts >= 0) & (selected_experts < num_experts)
|
||||
# For out-of-range indices, set them to num_experts
|
||||
selected_experts_fixed = torch.where(
|
||||
valid_mask, selected_experts, torch.full_like(selected_experts, num_experts)
|
||||
)
|
||||
# Create one-hot encoding with an extra class.
|
||||
one_hot = F.one_hot(selected_experts_fixed, num_classes=num_experts + 1)
|
||||
expert_mask = one_hot[..., :num_experts].permute(2, 1, 0)
|
||||
|
||||
for expert_idx in range(num_experts):
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
tokens_for_this_expert = x[None, top_x].reshape(-1, hidden_dim)
|
||||
if not tokens_for_this_expert.shape[0]:
|
||||
continue # input of shape [0, hidden_dim] breaks fp4 kernel
|
||||
|
||||
expert_out = mlps[expert_idx](tokens_for_this_expert)
|
||||
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(
|
||||
0, top_x, current_hidden_states.to(final_hidden_states.dtype)
|
||||
)
|
||||
return final_hidden_states.view(x_shape)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_moe", mutates_args=())
|
||||
def torch_moe(
|
||||
x: torch.Tensor,
|
||||
@ -33,41 +69,17 @@ def torch_moe(
|
||||
torch.Tensor: Output tensor with the same shape as the input x.
|
||||
"""
|
||||
|
||||
x_shape = x.shape
|
||||
hidden_dim = x_shape[-1]
|
||||
x = x.view(-1, hidden_dim)
|
||||
num_experts = len(w1_weight)
|
||||
|
||||
final_hidden_states = torch.zeros_like(x)
|
||||
valid_mask = (selected_experts >= 0) & (selected_experts < num_experts)
|
||||
# For out-of-range indices, set them to num_experts
|
||||
selected_experts_fixed = torch.where(
|
||||
valid_mask, selected_experts, torch.full_like(selected_experts, num_experts)
|
||||
)
|
||||
# Create one-hot encoding with an extra class.
|
||||
one_hot = torch.nn.functional.one_hot(selected_experts_fixed, num_classes=num_experts + 1)
|
||||
expert_mask = one_hot[..., :num_experts].permute(2, 1, 0)
|
||||
|
||||
for expert_idx in range(num_experts):
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
tokens_for_this_expert = x[None, top_x].reshape(-1, hidden_dim)
|
||||
|
||||
gate_out = F.linear(tokens_for_this_expert, w1_weight[expert_idx])
|
||||
up_out = F.linear(tokens_for_this_expert, w3_weight[expert_idx])
|
||||
activated = F.silu(gate_out)
|
||||
prod = activated * up_out
|
||||
expert_out = F.linear(prod, w2_weight[expert_idx])
|
||||
|
||||
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(
|
||||
0, top_x, current_hidden_states.to(final_hidden_states.dtype)
|
||||
def make_mlp(i):
|
||||
return lambda inp: F.linear(
|
||||
F.silu(F.linear(inp, w1_weight[i])) * F.linear(inp, w3_weight[i]), w2_weight[i]
|
||||
)
|
||||
|
||||
return final_hidden_states.view(x_shape)
|
||||
mlps = [make_mlp(i) for i in range(len(w1_weight))]
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@torch_moe.register_fake
|
||||
def torch_moe(
|
||||
def torch_moe_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
@ -133,7 +145,7 @@ def torch_fused_moe(
|
||||
|
||||
|
||||
@torch_fused_moe.register_fake
|
||||
def torch_fused_moe(
|
||||
def torch_fused_moe_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
@ -141,3 +153,174 @@ def torch_fused_moe(
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fp8_moe", mutates_args=())
|
||||
def torch_quant_fp8_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
w1_input_scale: List[torch.Tensor],
|
||||
w2_input_scale: List[torch.Tensor],
|
||||
w3_input_scale: List[torch.Tensor],
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP8 MoE op using quantized linear operations.
|
||||
|
||||
Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op, but uses the
|
||||
quantized FP8 linear op for expert computations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, H) or (B, S, H).
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
|
||||
|
||||
"""
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
)
|
||||
prod = F.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@torch_quant_fp8_moe.register_fake
|
||||
def torch_quant_fp8_moe_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
w1_input_scale: List[torch.Tensor],
|
||||
w2_input_scale: List[torch.Tensor],
|
||||
w3_input_scale: List[torch.Tensor],
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fp4_moe", mutates_args=())
|
||||
def torch_quant_fp4_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
w1_input_scale: List[torch.Tensor],
|
||||
w2_input_scale: List[torch.Tensor],
|
||||
w3_input_scale: List[torch.Tensor],
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP4 MoE op using quantized linear operations.
|
||||
|
||||
Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op,
|
||||
but uses the NVFP4 quantized linear op for expert computations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, H) or (B, S, H).
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
|
||||
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
|
||||
"""
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
if inp.shape[0] == 0:
|
||||
return torch.zeros_like(inp)
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_fp4_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
alpha=w1_alpha[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_fp4_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
alpha=w3_alpha[i],
|
||||
)
|
||||
prod = F.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_fp4_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
alpha=w2_alpha[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@torch_quant_fp4_moe.register_fake
|
||||
def torch_quant_fp4_moe_fake(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
w1_input_scale: List[torch.Tensor],
|
||||
w2_input_scale: List[torch.Tensor],
|
||||
w3_input_scale: List[torch.Tensor],
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -41,6 +41,8 @@ def _generate_mha(
|
||||
input_pos: torch.Tensor,
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
|
||||
max_seq_len, n_kv_heads = k_cache.shape[1:3]
|
||||
@ -97,7 +99,10 @@ def _generate_mha(
|
||||
v_d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
HEAD_BLOCK_SIZE,
|
||||
sliding_window if sliding_window is not None else -1,
|
||||
)
|
||||
has_sinks = sinks is not None
|
||||
|
||||
attention_kv_stage2[(b, n_heads, 1)](
|
||||
stage1_output_values,
|
||||
stage1_output_logsumexp,
|
||||
@ -107,6 +112,8 @@ def _generate_mha(
|
||||
n_heads,
|
||||
v_d_head,
|
||||
SEQ_BLOCK_SIZE,
|
||||
has_sinks,
|
||||
sinks,
|
||||
)
|
||||
|
||||
|
||||
@ -122,6 +129,8 @@ def _flattened_context_mha(
|
||||
seq_start: torch.Tensor,
|
||||
scale: float,
|
||||
out: torch.Tensor,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
# NOTE: s_total == sum(seq_len)
|
||||
s_total, n_heads, q_d_head = q.shape
|
||||
@ -149,6 +158,8 @@ def _flattened_context_mha(
|
||||
|
||||
# TODO: use input_pos to get the correct cache locations
|
||||
grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
|
||||
has_sinks = sinks is not None
|
||||
|
||||
context_attention_kv_flattened[grid](
|
||||
q,
|
||||
seq_len,
|
||||
@ -165,7 +176,9 @@ def _flattened_context_mha(
|
||||
v_d_head,
|
||||
SEQ_BLOCK,
|
||||
max_cache_seq_len,
|
||||
num_stages=2,
|
||||
sliding_window if sliding_window is not None else -1,
|
||||
has_sinks,
|
||||
sinks,
|
||||
)
|
||||
|
||||
|
||||
@ -187,6 +200,8 @@ def flattened_mha_with_cache(
|
||||
# <none>
|
||||
# CONSTANTS
|
||||
scale: Optional[float],
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Flattened MHA with cache that takes q, k, v in BSND layout.
|
||||
|
||||
@ -223,7 +238,9 @@ def flattened_mha_with_cache(
|
||||
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
|
||||
if s == 1:
|
||||
# generate-only phase
|
||||
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y)
|
||||
_generate_mha(
|
||||
q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y, sinks, sliding_window
|
||||
)
|
||||
else:
|
||||
# mixed context + generate phase
|
||||
_flattened_context_mha(
|
||||
@ -238,6 +255,8 @@ def flattened_mha_with_cache(
|
||||
seq_start,
|
||||
scale,
|
||||
y,
|
||||
sinks,
|
||||
sliding_window,
|
||||
)
|
||||
|
||||
return y.view(*output_shape)
|
||||
@ -255,6 +274,8 @@ def flattened_mha_fake(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
scale: Optional[float],
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
|
||||
|
||||
@ -388,7 +409,11 @@ class TritonAttention(AttentionDescriptor):
|
||||
if not isinstance(scale, float):
|
||||
ad_logger.warning("Provided scale is not a float, Using default scale instead.")
|
||||
scale = None
|
||||
|
||||
# Get sinks and sliding_window from args or kwargs
|
||||
sinks = extract_op_args(source_attn_node, "sinks")[0]
|
||||
sliding_window = extract_op_args(source_attn_node, "sliding_window")[0]
|
||||
return [
|
||||
scale, # softmax scale
|
||||
sinks,
|
||||
sliding_window,
|
||||
]
|
||||
|
||||
@ -112,6 +112,7 @@ def gqa_attention_kv_stage1(
|
||||
V_D_HEAD: tl.constexpr, # Dimension of each key/value head
|
||||
SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
|
||||
HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
):
|
||||
"""Attention kernel to be used for generate-only batches.
|
||||
|
||||
@ -122,7 +123,7 @@ def gqa_attention_kv_stage1(
|
||||
Supports non-power-of-2 D_HEAD
|
||||
|
||||
Uses flash decoding.
|
||||
KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
|
||||
KV-cache layout is assumed to be [Batch, Seq, Head, Dim]
|
||||
1. Fetch the K-cache from 0 to input_pos
|
||||
2. Fetch the V-cache from 0 to input_pos
|
||||
3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
|
||||
@ -145,10 +146,20 @@ def gqa_attention_kv_stage1(
|
||||
|
||||
# The number of Q heads that map to each KV head.
|
||||
HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
|
||||
if seq_start_pos > kv_position:
|
||||
return
|
||||
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
|
||||
seq_mask = seq_offsets <= kv_position
|
||||
|
||||
# Apply sliding window constraints
|
||||
if SLIDING_WINDOW > 0:
|
||||
# For sliding window, limit the sequence range
|
||||
sliding_start = tl.maximum(0, kv_position - SLIDING_WINDOW + 1)
|
||||
if seq_start_pos + SEQ_BLOCK_SIZE <= sliding_start or seq_start_pos > kv_position:
|
||||
return
|
||||
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
|
||||
seq_mask = (seq_offsets <= kv_position) & (seq_offsets >= sliding_start)
|
||||
else:
|
||||
if seq_start_pos > kv_position:
|
||||
return
|
||||
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
|
||||
seq_mask = seq_offsets <= kv_position
|
||||
|
||||
# Need to pad the head dim to 16 if HEAD_RATIO is < 16 so that tensor cores can be invoked
|
||||
#
|
||||
@ -358,6 +369,8 @@ def attention_kv_stage2(
|
||||
N_HEADS: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
SEQ_BLOCK_SIZE: tl.constexpr, # Nearest power of 2 for num_blocks
|
||||
HAS_SINKS: tl.constexpr,
|
||||
sinks_ptr,
|
||||
):
|
||||
# There are batch * N_HEADS programs
|
||||
batch_id = tl.program_id(axis=0)
|
||||
@ -382,6 +395,11 @@ def attention_kv_stage2(
|
||||
sumexp = tl.exp(logsumexp - max_logsumexp) # [NUM_BLOCKS_POW2]
|
||||
|
||||
aggregate_sumexp = tl.sum(sumexp, axis=0)
|
||||
# Add sinks contribution to the softmax denominator
|
||||
if HAS_SINKS:
|
||||
sinks_val = tl.load(sinks_ptr + batch_id * N_HEADS + head_id)
|
||||
sinks_exp = tl.exp(sinks_val - max_logsumexp)
|
||||
aggregate_sumexp += sinks_exp
|
||||
|
||||
values_offsets = block_offsets[:, None] * D_HEAD + dhead_offsets[None, :]
|
||||
values_mask = block_mask[:, None] * dhead_mask[None, :]
|
||||
@ -573,6 +591,9 @@ def context_attention_kv_flattened(
|
||||
V_D_HEAD: tl.constexpr, # Dimension of each value head.
|
||||
SEQ_BLOCK: tl.constexpr,
|
||||
MAX_SEQ_LENGTH: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr, # Sliding window size, -1 means no sliding window
|
||||
HAS_SINKS: tl.constexpr,
|
||||
sinks_ptr,
|
||||
):
|
||||
"""Kernel for context phase.
|
||||
|
||||
@ -623,7 +644,15 @@ def context_attention_kv_flattened(
|
||||
# input_pos_ptr stores the location at which kv must be written back for the given batch.
|
||||
kv_position = tl.load(input_pos_ptr + batch_id)
|
||||
num_blocks = (kv_position + seq_len + SEQ_BLOCK - 1) // SEQ_BLOCK
|
||||
for s in range(0, num_blocks + 1, 1):
|
||||
start = 0
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Use the LAST query in this block for more conservative start calculation
|
||||
last_q_pos = (
|
||||
(seq_block_id + 1) * SEQ_BLOCK - 1 + kv_position
|
||||
) # Last query's absolute position
|
||||
earliest_kv_pos = max(0, last_q_pos - SLIDING_WINDOW + 1)
|
||||
start = max(0, earliest_kv_pos // SEQ_BLOCK)
|
||||
for s in range(start, num_blocks + 1):
|
||||
kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
|
||||
kv_seq_mask = kv_seq_offsets < (kv_position + seq_len)
|
||||
|
||||
@ -637,9 +666,17 @@ def context_attention_kv_flattened(
|
||||
)
|
||||
qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
|
||||
qk += tl.dot(q, k.trans())
|
||||
qk = tl.where(
|
||||
(seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
|
||||
)
|
||||
# Apply causal mask
|
||||
causal_mask = (seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :]
|
||||
# Apply sliding window mask if enabled
|
||||
if SLIDING_WINDOW > 0:
|
||||
sliding_window_mask = kv_seq_offsets[None, :] >= (
|
||||
seq_offsets[:, None] + kv_position - SLIDING_WINDOW + 1
|
||||
)
|
||||
combined_mask = sliding_window_mask & causal_mask
|
||||
else:
|
||||
combined_mask = causal_mask
|
||||
qk = tl.where(combined_mask, qk, float("-inf"))
|
||||
qk *= SCALE
|
||||
# rowmax
|
||||
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
||||
@ -662,6 +699,16 @@ def context_attention_kv_flattened(
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
|
||||
# Add sinks contribution to the final softmax calculation
|
||||
if HAS_SINKS:
|
||||
sinks_val = tl.load(sinks_ptr + batch_id * N_HEADS + head_id)
|
||||
m_sinks = tl.maximum(m_i, sinks_val)
|
||||
acc_scale = tl.exp(m_i - m_sinks)
|
||||
acc = acc * acc_scale[:, None]
|
||||
l_sinks = tl.exp(lse_i - m_sinks) + tl.exp(sinks_val - m_sinks)
|
||||
lse_i = m_sinks + tl.log(l_sinks)
|
||||
m_i = m_sinks
|
||||
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
|
||||
acc = acc * o_scale[:, None]
|
||||
|
||||
5
tensorrt_llm/_torch/auto_deploy/export/__init__.py
Normal file
5
tensorrt_llm/_torch/auto_deploy/export/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""AutoDeploy's modular export patch system."""
|
||||
|
||||
from . import library # ensure all patches are registered
|
||||
from .export import *
|
||||
from .interface import *
|
||||
284
tensorrt_llm/_torch/auto_deploy/export/export.py
Normal file
284
tensorrt_llm/_torch/auto_deploy/export/export.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""Main export functionality with utilities for torch.export."""
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.export as te
|
||||
import torch.nn as nn
|
||||
from torch import fx
|
||||
|
||||
from ..transformations._graph import (
|
||||
canonicalize_graph,
|
||||
lift_to_meta,
|
||||
load_buffers_and_params,
|
||||
tree_to,
|
||||
)
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from .interface import ExportPatchRegistry, apply_export_patches
|
||||
|
||||
try:
|
||||
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
|
||||
except ImportError:
|
||||
torch_export_context = nullcontext
|
||||
|
||||
|
||||
def _clean_up_device_info(gm: fx.GraphModule) -> None:
|
||||
"""Correct device information in the graph."""
|
||||
devices = {t.device for _, t in gm.named_parameters()}
|
||||
if len(devices) == 0:
|
||||
return
|
||||
elif len(devices) > 1:
|
||||
raise AssertionError("All parameters should be on the same device.")
|
||||
device = devices.pop()
|
||||
meta_device = torch.device("meta")
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if any(a == meta_device for a in node.args):
|
||||
new_args = list(node.args)
|
||||
new_args = [a if a != meta_device else device for a in new_args]
|
||||
node.args = tuple(new_args)
|
||||
if any(a == meta_device for a in node.kwargs.values()):
|
||||
new_kwargs = dict(node.kwargs)
|
||||
new_kwargs = {k: v if v != meta_device else device for k, v in new_kwargs.items()}
|
||||
node.kwargs = new_kwargs
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _load_hook_for_deduplication(
|
||||
state_dict, prefix, *args, param_key_remaining: str, param_key_removed: str
|
||||
):
|
||||
"""Check for removed param key and and put it into the key that is remaining."""
|
||||
ad_logger.debug(f"Loading hook for deduplication: {param_key_remaining} <- {param_key_removed}")
|
||||
k_remaining = prefix + param_key_remaining
|
||||
k_removed = prefix + param_key_removed
|
||||
if k_removed in state_dict:
|
||||
state_dict[k_remaining] = state_dict.pop(k_removed)
|
||||
|
||||
|
||||
def _deduplicate_params_and_buffers(gm: fx.GraphModule) -> None:
|
||||
"""This will de-duplicate params and buffers that share the same tensor."""
|
||||
# get all get_attr nodes
|
||||
get_attr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
|
||||
|
||||
# sort by id of target
|
||||
targets: Dict[int, List[fx.Node]] = defaultdict(list)
|
||||
for n in get_attr_nodes:
|
||||
submod, _, name = n.target.rpartition(".")
|
||||
t_target = getattr(gm.get_submodule(submod), name)
|
||||
targets[id(t_target)].append(n)
|
||||
# now replace all instances of the same tensor with the same get_attr node (idx 0 in the list)
|
||||
for nodes in targets.values():
|
||||
node_kept = nodes[0]
|
||||
for n in nodes[1:]:
|
||||
n.replace_all_uses_with(node_kept)
|
||||
gm.graph.erase_node(n)
|
||||
|
||||
# remove the param/buffer from the submodule
|
||||
submod, _, name = n.target.rpartition(".")
|
||||
delattr(gm.get_submodule(submod), name)
|
||||
|
||||
# add load hooks to also load the weights correctly
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_load_hook_for_deduplication,
|
||||
param_key_remaining=str(node_kept.target),
|
||||
param_key_removed=str(n.target),
|
||||
)
|
||||
)
|
||||
|
||||
ad_logger.debug(f"Deduplicated: {n.target} --> {node_kept.target}")
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> None:
|
||||
"""Adds back the state dict load hooks stripped away during export."""
|
||||
hooks = {
|
||||
k: mod._load_state_dict_pre_hooks
|
||||
for k, mod in model.named_modules()
|
||||
if mod._load_state_dict_pre_hooks
|
||||
}
|
||||
|
||||
for mod_name, mod in gm.named_modules():
|
||||
if mod_name in hooks:
|
||||
for hook in hooks.pop(mod_name).values():
|
||||
mod._register_load_state_dict_pre_hook(hook.hook, with_module=hook.with_module)
|
||||
assert not (bool(hooks)), f"""Mismatch in names of exported and source modules with hooks.
|
||||
The following module names were not found in exported module {list(hooks.keys())}"""
|
||||
|
||||
|
||||
def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None:
|
||||
"""
|
||||
Add a load hook to handle aliased parameters in the model.
|
||||
|
||||
When parameters are aliased (multiple parameter names point to the same tensor),
|
||||
we need to ensure all aliases get the same value during loading. This hook:
|
||||
1. Identifies groups of aliased parameters
|
||||
2. For each group, finds a valid parameter value from the state dict
|
||||
3. Applies that value to all aliases in the group
|
||||
|
||||
Args:
|
||||
gm: The graph module to add the hook to
|
||||
model: The source model containing the original parameter aliases
|
||||
"""
|
||||
|
||||
def find_valid_param_value(
|
||||
state_dict: Dict[str, torch.Tensor], param_names: List[str]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Find a valid parameter value from state dict for a group of aliased parameters.
|
||||
|
||||
Args:
|
||||
state_dict: The state dict being loaded
|
||||
param_names: List of parameter names that are aliases of each other
|
||||
|
||||
Returns:
|
||||
A valid tensor value if found, None otherwise
|
||||
"""
|
||||
# First try to find a non-meta tensor value
|
||||
value = None
|
||||
for name in param_names:
|
||||
if name in state_dict:
|
||||
value = state_dict[name]
|
||||
if value.device.type != "meta":
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *args, **kwargs):
|
||||
"""Load hook that ensures aliased parameters get the same value."""
|
||||
for group in aliased_groups:
|
||||
# Find a valid value for this group of aliases
|
||||
value = find_valid_param_value(state_dict, group)
|
||||
|
||||
if value is not None:
|
||||
# Apply the value to all aliases
|
||||
for name in group:
|
||||
state_dict[name] = value
|
||||
|
||||
ad_logger.debug(f"Applied value from {group[0]} to aliased parameters: {group}")
|
||||
|
||||
# Find all parameter aliases in the source model
|
||||
param_to_names = defaultdict(list)
|
||||
for name, param in model.named_parameters(remove_duplicate=False):
|
||||
param_to_names[id(param)].append(name)
|
||||
|
||||
# Filter to only groups with multiple aliases
|
||||
aliased_groups = [names for names in param_to_names.values() if len(names) > 1]
|
||||
|
||||
if not aliased_groups:
|
||||
return
|
||||
|
||||
# Register the hook
|
||||
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
|
||||
|
||||
|
||||
def _clean_up_assertions(gm: fx.GraphModule):
|
||||
"""This transformations removes shape checks and assertions from the graph."""
|
||||
check_ops = {
|
||||
torch.ops.aten._assert_scalar,
|
||||
torch.ops.aten.sym_constrain_range,
|
||||
torch.ops.aten.sym_constrain_range_for_size,
|
||||
torch.ops.aten._assert_tensor_metadata,
|
||||
# torch.ops.aten._functional_sym_constrain_range,
|
||||
# torch.ops.aten._functional_sym_constrain_range_for_size
|
||||
}
|
||||
graph: fx.Graph = gm.graph
|
||||
for node in reversed(graph.nodes):
|
||||
if len(node.users) > 0 or not is_op(node, check_ops):
|
||||
continue
|
||||
graph.erase_node(node)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def torch_export_to_gm(
|
||||
model: nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
clone: bool = False, # clone or don't clone the model state_dict
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
|
||||
patch_list: Optional[List[str]] = None,
|
||||
) -> fx.GraphModule:
|
||||
"""torch's export with wrapping into GraphModule + useful additions to the resulting module.
|
||||
|
||||
This utility improves over stock torch.export.export in the following aspects:
|
||||
|
||||
1. Provide patches for certain corner cases that torch.export does not support.
|
||||
2. Standardize the export process to strictly run on the meta device.
|
||||
3. Automatically extract the GraphModule from the exported program.
|
||||
4. Retain load hooks for state_dict loading from the original module.
|
||||
5. Manage parameter aliasing in the model.
|
||||
6. Remove assertions from the graph.
|
||||
|
||||
Args:
|
||||
model: The model to export
|
||||
args: Arguments for the model
|
||||
kwargs: Keyword arguments for the model
|
||||
clone: Whether to clone the model state_dict
|
||||
dynamic_shapes: Dynamic shapes for the export
|
||||
strict: Whether to use strict mode for export
|
||||
patch_configs: Optional patch configurations. If None, all registered patches
|
||||
will be applied with default settings.
|
||||
patch_list: Optional list of patch names to apply with default settings.
|
||||
Cannot be used together with patch_configs.
|
||||
"""
|
||||
# Validate that both patch_configs and patch_list are not provided simultaneously
|
||||
if patch_configs is not None and patch_list is not None:
|
||||
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
|
||||
|
||||
# Handle patch configuration
|
||||
if patch_list is not None:
|
||||
# Convert patch_list to patch_configs format
|
||||
patch_configs = {patch_name: {} for patch_name in patch_list}
|
||||
elif patch_configs is None:
|
||||
# Default patch configurations - apply all registered patches with default settings
|
||||
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
|
||||
|
||||
# run export with patches and lifted to meta
|
||||
with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict:
|
||||
# clean up args, kwargs and move to correct device
|
||||
args, kwargs = tree_to((args, kwargs or {}), device="meta")
|
||||
|
||||
# NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode
|
||||
# context manager. Do NOT move it unless absolutely necessary.
|
||||
with torch.inference_mode():
|
||||
ep = te.export(model, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict)
|
||||
egm = ep.module()
|
||||
assert isinstance(egm, fx.GraphModule)
|
||||
|
||||
# load state_dict into egm
|
||||
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
|
||||
load_buffers_and_params(
|
||||
egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone
|
||||
)
|
||||
|
||||
# Export strips away all methods not traced during forward. The model could have
|
||||
# load hooks that contain logic for correct state_dict loading. We need to add those
|
||||
# hooks back to the exported graph module.
|
||||
_add_missing_load_hooks(egm, model)
|
||||
|
||||
# Add load hook to correctly load parameters that are aliased in the source model.
|
||||
# deduplicate params and buffers
|
||||
# TODO (lucaslie, suyoggupta): seems there is some overlap here. I believe we should just have
|
||||
# the deduplicate function and extend it to handle reading from state dict for any name.
|
||||
_add_load_hook_for_aliased_params(egm, model)
|
||||
_deduplicate_params_and_buffers(egm)
|
||||
|
||||
# clean up devices in the graph
|
||||
# This is a consequence of lifting to meta during export.
|
||||
_clean_up_device_info(egm)
|
||||
|
||||
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
|
||||
_clean_up_assertions(egm)
|
||||
|
||||
# show exported graph
|
||||
ad_logger.debug("exported graph: " + str(egm))
|
||||
|
||||
return egm
|
||||
249
tensorrt_llm/_torch/auto_deploy/export/interface.py
Normal file
249
tensorrt_llm/_torch/auto_deploy/export/interface.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""The interface for all export patches.
|
||||
|
||||
This module defines the base classes and interfaces for all export patches.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, Union, final
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
|
||||
class ExportPatchError(Exception):
|
||||
"""An exception raised when an export patch fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExportPatchConfig(BaseModel):
|
||||
"""Base configuration class for export patches."""
|
||||
|
||||
model_config = {
|
||||
"extra": "allow", # Allow subclasses to add more fields
|
||||
}
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable this patch.",
|
||||
)
|
||||
skip_on_error: bool = Field(
|
||||
default=False,
|
||||
description="Whether to skip the patch if an error occurs during application.",
|
||||
)
|
||||
|
||||
|
||||
class BaseExportPatch(ABC):
|
||||
"""Base class for all export patches.
|
||||
|
||||
Export patches are context managers that apply temporary modifications
|
||||
to the global state during torch.export, then revert them afterwards.
|
||||
"""
|
||||
|
||||
config: ExportPatchConfig
|
||||
_patch_key: str # Set by ExportPatchRegistry.register() decorator
|
||||
|
||||
@classmethod
|
||||
def get_patch_key(cls) -> str:
|
||||
"""Get the short name of the patch."""
|
||||
if hasattr(cls, "_patch_key"):
|
||||
return cls._patch_key
|
||||
raise NotImplementedError(
|
||||
f"Patch class {cls.__name__} must be registered with ExportPatchRegistry.register() "
|
||||
"or manually implement get_patch_key()"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[ExportPatchConfig]:
|
||||
"""Get the configuration class for the patch."""
|
||||
return ExportPatchConfig
|
||||
|
||||
@final
|
||||
def __init__(self, config: ExportPatchConfig):
|
||||
"""Initialize the patch.
|
||||
|
||||
Args:
|
||||
config: The configuration for the patch.
|
||||
"""
|
||||
if not isinstance(config, self.get_config_class()):
|
||||
config = self.get_config_class()(**config.model_dump())
|
||||
self.config = config
|
||||
self.original_values = {}
|
||||
self._post_init()
|
||||
|
||||
def _post_init(self):
|
||||
"""Post-initialization hook that can be overridden by subclasses."""
|
||||
pass
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "BaseExportPatch":
|
||||
"""Create a patch from kwargs."""
|
||||
config = cls.get_config_class()(**kwargs)
|
||||
return cls(config=config)
|
||||
|
||||
@final
|
||||
def __enter__(self):
|
||||
"""Enter the context manager and apply the patch."""
|
||||
if not self.config.enabled:
|
||||
ad_logger.debug(f"Patch {self.get_patch_key()} is disabled, skipping")
|
||||
return self
|
||||
|
||||
try:
|
||||
ad_logger.debug(f"Applying patch: {self.get_patch_key()}")
|
||||
self._apply_patch()
|
||||
except Exception as e:
|
||||
error_msg = f"Patch {self.get_patch_key()} failed to apply"
|
||||
if self.config.skip_on_error:
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
else:
|
||||
raise ExportPatchError(error_msg) from e
|
||||
|
||||
return self
|
||||
|
||||
@final
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the context manager and revert the patch."""
|
||||
if not self.config.enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
ad_logger.debug(f"Reverting patch: {self.get_patch_key()}")
|
||||
self._revert_patch()
|
||||
except Exception as e:
|
||||
error_msg = f"Patch {self.get_patch_key()} failed to revert"
|
||||
if self.config.skip_on_error:
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
else:
|
||||
raise ExportPatchError(error_msg) from e
|
||||
|
||||
@abstractmethod
|
||||
def _apply_patch(self):
|
||||
"""Apply the patch. Should store original values in self.original_values."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _revert_patch(self):
|
||||
"""Revert the patch using stored original values."""
|
||||
pass
|
||||
|
||||
|
||||
class ContextManagerPatch(BaseExportPatch):
|
||||
"""A patch that wraps an existing context manager.
|
||||
|
||||
This allows easy registration of context managers as patches without
|
||||
having to implement the full BaseExportPatch interface.
|
||||
|
||||
Subclasses must implement `init_context_manager()` to return the context manager.
|
||||
"""
|
||||
|
||||
def _post_init(self):
|
||||
self.context_manager: Any = None
|
||||
|
||||
@abstractmethod
|
||||
def init_context_manager(self) -> Any:
|
||||
"""Initialize and return the context manager.
|
||||
|
||||
Returns:
|
||||
A context manager that will be used during export.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the patch by entering the context manager."""
|
||||
self.context_manager = self.init_context_manager()
|
||||
self.context_manager.__enter__()
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the patch by exiting the context manager."""
|
||||
if self.context_manager is not None:
|
||||
self.context_manager.__exit__(None, None, None)
|
||||
self.context_manager = None
|
||||
|
||||
|
||||
class ExportPatchRegistry:
|
||||
"""Registry for export patches."""
|
||||
|
||||
_registry: Dict[str, Type[BaseExportPatch]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str) -> Callable[[Type[BaseExportPatch]], Type[BaseExportPatch]]:
|
||||
"""Register a patch class with the given name."""
|
||||
|
||||
def inner(patch_cls: Type[BaseExportPatch]) -> Type[BaseExportPatch]:
|
||||
cls._registry[name] = patch_cls
|
||||
# Auto-store the patch key as a class attribute
|
||||
patch_cls._patch_key = name
|
||||
return patch_cls
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Type[BaseExportPatch]:
|
||||
"""Get a patch class by name."""
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls, name: str) -> Type[ExportPatchConfig]:
|
||||
"""Get the configuration class for a patch by name."""
|
||||
return cls.get(name).get_config_class()
|
||||
|
||||
@classmethod
|
||||
def has(cls, name: str) -> bool:
|
||||
"""Check if a patch is registered."""
|
||||
return name in cls._registry
|
||||
|
||||
@classmethod
|
||||
def create_patch(
|
||||
cls, name: str, config: Union[ExportPatchConfig, Dict[str, Any]]
|
||||
) -> BaseExportPatch:
|
||||
"""Create a patch instance by name."""
|
||||
patch_cls = cls.get(name)
|
||||
if isinstance(config, dict):
|
||||
config = patch_cls.get_config_class()(**config)
|
||||
return patch_cls(config)
|
||||
|
||||
@classmethod
|
||||
def list_patches(cls) -> List[str]:
|
||||
"""List all registered patch names."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]):
|
||||
"""Context manager to apply multiple patches.
|
||||
|
||||
Args:
|
||||
patch_configs: Dict mapping patch names to their configurations.
|
||||
"""
|
||||
patches = []
|
||||
|
||||
# Create patch instances
|
||||
for name, config in patch_configs.items():
|
||||
if not ExportPatchRegistry.has(name):
|
||||
raise ValueError(f"Unknown patch: {name}")
|
||||
patch = ExportPatchRegistry.create_patch(name, config)
|
||||
patches.append(patch)
|
||||
|
||||
# Apply patches using nested context managers
|
||||
if not patches:
|
||||
yield
|
||||
return
|
||||
|
||||
def _apply_patches(remaining_patches):
|
||||
if not remaining_patches:
|
||||
yield
|
||||
return
|
||||
|
||||
patch = remaining_patches[0]
|
||||
with patch:
|
||||
yield from _apply_patches(remaining_patches[1:])
|
||||
|
||||
# log applied patches
|
||||
ad_logger.debug(
|
||||
f"applying export patches: {', '.join([patch.get_patch_key() for patch in patches])}"
|
||||
)
|
||||
|
||||
yield from _apply_patches(patches)
|
||||
16
tensorrt_llm/_torch/auto_deploy/export/library/__init__.py
Normal file
16
tensorrt_llm/_torch/auto_deploy/export/library/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""AutoDeploy's library of export patches.
|
||||
|
||||
This file ensures that all publicly listed files/patches in the library folder are auto-imported
|
||||
and the corresponding patches are registered.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
__all__ = []
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
if module_name.startswith("_"):
|
||||
continue
|
||||
__all__.append(module_name)
|
||||
importlib.import_module(f"{__name__}.{module_name}")
|
||||
@ -0,0 +1,28 @@
|
||||
"""Patch to make torch.autocast a no-op during export."""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("autocast_noop")
|
||||
class AutocastNoopPatch(BaseExportPatch):
|
||||
"""Patch torch.autocast to be a no-op during export.
|
||||
|
||||
This patch replaces torch.autocast with a null context manager
|
||||
that can interfere with export.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the autocast no-op patch."""
|
||||
# Store original function
|
||||
self.original_values["torch.autocast"] = torch.autocast
|
||||
|
||||
# Apply patch
|
||||
torch.autocast = lambda *args, **kwargs: nullcontext()
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the autocast no-op patch."""
|
||||
torch.autocast = self.original_values["torch.autocast"]
|
||||
35
tensorrt_llm/_torch/auto_deploy/export/library/linear.py
Normal file
35
tensorrt_llm/_torch/auto_deploy/export/library/linear.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Patch for F.linear to use simpler implementation during export."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("linear")
|
||||
class LinearPatch(BaseExportPatch):
|
||||
"""Patch F.linear to use a simpler implementation for export.
|
||||
|
||||
This patch replaces F.linear with a version that avoids exporting
|
||||
view operations used to flatten/unflatten multiple batch dimensions.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the linear patch."""
|
||||
# Store original function
|
||||
self.original_values["F.linear"] = F.linear
|
||||
|
||||
# Create patched function
|
||||
def _torch_linear_patch(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.auto_deploy.torch_linear_simple(input, weight, bias)
|
||||
|
||||
# Apply patch
|
||||
F.linear = _torch_linear_patch
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the linear patch."""
|
||||
F.linear = self.original_values["F.linear"]
|
||||
@ -0,0 +1,23 @@
|
||||
"""Patch for modelopt's torch_export_context."""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..interface import ContextManagerPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("modelopt_context")
|
||||
class ModeloptContextPatch(ContextManagerPatch):
|
||||
"""Patch to apply modelopt's torch_export_context during export.
|
||||
|
||||
This patch applies the modelopt quantization context manager around
|
||||
the export process when available, otherwise uses a null context.
|
||||
"""
|
||||
|
||||
def init_context_manager(self):
|
||||
"""Initialize and return the modelopt context manager or nullcontext if not available."""
|
||||
try:
|
||||
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
|
||||
|
||||
return torch_export_context()
|
||||
except ImportError:
|
||||
return nullcontext()
|
||||
27
tensorrt_llm/_torch/auto_deploy/export/library/sdpa.py
Normal file
27
tensorrt_llm/_torch/auto_deploy/export/library/sdpa.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Patch for F.scaled_dot_product_attention to use custom op."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("sdpa")
|
||||
class SdpaPatch(BaseExportPatch):
|
||||
"""Patch F.scaled_dot_product_attention to use custom op during export.
|
||||
|
||||
This patch ensures that scaled_dot_product_attention is represented consistently
|
||||
in the exported graph by using a custom operation.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the SDPA patch."""
|
||||
# Store original function
|
||||
self.original_values["F.scaled_dot_product_attention"] = F.scaled_dot_product_attention
|
||||
|
||||
# Apply patch
|
||||
F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the SDPA patch."""
|
||||
F.scaled_dot_product_attention = self.original_values["F.scaled_dot_product_attention"]
|
||||
@ -0,0 +1,28 @@
|
||||
"""Patch to make torch.nn.attention.sdpa_kernel a no-op during export."""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("sdpa_kernel_noop")
|
||||
class SdpaKernelNoopPatch(BaseExportPatch):
|
||||
"""Patch torch.nn.attention.sdpa_kernel to be a no-op during export.
|
||||
|
||||
This patch replaces torch.nn.attention.sdpa_kernel with a null context manager
|
||||
that can interfere with export.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the sdpa_kernel no-op patch."""
|
||||
# Store original function
|
||||
self.original_values["torch.nn.attention.sdpa_kernel"] = torch.nn.attention.sdpa_kernel
|
||||
|
||||
# Apply patch
|
||||
torch.nn.attention.sdpa_kernel = lambda *args, **kwargs: nullcontext()
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the sdpa_kernel no-op patch."""
|
||||
torch.nn.attention.sdpa_kernel = self.original_values["torch.nn.attention.sdpa_kernel"]
|
||||
@ -0,0 +1,33 @@
|
||||
"""Patch for torch.tensor to handle 0.0 on meta device."""
|
||||
|
||||
import torch
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("tensor_meta_device")
|
||||
class TensorMetaDevicePatch(BaseExportPatch):
|
||||
"""Patch torch.tensor to handle 0.0 on meta device.
|
||||
|
||||
This patch addresses an issue where torch.tensor(0.0, device="meta")
|
||||
doesn't work and needs to be replaced with torch.zeros((), device="meta").
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the tensor meta device patch."""
|
||||
# Store original function
|
||||
self.original_values["torch.tensor"] = torch.tensor
|
||||
|
||||
# Create patched function
|
||||
def _torch_tensor_patch(data, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
if data == 0.0 and device is not None and torch.device(device) == torch.device("meta"):
|
||||
return torch.zeros((), **kwargs)
|
||||
return self.original_values["torch.tensor"](data, **kwargs)
|
||||
|
||||
# Apply patch
|
||||
torch.tensor = _torch_tensor_patch
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the tensor meta device patch."""
|
||||
torch.tensor = self.original_values["torch.tensor"]
|
||||
@ -0,0 +1,43 @@
|
||||
"""Patch for nn.ModuleList.__getitem__ to handle slicing during export."""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("torch_modulelist_getitem")
|
||||
class TorchModuleListGetitemPatch(BaseExportPatch):
|
||||
"""Patch nn.ModuleList.__getitem__ to handle slicing during export.
|
||||
|
||||
This patch addresses a PyTorch issue where nn.ModuleList.__getitem__ with slice
|
||||
indexing doesn't work correctly during export. The workaround returns a simple
|
||||
list for slice operations.
|
||||
|
||||
Reference: https://github.com/pytorch/pytorch/issues/142439
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the ModuleList getitem patch."""
|
||||
# Store original function
|
||||
self.original_values["nn.ModuleList.__getitem__"] = nn.ModuleList.__getitem__
|
||||
|
||||
# Capture the original function for use in closure
|
||||
original_getitem = nn.ModuleList.__getitem__
|
||||
|
||||
# Create patched function
|
||||
def _torch_modulelist_getitem_patch(self: nn.ModuleList, idx):
|
||||
if isinstance(idx, slice):
|
||||
# return a simple list.
|
||||
# NOTE: this obviously only works for any use case where we access the sliced module list
|
||||
# like a regular list like a for-loop. For most other things, this hack will not work.
|
||||
return list(self._modules.values())[idx]
|
||||
else:
|
||||
# Call the original function
|
||||
return original_getitem(self, idx)
|
||||
|
||||
# Apply patch (type ignore needed as return type differs for slice case)
|
||||
nn.ModuleList.__getitem__ = _torch_modulelist_getitem_patch # type: ignore
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the ModuleList getitem patch."""
|
||||
nn.ModuleList.__getitem__ = self.original_values["nn.ModuleList.__getitem__"]
|
||||
@ -0,0 +1,33 @@
|
||||
"""Patch for torch.where to handle case where only condition is provided."""
|
||||
|
||||
import torch
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("torch_where")
|
||||
class TorchWherePatch(BaseExportPatch):
|
||||
"""Patch torch.where to handle the case where only condition is provided.
|
||||
|
||||
This patch addresses the issue where torch.where(condition) should return
|
||||
torch.nonzero(condition, as_tuple=True) but the export process doesn't
|
||||
handle this correctly.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the torch.where patch."""
|
||||
# Store original function
|
||||
self.original_values["torch.where"] = torch.where
|
||||
|
||||
# Create patched function
|
||||
def _torch_where_patch(condition: torch.Tensor, *args, **kwargs):
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
return torch.nonzero(condition, as_tuple=True)
|
||||
return self.original_values["torch.where"](condition, *args, **kwargs)
|
||||
|
||||
# Apply patch
|
||||
torch.where = _torch_where_patch
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the torch.where patch."""
|
||||
torch.where = self.original_values["torch.where"]
|
||||
@ -0,0 +1,78 @@
|
||||
"""Patch for transformers SDPA mask to be export-compatible."""
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
def _transformers_version() -> str:
|
||||
"""Get the version of transformers."""
|
||||
return version.parse(importlib.metadata.version("transformers")).base_version
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("transformers_sdpa_mask")
|
||||
class TransformersSdpaMaskPatch(BaseExportPatch):
|
||||
"""Patch transformers.masking_utils.sdpa_mask to be export-compatible.
|
||||
|
||||
This patch replaces the transformers SDPA mask implementation with an
|
||||
export-compatible version for transformers >= 4.53.0.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the transformers SDPA mask patch."""
|
||||
# this patch is only needed+compatible for transformers >= 4.53.0
|
||||
if version.parse(_transformers_version()) < version.parse("4.53.0"):
|
||||
return # Skip patch for older versions
|
||||
|
||||
try:
|
||||
# imports only after version check
|
||||
from transformers import masking_utils
|
||||
from transformers.integrations.executorch import sdpa_mask_without_vmap
|
||||
|
||||
# recall original implementation
|
||||
self.original_values["masking_utils.sdpa_mask"] = masking_utils.sdpa_mask
|
||||
|
||||
# patch function and mask attention interface
|
||||
masking_utils.sdpa_mask = sdpa_mask_without_vmap
|
||||
|
||||
if "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping:
|
||||
self.original_values["sdpa_local_original"] = (
|
||||
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping["sdpa"]
|
||||
)
|
||||
else:
|
||||
self.original_values["sdpa_local_original"] = None
|
||||
|
||||
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = sdpa_mask_without_vmap
|
||||
|
||||
except ImportError:
|
||||
# If transformers is not available or doesn't have required modules, skip patch
|
||||
pass
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the transformers SDPA mask patch."""
|
||||
# this patch is only needed+compatible for transformers >= 4.53.0
|
||||
if version.parse(_transformers_version()) < version.parse("4.53.0"):
|
||||
return # Skip revert for older versions
|
||||
|
||||
try:
|
||||
# imports only after version check
|
||||
from transformers import masking_utils
|
||||
|
||||
# revert patches
|
||||
if "masking_utils.sdpa_mask" in self.original_values:
|
||||
masking_utils.sdpa_mask = self.original_values["masking_utils.sdpa_mask"]
|
||||
|
||||
if "sdpa_local_original" in self.original_values:
|
||||
if self.original_values["sdpa_local_original"] is None:
|
||||
if "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping:
|
||||
del masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
||||
else:
|
||||
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = self.original_values[
|
||||
"sdpa_local_original"
|
||||
]
|
||||
|
||||
except ImportError:
|
||||
# If transformers is not available, skip revert
|
||||
pass
|
||||
@ -1,35 +1,60 @@
|
||||
import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic import Field, ValidationInfo, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig
|
||||
from ...llmapi.utils import get_type_repr
|
||||
from .models import ModelFactory, ModelFactoryRegistry
|
||||
from .transform.interface import TransformConfig
|
||||
from .utils._config import DynamicYamlMixInForSettings
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
|
||||
|
||||
def _try_decode_dict_with_str_values(value: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Try to parse string values as JSON to convert to native types if possible."""
|
||||
for k, v in value.items():
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
value[k] = json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
def _get_config_dict() -> SettingsConfigDict:
|
||||
return SettingsConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
yaml_file=str(files("tensorrt_llm._torch.auto_deploy.config") / "default.yaml"),
|
||||
nested_model_default_partial_update=True,
|
||||
)
|
||||
|
||||
|
||||
def _check_for_default_value_only(
|
||||
cls: Type[BaseSettings], value: Any, info: ValidationInfo, msg: str
|
||||
) -> Any:
|
||||
"""Check if the value is the default value for the field.
|
||||
|
||||
If the value is not the default value, raise a ValueError.
|
||||
"""
|
||||
field_name = info.field_name
|
||||
assert field_name is not None, "field_name should be set for validated field."
|
||||
if value != cls.model_fields[field_name].get_default(call_default_factory=True):
|
||||
raise ValueError(msg)
|
||||
return value
|
||||
|
||||
|
||||
class LlmArgs(BaseLlmArgs):
|
||||
"""LLM arguments specifically for AutoDeploy backend.
|
||||
class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"""An argument class stripped down to AutoDeploy-specific configurations.
|
||||
|
||||
This class extends BaseLlmArgs with AutoDeploy-specific configuration options.
|
||||
AutoDeploy provides automatic deployment and optimization of language models
|
||||
with various attention backends and optimization strategies.
|
||||
This class be used as a drop-in replacement to simplify configuring the AutoDeploy backend and
|
||||
should be used in place of LlmArgs unless more advanced features are needed.
|
||||
|
||||
It is compatible with AutoDeploy's LLM API (``tensorrt_llm._torch.auto_deploy.llm.LLM``) and
|
||||
exposes the full set of parameters used in AutoDeploy's ``InferenceOptimizer``.
|
||||
"""
|
||||
|
||||
model_config = _get_config_dict()
|
||||
|
||||
### MODEL AND TOKENIZER FACTORY ################################################################
|
||||
model: PathLike = Field(
|
||||
description="The path to the model checkpoint or the model name from the Hugging Face Hub."
|
||||
)
|
||||
|
||||
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field(
|
||||
default="AutoModelForCausalLM",
|
||||
description="The model factory to use for loading the model.",
|
||||
@ -56,7 +81,7 @@ class LlmArgs(BaseLlmArgs):
|
||||
"Defaults to the same device as the rest of the pipeline.",
|
||||
)
|
||||
|
||||
tokenizer: Optional[Union[str, Path]] = Field(
|
||||
tokenizer: Optional[PathLike] = Field(
|
||||
description="The tokenizer",
|
||||
default=None,
|
||||
repr=False,
|
||||
@ -70,13 +95,14 @@ class LlmArgs(BaseLlmArgs):
|
||||
"https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127.",
|
||||
)
|
||||
|
||||
skip_tokenizer_init: bool = Field(
|
||||
default=False, description="Whether to skip the tokenizer initialization."
|
||||
)
|
||||
|
||||
### RUNTIME FEATURES ###########################################################################
|
||||
disable_overlap_scheduler: bool = Field(
|
||||
default=True,
|
||||
description="Disable the overlap scheduler. This is a temporary field until the overlap "
|
||||
"scheduler is supported (https://github.com/NVIDIA/TensorRT-LLM/issues/4364).",
|
||||
frozen=True,
|
||||
repr=False,
|
||||
default=False,
|
||||
description="Disable the overlap scheduler in trtllm runtime",
|
||||
)
|
||||
|
||||
enable_mixed_sampler: bool = Field(
|
||||
@ -102,8 +128,14 @@ class LlmArgs(BaseLlmArgs):
|
||||
"supported in AutoDeploy.",
|
||||
)
|
||||
|
||||
# INFERENCE OPTIMIZER CONFIG ###################################################################
|
||||
attn_backend: Literal["flashinfer", "triton"] = Field(
|
||||
max_beam_width: int = Field(
|
||||
default=1,
|
||||
description="The maximum beam width. >1 is not supported by AutoDeploy.",
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
### INFERENCE OPTIMIZER CONFIG #################################################################
|
||||
attn_backend: Literal["flashinfer", "triton", "torch"] = Field(
|
||||
default="flashinfer", description="Attention backend to use."
|
||||
)
|
||||
|
||||
@ -138,18 +170,75 @@ class LlmArgs(BaseLlmArgs):
|
||||
|
||||
visualize: bool = Field(default=False, description="Whether to visualize the model graph.")
|
||||
|
||||
### NEW INFERENCE OPTIMIZER CONFIG #############################################################
|
||||
transforms: Dict[str, TransformConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary of transform configurations. The key is the transform name and "
|
||||
"the value is the transform configuration.",
|
||||
)
|
||||
|
||||
### SEQUENCE INTERFACE CONFIG ##################################################################
|
||||
max_input_len: int = Field(default=1024, description="The maximum input length.")
|
||||
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
|
||||
max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.")
|
||||
max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.")
|
||||
attn_page_size: int = Field(
|
||||
default=64,
|
||||
ge=1,
|
||||
description="Page size for attention (tokens_per_block). For triton "
|
||||
"backend, this should equal max_seq_len. Temporary field until tokens_per_block gets "
|
||||
description="Page size for attention (tokens_per_block). For triton and torch "
|
||||
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
|
||||
"properly passed through.",
|
||||
)
|
||||
|
||||
### !!! DO NOT USE !!! #########################################################################
|
||||
### VALIDATION #################################################################################
|
||||
@model_validator(mode="after")
|
||||
def update_attn_page_size(self):
|
||||
# NOTE force attn_page_size to equal max_seq_len for triton backend
|
||||
if self.attn_backend == "triton" or self.attn_backend == "torch":
|
||||
self.attn_page_size = self.max_seq_len
|
||||
return self
|
||||
|
||||
### UTILITY METHODS ############################################################################
|
||||
def create_factory(self) -> ModelFactory:
|
||||
"""Create a model factory from the arguments."""
|
||||
|
||||
# TODO (lucaslie): consider supporting Path objects in the model factory
|
||||
return ModelFactoryRegistry.get(self.model_factory)(
|
||||
model=str(self.model),
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer=None if self.tokenizer is None else str(self.tokenizer),
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
skip_loading_weights=self.skip_loading_weights,
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the arguments to a dictionary."""
|
||||
return self.model_dump()
|
||||
|
||||
def to_llm_args(self) -> "LlmArgs":
|
||||
"""Convert the arguments to a LlmArgs instance that is used for the LLM API."""
|
||||
return LlmArgs(**self.to_dict())
|
||||
|
||||
|
||||
class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
"""LlmArgs config class for providing full expert configurability of the AutoDeploy backend.
|
||||
|
||||
Specifically, this class extends AutoDeployConfig with all the fields from BaseLlmArgs for
|
||||
providing configurability beyond what is provided by AutoDeployConfig.
|
||||
|
||||
Just like AutoDeployConfig, this class is compatible with AutoDeploy's LLM API
|
||||
(``tensorrt_llm._torch.auto_deploy.llm.LLM``) but provides greater configurability.
|
||||
|
||||
NOTE: this class should only be used directly for advanced use cases. For most use cases,
|
||||
AutoDeployConfig should be used instead.
|
||||
|
||||
NOTE: this class may expose redundant fields from BaseLlmArgs or fields that are ignored or
|
||||
have overlapping functionality with AutoDeployConfig. Please be careful when using this class.
|
||||
"""
|
||||
|
||||
model_config = _get_config_dict()
|
||||
|
||||
build_config: Optional[object] = Field(
|
||||
default_factory=lambda: BuildConfig(),
|
||||
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
|
||||
@ -173,16 +262,25 @@ class LlmArgs(BaseLlmArgs):
|
||||
### VALIDATION #################################################################################
|
||||
@field_validator("build_config", mode="before")
|
||||
@classmethod
|
||||
def ensure_no_build_config(cls, value: Any) -> Any:
|
||||
if value is not None:
|
||||
raise ValueError("build_config is not used")
|
||||
return value
|
||||
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "build_config is not in use by AutoDeploy's LlmArgs"
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@field_validator("model_kwargs", "tokenizer_kwargs", mode="after")
|
||||
@field_validator(
|
||||
"tensor_parallel_size",
|
||||
"pipeline_parallel_size",
|
||||
"context_parallel_size",
|
||||
"moe_cluster_parallel_size",
|
||||
"moe_tensor_parallel_size",
|
||||
"moe_expert_parallel_size",
|
||||
"enable_attention_dp",
|
||||
"cp_config",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def validate_model_kwargs(cls, value: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Try to parse string values as JSON to convert to native types if possible."""
|
||||
return _try_decode_dict_with_str_values(value)
|
||||
def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parallel_config(self):
|
||||
@ -192,7 +290,6 @@ class LlmArgs(BaseLlmArgs):
|
||||
rank to automatically shard the model. This is just to ensure that other objects in the
|
||||
runtime that may read parallel_config can do so.
|
||||
"""
|
||||
# setup parallel config
|
||||
self._parallel_config = _ParallelConfig(
|
||||
auto_parallel=True, gpus_per_node=self.gpus_per_node
|
||||
)
|
||||
@ -204,26 +301,7 @@ class LlmArgs(BaseLlmArgs):
|
||||
"""Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class."""
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def update_attn_page_size(self):
|
||||
# NOTE force attn_page_size to equal max_seq_len for triton backend
|
||||
if self.attn_backend == "triton":
|
||||
self.attn_page_size = self.max_seq_len
|
||||
return self
|
||||
|
||||
### UTILITY METHODS ############################################################################
|
||||
def create_factory(self) -> ModelFactory:
|
||||
"""Create a model factory from the arguments."""
|
||||
|
||||
return ModelFactoryRegistry.get(self.model_factory)(
|
||||
model=self.model,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer=self.tokenizer,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
skip_loading_weights=self.skip_loading_weights,
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
# TODO: Remove this after the PyTorch backend is fully migrated to LlmArgs from ExecutorConfig
|
||||
def get_pytorch_backend_config(self) -> "LlmArgs":
|
||||
"""Return the LlmArgs (self) object."""
|
||||
|
||||
@ -1,7 +1,2 @@
|
||||
from . import hf
|
||||
from .decilm import *
|
||||
from .deepseek import *
|
||||
from . import hf, patches
|
||||
from .factory import *
|
||||
from .mixtral import *
|
||||
from .phi import *
|
||||
from .qwen3 import *
|
||||
|
||||
@ -211,9 +211,7 @@ class ModelFactoryRegistry:
|
||||
_registry: Dict[str, Type[ModelFactory]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls: Type[ModelFactory], name: str
|
||||
) -> Callable[[Type[ModelFactory]], Type[ModelFactory]]:
|
||||
def register(cls, name: str) -> Callable[[Type[ModelFactory]], Type[ModelFactory]]:
|
||||
def inner(fn: Type[ModelFactory]) -> Type[ModelFactory]:
|
||||
cls._registry[name] = fn
|
||||
return fn
|
||||
|
||||
@ -28,6 +28,7 @@ from transformers.utils import (
|
||||
)
|
||||
|
||||
from ..custom_ops.attention_interface import CacheConfig
|
||||
from ..utils._config import deep_merge_dicts
|
||||
from ..utils.logger import ad_logger
|
||||
from .factory import ModelFactory, ModelFactoryRegistry
|
||||
|
||||
@ -62,25 +63,27 @@ def hf_load_state_dict_with_device(device: DeviceLikeType):
|
||||
|
||||
@ModelFactoryRegistry.register("AutoModelForCausalLM")
|
||||
class AutoModelForCausalLMFactory(ModelFactory):
|
||||
_tokenizer_defaults = {
|
||||
"legacy": False,
|
||||
"padding_side": "left",
|
||||
"truncation_side": "left",
|
||||
"trust_remote_code": True,
|
||||
"use_fast": True,
|
||||
}
|
||||
|
||||
_model_defaults = {
|
||||
"use_cache": False,
|
||||
"max_position_embeddings": 1024,
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._quant_config: Optional[Dict] = None
|
||||
|
||||
# Relevant default tokenizer kwargs for HF-style tokenizer
|
||||
defaults = {
|
||||
"legacy": False,
|
||||
"padding_side": "left",
|
||||
"truncation_side": "left",
|
||||
"trust_remote_code": True,
|
||||
"use_fast": True,
|
||||
}
|
||||
self.tokenizer_kwargs = {**defaults, **self.tokenizer_kwargs}
|
||||
|
||||
# NEVER use cache
|
||||
self.model_kwargs["use_cache"] = False
|
||||
# Ensure max_seq_len is propagated to model_kwargs
|
||||
self.model_kwargs["max_position_embeddings"] = self.max_seq_len
|
||||
# Ingest defaults for tokenizer and model kwargs
|
||||
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
|
||||
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
|
||||
|
||||
# special handling for torch_dtype in model_kwargs since HF does not correctly update
|
||||
# torch_dtype string to an actual torch.dtype object (only with default)
|
||||
@ -114,7 +117,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
|
||||
def _recursive_update_config(self, config: PretrainedConfig, update_dict: Dict[str, Any]):
|
||||
"""
|
||||
Recursively update a PretrainedConfig object with values from update_dict.
|
||||
Deep-merge a PretrainedConfig object with values from update_dict.
|
||||
|
||||
Args:
|
||||
config: PretrainedConfig object to update
|
||||
@ -302,7 +305,13 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
ckpt_file = self._get_checkpoint_file(self.model)
|
||||
# reuse the load checkpoint utility from accelerate
|
||||
with hf_load_state_dict_with_device(device):
|
||||
load_checkpoint_in_model(model, checkpoint=ckpt_file)
|
||||
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
|
||||
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
|
||||
# which collects local model params, syncs weights from checkpoint, and applies them via
|
||||
# model.load_state_dict.
|
||||
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
|
||||
# model-transformed weights,leading to unexpected key mismatches or format issues.
|
||||
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
||||
|
||||
def _load_quantization_config(self):
|
||||
"""Load the quantization config from the model directory if not done already."""
|
||||
@ -326,21 +335,14 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
|
||||
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
|
||||
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# additional heuristic to propagate "important keys"
|
||||
# TODO (lucaslie): WAR until we have better support on dashboard to control model_kwargs
|
||||
keys_to_propagate = [
|
||||
"num_hidden_layers",
|
||||
"max_position_embeddings",
|
||||
"use_cache",
|
||||
"torch_dtype",
|
||||
]
|
||||
self.model_kwargs["text_config"] = self.model_kwargs.get("text_config", {})
|
||||
for key in keys_to_propagate:
|
||||
if key in self.model_kwargs:
|
||||
self.model_kwargs["text_config"][key] = self.model_kwargs[key]
|
||||
_model_defaults = {
|
||||
"use_cache": False,
|
||||
"max_position_embeddings": 1024,
|
||||
"text_config": {
|
||||
"max_position_embeddings": 1024,
|
||||
"use_cache": False,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def automodel_from_config(self):
|
||||
|
||||
16
tensorrt_llm/_torch/auto_deploy/models/patches/__init__.py
Normal file
16
tensorrt_llm/_torch/auto_deploy/models/patches/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""AutoDeploy's library of export patches for models.
|
||||
|
||||
This file ensures that all publicly listed files/patches in the library folder are auto-imported
|
||||
and the corresponding patches are registered.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
__all__ = []
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
if module_name.startswith("_"):
|
||||
continue
|
||||
__all__.append(module_name)
|
||||
importlib.import_module(f"{__name__}.{module_name}")
|
||||
@ -12,4 +12,5 @@ def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs):
|
||||
return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
|
||||
# TODO: figure out how this can be incorporated into the export patch system
|
||||
AutoConfig.from_pretrained = _from_pretrained_patched
|
||||
@ -181,4 +181,5 @@ def get_model_from_config_patched(config, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
# TODO: figure out how this can be incorporated into the export patch system
|
||||
AutoModelForCausalLM.from_config = get_model_from_config_patched
|
||||
@ -5,6 +5,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from ...export.interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor):
|
||||
# check if we can apply the patch
|
||||
@ -46,5 +48,28 @@ def _forward_moe(self: MixtralSparseMoeBlock, hidden_states: torch.Tensor):
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
MixtralSparseMoeBlock._original_forward = MixtralSparseMoeBlock.forward
|
||||
MixtralSparseMoeBlock.forward = _forward_moe
|
||||
@ExportPatchRegistry.register("hf_mixtral_moe")
|
||||
class MixtralMoePatch(BaseExportPatch):
|
||||
"""Patch for Mixtral MoE to make it compatible with torch.export.
|
||||
|
||||
This patch replaces the forward method of MixtralSparseMoeBlock with
|
||||
a version that uses the torch_moe custom operator for better export compatibility.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the Mixtral MoE patch."""
|
||||
# Store original forward method
|
||||
self.original_values["MixtralSparseMoeBlock.forward"] = MixtralSparseMoeBlock.forward
|
||||
|
||||
# Apply patch by replacing the forward method
|
||||
MixtralSparseMoeBlock._original_forward = MixtralSparseMoeBlock.forward # type: ignore
|
||||
MixtralSparseMoeBlock.forward = _forward_moe # type: ignore
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the Mixtral MoE patch."""
|
||||
# Restore original forward method
|
||||
MixtralSparseMoeBlock.forward = self.original_values["MixtralSparseMoeBlock.forward"] # type: ignore
|
||||
|
||||
# Clean up the temporary attribute
|
||||
if hasattr(MixtralSparseMoeBlock, "_original_forward"):
|
||||
delattr(MixtralSparseMoeBlock, "_original_forward")
|
||||
@ -173,4 +173,5 @@ def get_model_from_config_patched(config, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
# TODO: figure out how this can be incorporated into the export patch system
|
||||
AutoModelForCausalLM.from_config = get_model_from_config_patched
|
||||
@ -5,6 +5,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
|
||||
from ...export.interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
|
||||
# check if we can apply the patch
|
||||
@ -43,5 +45,28 @@ def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
Qwen3MoeSparseMoeBlock._original_forward = Qwen3MoeSparseMoeBlock.forward
|
||||
Qwen3MoeSparseMoeBlock.forward = _forward_moe
|
||||
@ExportPatchRegistry.register("hf_qwen3_moe")
|
||||
class Qwen3MoePatch(BaseExportPatch):
|
||||
"""Patch for Qwen3 MoE to make it compatible with torch.export and reduce export time.
|
||||
|
||||
This patch replaces the forward method of Qwen3MoeSparseMoeBlock with
|
||||
a version that uses the torch_moe custom operator for better export compatibility.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the Qwen3 MoE patch."""
|
||||
# Store original forward method
|
||||
self.original_values["Qwen3MoeSparseMoeBlock.forward"] = Qwen3MoeSparseMoeBlock.forward
|
||||
|
||||
# Apply patch by replacing the forward method
|
||||
Qwen3MoeSparseMoeBlock._original_forward = Qwen3MoeSparseMoeBlock.forward # type: ignore
|
||||
Qwen3MoeSparseMoeBlock.forward = _forward_moe # type: ignore
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the Qwen3 MoE patch."""
|
||||
# Restore original forward method
|
||||
Qwen3MoeSparseMoeBlock.forward = self.original_values["Qwen3MoeSparseMoeBlock.forward"] # type: ignore
|
||||
|
||||
# Clean up the temporary attribute
|
||||
if hasattr(Qwen3MoeSparseMoeBlock, "_original_forward"):
|
||||
delattr(Qwen3MoeSparseMoeBlock, "_original_forward")
|
||||
@ -25,7 +25,7 @@ from ...pyexecutor.scheduler import (
|
||||
)
|
||||
from ..custom_ops.attention_interface import SequenceInfo
|
||||
from ..distributed import common as dist
|
||||
from ..llm_args import LlmArgs
|
||||
from ..llm_args import AutoDeployConfig, LlmArgs
|
||||
from ..transformations.transform import InferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
from .interface import CachedSequenceInterface, GetInferenceModel
|
||||
@ -82,14 +82,17 @@ class ADEngine(ModelEngine):
|
||||
return self.cache_seq_interface.device
|
||||
|
||||
@classmethod
|
||||
def build_from_config(cls, ad_config: LlmArgs):
|
||||
"""Build the ADEngine using the AD LlmArgs that gets passed through from the LLM."""
|
||||
def build_from_config(cls, ad_config: AutoDeployConfig):
|
||||
"""Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
|
||||
|
||||
max_batch_size = ad_config.max_batch_size
|
||||
max_seq_len = ad_config.max_seq_len
|
||||
attn_page_size = ad_config.attn_page_size
|
||||
max_num_tokens = ad_config.max_num_tokens
|
||||
ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}")
|
||||
max_beam_width = ad_config.max_beam_width
|
||||
ad_logger.info(
|
||||
f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}"
|
||||
)
|
||||
|
||||
# initialize seq info object
|
||||
seq_info = SequenceInfo(
|
||||
@ -111,7 +114,7 @@ class ADEngine(ModelEngine):
|
||||
)
|
||||
|
||||
# construct engine
|
||||
return cls(build_and_optimize, seq_info, device)
|
||||
return cls(build_and_optimize, seq_info, device, max_beam_width)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __init__(
|
||||
@ -119,6 +122,7 @@ class ADEngine(ModelEngine):
|
||||
get_inference_model: GetInferenceModel,
|
||||
seq_info: SequenceInfo,
|
||||
device: DeviceLikeType,
|
||||
max_beam_width: int = 1,
|
||||
) -> None:
|
||||
"""Initialize the engine with model and sequence information."""
|
||||
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
|
||||
@ -131,6 +135,7 @@ class ADEngine(ModelEngine):
|
||||
self.iter_counter = 0
|
||||
|
||||
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
|
||||
self.max_beam_width = max_beam_width
|
||||
self.enable_attention_dp = False
|
||||
|
||||
# construct cache sequence interface
|
||||
@ -147,19 +152,25 @@ class ADEngine(ModelEngine):
|
||||
|
||||
@nvtx_range("ad_prepare_inputs")
|
||||
def _prepare_inputs(
|
||||
self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager
|
||||
) -> bool:
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager,
|
||||
new_tokens: Optional[torch.Tensor] = None,
|
||||
) -> List[bool]:
|
||||
"""Prepare inputs for AD Model from scheduled requests."""
|
||||
# cache manager
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.KV_CACHE_MANAGER
|
||||
)
|
||||
|
||||
# requests in order of context, extend (generate with draft), generate
|
||||
# requests in order of context, generate
|
||||
context_requests = scheduled_requests.context_requests
|
||||
extend_requests = [r for r in scheduled_requests.generation_requests if r.draft_tokens]
|
||||
gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens]
|
||||
|
||||
# new_tokens is a tensor on the device, we need to convert it to a list of lists.
|
||||
# can we avoid this additional gpu->cpu transfer?
|
||||
new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None
|
||||
|
||||
# info to be extracted
|
||||
input_ids: List[List[int]] = []
|
||||
input_pos: List[int] = []
|
||||
@ -172,24 +183,27 @@ class ADEngine(ModelEngine):
|
||||
input_ids.append(request.get_tokens(0))
|
||||
input_pos.append(request.context_current_position)
|
||||
|
||||
# only return last logit
|
||||
request.py_batch_idx = request.seq_slot
|
||||
last_logit_only.append(True)
|
||||
|
||||
# look at extend+generate requests next
|
||||
for request in chain(extend_requests, gen_requests):
|
||||
# store input ids and pos of first token in sequence
|
||||
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
|
||||
input_pos.append(request.max_beam_num_tokens - 1)
|
||||
# look at generate requests next
|
||||
# TODO: we should also handle extend requests (for speculative decoding) here
|
||||
for request in gen_requests:
|
||||
# new_tokens are provided when the overlap scheduler is enabled.
|
||||
if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None:
|
||||
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
|
||||
input_pos.append(request.max_beam_num_tokens - 1)
|
||||
else:
|
||||
input_ids.append([new_tokens_list[request.py_batch_idx]])
|
||||
input_pos.append(request.max_beam_num_tokens)
|
||||
|
||||
# check for draft tokens
|
||||
if request.draft_tokens:
|
||||
input_ids[-1].extend([t for t in request.draft_tokens])
|
||||
request.py_batch_idx = request.seq_slot
|
||||
|
||||
# return all logits
|
||||
last_logit_only.append(False)
|
||||
|
||||
# extract cache information for all requests
|
||||
for request in chain(context_requests, extend_requests, gen_requests):
|
||||
for request in chain(context_requests, gen_requests):
|
||||
# get cache indices
|
||||
cache_indices = kv_cache_manager.get_cache_indices(request)
|
||||
page_assignments.append(cache_indices)
|
||||
@ -199,7 +213,6 @@ class ADEngine(ModelEngine):
|
||||
si.nest_sequences(input_ids)
|
||||
si.update_pos(input_pos, reset=True)
|
||||
si.assign_cache_loc(page_assignments)
|
||||
|
||||
return last_logit_only
|
||||
|
||||
def _compute_logits(self) -> List[torch.Tensor]:
|
||||
@ -224,7 +237,8 @@ class ADEngine(ModelEngine):
|
||||
):
|
||||
"""Run forward from scheduled requests; main entrypoint that gets called by the executor."""
|
||||
# convert requests and store in sequence info object
|
||||
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager)
|
||||
new_tokens = getattr(new_tokens_device, "new_tokens", None)
|
||||
last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens)
|
||||
|
||||
# compute all logits
|
||||
logits = self._compute_logits()
|
||||
@ -303,7 +317,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_draft_len=max_draft_len,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_beam_width=ad_config.max_beam_width,
|
||||
enable_mixed_sampler=ad_config.enable_mixed_sampler,
|
||||
)
|
||||
sampler = TorchSampler(sampler_args)
|
||||
|
||||
4
tensorrt_llm/_torch/auto_deploy/transform/__init__.py
Normal file
4
tensorrt_llm/_torch/auto_deploy/transform/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""AutoDeploy's modular graph transform + inference optimizer pipeline."""
|
||||
|
||||
from . import library # ensure all transforms are registered
|
||||
from .interface import *
|
||||
361
tensorrt_llm/_torch/auto_deploy/transform/interface.py
Normal file
361
tensorrt_llm/_torch/auto_deploy/transform/interface.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""The interface for all transforms.
|
||||
|
||||
This module defines the base classes and interfaces for all transforms.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transformations._graph import canonicalize_graph, lift_to_meta
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
|
||||
class TransformError(Exception):
|
||||
"""An exception raised when a transform fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Stages(Enum):
|
||||
"""Enumerated (ordered!) stages of the transformation pipeline.
|
||||
|
||||
This is used to classify and pre-order transforms.
|
||||
"""
|
||||
|
||||
FACTORY = "factory" # factory stage for building the model
|
||||
EXPORT = "export" # export stage for exporting the model to a graph module
|
||||
POST_EXPORT = "post_export" # low-level cleanups of the exported graph
|
||||
PATTERN_MATCHER = "pattern_matcher" # high-level pattern matching to standardize graph
|
||||
SHARDING = "sharding" # auto-sharding of the graph
|
||||
WEIGHT_LOAD = "weight_load" # loading of the model weights
|
||||
POST_LOAD_FUSION = "post_load_fusion" # post-loading fusion and perf optimizations of the graph
|
||||
CACHE_INIT = "cache_init" # initialization of cached attention + (KV) cache initialization
|
||||
COMPILE = "compile" # graph compilation stage using low-level compilers like torch.compile
|
||||
|
||||
def __lt__(self, other):
|
||||
"""Enable sorting by definition order."""
|
||||
if self.__class__ is other.__class__:
|
||||
return list(self.__class__).index(self) < list(other.__class__).index(other)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class TransformConfig(BaseModel):
|
||||
"""A simple configuration class that can be extended by a transform for configurability."""
|
||||
|
||||
model_config = {
|
||||
# to provide an easy way to do config validation of child config classes with more fields
|
||||
"extra": "allow",
|
||||
}
|
||||
|
||||
### MANDATORY CONFIG ###########################################################################
|
||||
stage: Stages = Field(
|
||||
description="The stage of the transformation pipeline where this transform should run.",
|
||||
)
|
||||
|
||||
### OPTIONAL CONFIG ###########################################################################
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable this transform.",
|
||||
)
|
||||
skip_on_error: bool = Field(
|
||||
default=False,
|
||||
description="Whether to skip the transform if an error occurs.",
|
||||
)
|
||||
|
||||
run_graph_cleanup: bool = Field(
|
||||
default=True,
|
||||
description="Whether to run graph cleanup/canonicalization after this transform.",
|
||||
)
|
||||
run_shape_prop: bool = Field(
|
||||
default=False,
|
||||
description="Whether to run shape propagation after this transform.",
|
||||
)
|
||||
|
||||
requires_clean_graph: bool = Field(
|
||||
default=True,
|
||||
description="Whether this transform requires the graph to be clean before it is applied.",
|
||||
)
|
||||
requires_shape_prop: bool = Field(
|
||||
default=False,
|
||||
description="Whether this transform requires shape propagation before it is applied.",
|
||||
)
|
||||
|
||||
|
||||
AutodeployMeta = Dict[str, Any]
|
||||
_UntypedInferenceOptimizerConfig = Dict[str, Any]
|
||||
StrictInferenceOptimizerConfig = Dict[str, TransformConfig]
|
||||
InferenceOptimizerConfig = Mapping[str, Union[TransformConfig, _UntypedInferenceOptimizerConfig]]
|
||||
|
||||
|
||||
class TransformInfo(BaseModel):
|
||||
"""Information about the result of a transform."""
|
||||
|
||||
model_config = {
|
||||
"frozen": True, # Make the model immutable after creation
|
||||
}
|
||||
|
||||
skipped: bool = Field(
|
||||
description="Whether the transform was skipped.",
|
||||
)
|
||||
num_matches: int = Field(
|
||||
description="Number of matches found.",
|
||||
)
|
||||
is_clean: bool = Field(
|
||||
default=False,
|
||||
description="Whether the graph is clean after the transform. This can be set by the "
|
||||
"transform to indicate that the transform does not change the graph and it preserves the "
|
||||
"is_clean flag of the last transform.",
|
||||
)
|
||||
has_valid_shapes: bool = Field(
|
||||
default=False,
|
||||
description="Whether meta tensor shapes are valid after the transform. This can be set by "
|
||||
"the transform to indicate that the transform does not affect the shapes in the meta "
|
||||
"information of the graph. In other words, the transform does not change the shapes of the "
|
||||
"tensors in the graph and it preserves the has_valid_shapes flag of the last transform.",
|
||||
)
|
||||
|
||||
|
||||
TransformHistory = Dict[str, TransformInfo]
|
||||
|
||||
|
||||
class BaseTransform(ABC):
|
||||
"""A base class for all transforms."""
|
||||
|
||||
config: TransformConfig # overwrite type hint if other config cls is used in subclass!
|
||||
_autodeploy_meta_key: str = "_autodeploy"
|
||||
_history_key: str = "transform_history"
|
||||
_transform_key: str # Set by TransformRegistry.register() decorator
|
||||
|
||||
@classmethod
|
||||
def get_transform_key(cls) -> str:
|
||||
"""Get the short name of the transform.
|
||||
|
||||
This is used to identify the transform in the transformation pipeline.
|
||||
"""
|
||||
if hasattr(cls, "_transform_key"):
|
||||
return cls._transform_key
|
||||
raise NotImplementedError(
|
||||
f"Transform class {cls.__name__} must be registered with TransformRegistry.register() "
|
||||
"or manually implement get_transform_key()"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
"""Get the configuration class for the transform.
|
||||
|
||||
This is used to validate the configuration of the transform.
|
||||
"""
|
||||
return TransformConfig
|
||||
|
||||
@final
|
||||
def __init__(self, config: TransformConfig):
|
||||
"""Initialize the transform.
|
||||
|
||||
Args:
|
||||
config: The configuration for the transform, either as base config object or the actual
|
||||
config object.
|
||||
|
||||
To customize the initialization, override the `_post_init` method.
|
||||
"""
|
||||
if not isinstance(config, self.get_config_class()):
|
||||
config = self.get_config_class()(**config.model_dump())
|
||||
self.config = config
|
||||
self._post_init()
|
||||
|
||||
def _post_init(self):
|
||||
"""Post-initialization hook that can be overridden by subclasses."""
|
||||
pass
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "BaseTransform":
|
||||
"""Create a transform from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: The configuration for the transform.
|
||||
|
||||
Returns:
|
||||
The transform instance.
|
||||
"""
|
||||
config = cls.get_config_class()(**kwargs)
|
||||
return cls(config=config)
|
||||
|
||||
@final
|
||||
def __call__(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> GraphModule:
|
||||
"""Apply the transform to the graph.
|
||||
|
||||
Args:
|
||||
gm: The graph module to apply the transform to.
|
||||
cm: The cached sequence interface defining the sequence interface.
|
||||
factory: The model factory used to build the model.
|
||||
|
||||
Returns:
|
||||
GraphModule: The transformed graph module.
|
||||
|
||||
NOTE: The transform can/should modify the graph module in place if possible. Returning the
|
||||
graph is mostly to standardize the interface for transforms that cannot modify the graph
|
||||
in place (e.g. the factory or export transform).
|
||||
|
||||
This method is the main entry point for any transforms and is called by the
|
||||
InferenceOptimizer pipeline.
|
||||
"""
|
||||
|
||||
# get the transform key
|
||||
t_name = self.get_transform_key()
|
||||
|
||||
# retrieve autodeploy metadata from the graphmodule
|
||||
autodeploy_meta = self._get_autodeploy_meta(gm)
|
||||
|
||||
# retrieve transform history and last transform info
|
||||
history: TransformHistory = autodeploy_meta.get(self._history_key, {})
|
||||
h_keys = list(history.keys()) # preserves order of insertion/transform execution
|
||||
info_last = history[h_keys[-1]] if h_keys else TransformInfo(skipped=False, num_matches=0)
|
||||
|
||||
# show debug info for debug config
|
||||
ad_logger.debug(f"{t_name} config: {self.config}")
|
||||
|
||||
# run or skip the transform
|
||||
if self.config.enabled:
|
||||
# run graph pre-cleanup
|
||||
self._run_pre_cleanup(gm, info_last)
|
||||
|
||||
# run the transform in a error-handling wrapper
|
||||
try:
|
||||
gm, info = self._apply(gm, cm, factory)
|
||||
except Exception as e:
|
||||
error_msg = f"Transform {t_name} failed"
|
||||
if self.config.skip_on_error:
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
info = TransformInfo(skipped=True, num_matches=0)
|
||||
else:
|
||||
raise TransformError(error_msg) from e
|
||||
|
||||
# run graph post-cleanup
|
||||
info = self._run_post_cleanup(gm, info)
|
||||
else:
|
||||
# skip the transform and set info object using the last transform info
|
||||
info_dict = info_last.model_dump()
|
||||
info_dict["skipped"] = True
|
||||
info_dict["num_matches"] = 0
|
||||
info = TransformInfo(**info_dict)
|
||||
|
||||
# log the result of the transform
|
||||
log_msgs = [
|
||||
f"stage={self.config.stage.value}",
|
||||
f"transform={t_name}",
|
||||
"skipped=True" if info.skipped else f"num_matches={info.num_matches}",
|
||||
f"is_clean={info.is_clean}",
|
||||
f"has_valid_shapes={info.has_valid_shapes}",
|
||||
]
|
||||
ad_logger.info(", ".join(log_msgs))
|
||||
ad_logger.debug(f"Graph after {t_name}: {gm}")
|
||||
|
||||
# update + store new meta data
|
||||
history[t_name] = info
|
||||
autodeploy_meta[self._history_key] = history
|
||||
self._set_autodeploy_meta(gm, autodeploy_meta)
|
||||
|
||||
# return the graph module
|
||||
return gm
|
||||
|
||||
@final
|
||||
def _get_autodeploy_meta(self, gm: GraphModule) -> AutodeployMeta:
|
||||
"""Get the autodeploy metadata from the graphmodule."""
|
||||
return gm.meta.get(self._autodeploy_meta_key, {})
|
||||
|
||||
@final
|
||||
def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta) -> None:
|
||||
"""Set the autodeploy metadata in the graphmodule."""
|
||||
gm.meta[self._autodeploy_meta_key] = autodeploy_meta
|
||||
|
||||
@final
|
||||
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None:
|
||||
"""Run graph cleanup before the transform.
|
||||
|
||||
This is used to ensure the transform is applied to a clean graph as needed by the transform.
|
||||
"""
|
||||
if not self.config.requires_clean_graph:
|
||||
return
|
||||
|
||||
# check if run cleanup depending on the config and info
|
||||
if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes):
|
||||
with lift_to_meta(gm):
|
||||
canonicalize_graph(gm, shape_prop=True)
|
||||
elif self.config.requires_clean_graph and not info.is_clean:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
@final
|
||||
def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo:
|
||||
"""Run graph cleanup after the transform.
|
||||
|
||||
Cleanup is done as requested in the config and we will update the graph module and info
|
||||
accordingly.
|
||||
|
||||
Returns:
|
||||
Updated TransformInfo with cleanup status.
|
||||
"""
|
||||
if not self.config.run_graph_cleanup:
|
||||
return info
|
||||
|
||||
# check if run cleanup depending on the config and info
|
||||
if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes):
|
||||
with lift_to_meta(gm):
|
||||
canonicalize_graph(gm, shape_prop=True)
|
||||
elif self.config.run_graph_cleanup and not info.is_clean:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
# create new info object with updated cleanup status
|
||||
info_dict = info.model_dump()
|
||||
info_dict["is_clean"] |= self.config.run_graph_cleanup
|
||||
info_dict["has_valid_shapes"] |= self.config.run_shape_prop
|
||||
return TransformInfo(**info_dict)
|
||||
|
||||
@abstractmethod
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
"""Apply the transform to the graph.
|
||||
|
||||
This is the core method that should be implemented by subclasses.
|
||||
"""
|
||||
|
||||
|
||||
class TransformRegistry:
|
||||
"""A registry for all transforms."""
|
||||
|
||||
_registry: Dict[str, Type[BaseTransform]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str) -> Callable[[Type[BaseTransform]], Type[BaseTransform]]:
|
||||
def inner(fn: Type[BaseTransform]) -> Type[BaseTransform]:
|
||||
cls._registry[name] = fn
|
||||
# Auto-store the transform key as a class attribute
|
||||
fn._transform_key = name
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Type[BaseTransform]:
|
||||
"""Get the transform class by name."""
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls, name: str) -> Type[TransformConfig]:
|
||||
"""Get the configuration class for a transform by name."""
|
||||
return cls.get(name).get_config_class()
|
||||
|
||||
@classmethod
|
||||
def has(cls, name: str) -> bool:
|
||||
"""Check if a transform is registered."""
|
||||
return name in cls._registry
|
||||
@ -0,0 +1,16 @@
|
||||
"""AutoDeploy's library of transforms.
|
||||
|
||||
This file ensures that all publicly listed files/transforms in the library folder are auto-imported
|
||||
and the corresponding transforms are registered.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
__all__ = []
|
||||
|
||||
for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
if module_name.startswith("_"):
|
||||
continue
|
||||
__all__.append(module_name)
|
||||
importlib.import_module(f"{__name__}.{module_name}")
|
||||
@ -0,0 +1,41 @@
|
||||
"""A simple wrapper transform to build a model via the model factory."""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
class BuildModelConfig(TransformConfig):
|
||||
"""Configuration for the build model transform."""
|
||||
|
||||
device: str = Field(default="meta", description="The device to build the model on.")
|
||||
|
||||
|
||||
@TransformRegistry.register("build_model")
|
||||
class BuildModel(BaseTransform):
|
||||
"""A simple wrapper transform to build a model via the model factory."""
|
||||
|
||||
config: BuildModelConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return BuildModelConfig
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# build the model
|
||||
model = factory.build_model(self.config.device)
|
||||
|
||||
# as wrapper to satisfy the interface we will register the model as a submodule
|
||||
gm.add_module("factory_model", model)
|
||||
|
||||
# by convention, we say this fake graph module is always clean
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,49 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
# TODO (lucaslie): consider reconfiguring this transform to run before we switch to flattened
|
||||
# sequences which is done in update_in_out_nodes at the moment.
|
||||
@TransformRegistry.register("cleanup_input_constraints")
|
||||
class CleanupInputConstraints(BaseTransform):
|
||||
"""Cleanup input constraints from the graph.
|
||||
|
||||
This transformations updates the input constraints of the graph. Specifically, we want to
|
||||
account for flattened sequences and hence the max constraint should be updated to reflect the
|
||||
flattened sequence length.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
graph: Graph = gm.graph
|
||||
input_node = graph.find_nodes(op="placeholder")[0]
|
||||
sym_shape: torch.Size = input_node.meta["val"].shape
|
||||
|
||||
# get expressions in the symbolic shape
|
||||
vrs: List[ValueRanges] = []
|
||||
for s in sym_shape:
|
||||
if isinstance(s, int):
|
||||
vrs.append(ValueRanges(0, s))
|
||||
elif isinstance(s, torch.SymInt):
|
||||
vrs.append(gm.range_constraints[s.node.expr])
|
||||
else:
|
||||
raise TypeError(f"Unexpected type {type(s)} in symbolic shape.")
|
||||
|
||||
# update the max constraint for each vr
|
||||
max_total = math.prod(vr.upper for vr in vrs)
|
||||
for vr in vrs:
|
||||
object.__setattr__(vr, "upper", max_total)
|
||||
|
||||
# store info object about the transform
|
||||
info = TransformInfo(skipped=False, num_matches=len(vrs))
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,52 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("cleanup_noop_add")
|
||||
class CleanupNoopAdd(BaseTransform):
|
||||
"""Eliminate add nodes from the graph that are no-ops.
|
||||
|
||||
This would be any node that is just adding 0 to the input tensor. We can safely remove those.
|
||||
|
||||
NOTE: this function has one failure mode when the op ``out = tensor + zero_tensor`` is used
|
||||
in such a way that``out`` will be broadcast to the shape of zero_tensor. After removing this op
|
||||
then, out won't have the right shape anymore. This should be a rare case and we can handle it
|
||||
when it comes up or disable this transform.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
num_matches = 0
|
||||
for node in gm.graph.nodes:
|
||||
# looking for add nodes
|
||||
if not is_op(node, torch.ops.aten.add):
|
||||
continue
|
||||
# only handling this parameter combination for now
|
||||
if len(node.all_input_nodes) != 2:
|
||||
continue
|
||||
|
||||
# check if any of the input nodes is just a constant tensor with value 0
|
||||
if is_op(node.all_input_nodes[0], torch.ops.aten.zeros):
|
||||
zero_node, true_node = node.all_input_nodes
|
||||
elif is_op(node.all_input_nodes[1], torch.ops.aten.zeros):
|
||||
true_node, zero_node = node.all_input_nodes
|
||||
else:
|
||||
continue
|
||||
|
||||
# do the replacement and clean-up
|
||||
node.replace_all_uses_with(true_node)
|
||||
gm.graph.erase_node(node)
|
||||
num_matches += 1
|
||||
|
||||
# store info object about the transform
|
||||
info = TransformInfo(skipped=False, num_matches=num_matches)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,49 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
@TransformRegistry.register("cleanup_noop_slice")
|
||||
class CleanupNoopSlice(BaseTransform):
|
||||
"""Remove no-op slice nodes from the graph.
|
||||
|
||||
Those will be nodes that are used to represent a slice operation like ``t[:, :5]``. The graph IR
|
||||
will represent it as ``t[:][:5]``, i.e., two nodes and the first slice being a no-op. This
|
||||
function gets rid of such instances.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
num_matches = 0
|
||||
for node in gm.graph.nodes:
|
||||
# looking for slice nodes
|
||||
if not is_op(node, torch.ops.aten.slice):
|
||||
continue
|
||||
# only handling this parameter combination for now
|
||||
# 4 args will be (input, dim, start, end)
|
||||
if len(node.args) != 4 or len(node.kwargs) != 0:
|
||||
continue
|
||||
# check if dim is just an integer
|
||||
if not isinstance(node.args[1], int):
|
||||
continue
|
||||
# check if the slice op is indeed a no-op
|
||||
if node.args[2] != 0 or node.args[3] != torch.iinfo(torch.long).max:
|
||||
continue
|
||||
# extract input tensor node and remove the slice node
|
||||
in_node = node.args[0]
|
||||
assert [in_node] == node.all_input_nodes, "Slice node has unexpected input nodes."
|
||||
node.replace_all_uses_with(in_node)
|
||||
gm.graph.erase_node(node)
|
||||
num_matches += 1
|
||||
|
||||
# store info object about the transform
|
||||
info = TransformInfo(skipped=False, num_matches=num_matches)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,71 @@
|
||||
"""A simple wrapper transform to export a model to a graph module."""
|
||||
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...export import torch_export_to_gm
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
class ExportToGMConfig(TransformConfig):
|
||||
"""Configuration for the export to graph module transform."""
|
||||
|
||||
strict: bool = Field(
|
||||
description="Whether to export in strict mode. NOTE: we generally export in non-strict mode"
|
||||
"for now as it relaxes some assumptions around tracing. Strict mode uses torchdynamo"
|
||||
"(symbolic bytecode analysis), which can be brittle since it relies on the exact bytecode"
|
||||
"representation of the model see here as well: https://pytorch.org/docs/stable/export.html#non-strict-export",
|
||||
default=False,
|
||||
)
|
||||
clone_state_dict: bool = Field(
|
||||
description="Whether to clone the state_dict of the model. This is useful to avoid"
|
||||
"modifying the original state_dict of the model.",
|
||||
default=False,
|
||||
)
|
||||
patch_list: Optional[List[str]] = Field(
|
||||
description="List of patch names to apply with export. "
|
||||
"Default is to apply all registered patches.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("export_to_gm")
|
||||
class ExportToGM(BaseTransform):
|
||||
"""A simple wrapper transform to export a model to a graph module."""
|
||||
|
||||
config: ExportToGMConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return ExportToGMConfig
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# at this point we assume the gm is just a dummy graph module
|
||||
assert len(gm.graph.nodes) == 0, "Expected empty graph module."
|
||||
|
||||
# retrieve the actual model from the dummy graph module
|
||||
model = gm.get_submodule("factory_model")
|
||||
|
||||
# set the example sequence
|
||||
cm.info.set_example_sequence()
|
||||
|
||||
# export the model to a graph module
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=cm.args,
|
||||
dynamic_shapes=cm.dynamic_shapes,
|
||||
clone=self.config.clone_state_dict,
|
||||
strict=self.config.strict,
|
||||
patch_list=self.config.patch_list,
|
||||
)
|
||||
|
||||
# this is a clean graph by definition since it was just exported
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return gm, info
|
||||
76
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
Normal file
76
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""High-level entrypoint to transform a model into an efficient inference model."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, GraphModule
|
||||
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from .interface import (
|
||||
InferenceOptimizerConfig,
|
||||
Stages,
|
||||
StrictInferenceOptimizerConfig,
|
||||
TransformConfig,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
class InferenceOptimizer:
|
||||
def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig):
|
||||
self.factory = factory
|
||||
self.config = self._clean_config(config)
|
||||
|
||||
def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOptimizerConfig:
|
||||
"""Get a typed checked ("strict") config with sorted keys according to stages."""
|
||||
# convert to nested kwargs, no TransformConfig objects allowed
|
||||
nested_kwargs = {
|
||||
k: v.model_dump() if isinstance(v, TransformConfig) else v for k, v in config.items()
|
||||
}
|
||||
# sort by stage
|
||||
keys_sorted = sorted(nested_kwargs.keys(), key=lambda k: Stages(nested_kwargs[k]["stage"]))
|
||||
# create strict config with correct config classes and correct order
|
||||
strict_config: StrictInferenceOptimizerConfig = {
|
||||
k: TransformRegistry.get_config_class(k)(**nested_kwargs[k]) for k in keys_sorted
|
||||
}
|
||||
# return strict config
|
||||
return strict_config
|
||||
|
||||
@staticmethod
|
||||
def _init_gm() -> GraphModule:
|
||||
"""Initialize a fake graph module.
|
||||
|
||||
This is a dummy graph module that will be used to kick off the transforms.
|
||||
"""
|
||||
return GraphModule(nn.Module(), Graph())
|
||||
|
||||
def __call__(
|
||||
self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None
|
||||
) -> GraphModule:
|
||||
"""Transform a model into an optimized inference model.
|
||||
|
||||
Args:
|
||||
cm: The cached sequence interface defining the sequence interface.
|
||||
|
||||
Returns:
|
||||
A GraphModule representing the optimized inference model.
|
||||
"""
|
||||
############################################################################################
|
||||
# RUN THROUGH CONFIGURED TRANSFORMATIONS
|
||||
############################################################################################
|
||||
|
||||
# start with an empty fake graph module if not provided
|
||||
if gm is None:
|
||||
gm = self._init_gm()
|
||||
|
||||
# iterate over all transforms sorted by stage in the config
|
||||
for t_name, t_config in self.config.items():
|
||||
# instantiate transform
|
||||
transform = TransformRegistry.get(t_name)(t_config)
|
||||
# run transform
|
||||
gm = transform(gm, cm, self.factory)
|
||||
|
||||
############################################################################################
|
||||
# RETURN OPTIMIZED GRAPH
|
||||
############################################################################################
|
||||
return gm
|
||||
@ -0,0 +1 @@
|
||||
"""V1 Graph Transformations Module --> will be deprecated and replaced by auto_deploy.transform."""
|
||||
@ -59,7 +59,7 @@ def load_buffers_and_params(
|
||||
if clone:
|
||||
v_new = v.detach().clone()
|
||||
if isinstance(v, torch.nn.Parameter):
|
||||
v_new = nn.Parameter(v_new)
|
||||
v_new = nn.Parameter(v_new, requires_grad=False)
|
||||
else:
|
||||
v_new = state_dict[k]
|
||||
setattr(submod, name, v_new)
|
||||
@ -192,7 +192,7 @@ def _canonicalize_single_gm(
|
||||
|
||||
def canonicalize_graph(
|
||||
gm: GraphModule, shape_prop: bool = False, args_static: Optional[Tuple[Any, ...]] = None
|
||||
) -> GraphModule:
|
||||
) -> None:
|
||||
"""Canonicalize the graph of the given GraphModule.
|
||||
|
||||
Args:
|
||||
@ -217,8 +217,6 @@ def canonicalize_graph(
|
||||
|
||||
ad_logger.debug(f"After canonicalizing: {gm}")
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def add_graph_input(
|
||||
gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
|
||||
|
||||
@ -1,488 +0,0 @@
|
||||
import importlib.metadata
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.export as te
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from packaging import version
|
||||
from torch import fx
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from ._graph import canonicalize_graph, lift_to_meta, load_buffers_and_params, tree_to
|
||||
|
||||
try:
|
||||
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
|
||||
except ImportError:
|
||||
torch_export_context = nullcontext
|
||||
|
||||
|
||||
def _clean_up_no_op_slice_nodes(gm: fx.GraphModule):
|
||||
"""Remove no-op slice nodes from the graph.
|
||||
|
||||
Those will be nodes that are used to represent a slice operation like ``t[:, :5]``. The graph IR
|
||||
will represent it as ``t[:][:5]``, i.e., two nodes and the first slice being a no-op. This
|
||||
function gets rid of such instances.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
# looking for slice nodes
|
||||
if not is_op(node, torch.ops.aten.slice):
|
||||
continue
|
||||
# only handling this parameter combination for now
|
||||
# 4 args will be (input, dim, start, end)
|
||||
if len(node.args) != 4 or len(node.kwargs) != 0:
|
||||
continue
|
||||
# check if dim is just an integer
|
||||
if not isinstance(node.args[1], int):
|
||||
continue
|
||||
# check if the slice op is indeed a no-op
|
||||
if node.args[2] != 0 or node.args[3] != torch.iinfo(torch.long).max:
|
||||
continue
|
||||
# extract input tensor node and remove the slice node
|
||||
in_node = node.args[0]
|
||||
assert [in_node] == node.all_input_nodes, "Slice node has unexpected input nodes."
|
||||
node.replace_all_uses_with(in_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _eliminate_no_op_add_nodes(gm: fx.GraphModule):
|
||||
"""Eliminate add nodes from the graph that are no-ops.
|
||||
|
||||
This would be any node that is just adding 0 to the input tensor. We can safely remove those.
|
||||
|
||||
NOTE: this function has one failure mode when the op ``out = tensor + zero_tensor`` is used
|
||||
in such a way that``out`` will be broadcast to the shape of zero_tensor. After removing this op
|
||||
then, out won't have the right shape anymore. This should e a rare case and we can handle it
|
||||
when it comes up.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
# looking for add nodes
|
||||
if not is_op(node, torch.ops.aten.add):
|
||||
continue
|
||||
# only handling this parameter combination for now
|
||||
if len(node.all_input_nodes) != 2:
|
||||
continue
|
||||
|
||||
# check if any of the input nodes is just a constant tensor with value 0
|
||||
if is_op(node.all_input_nodes[0], torch.ops.aten.zeros):
|
||||
zero_node, true_node = node.all_input_nodes
|
||||
elif is_op(node.all_input_nodes[1], torch.ops.aten.zeros):
|
||||
true_node, zero_node = node.all_input_nodes
|
||||
else:
|
||||
continue
|
||||
|
||||
# do the replacement and clean-up
|
||||
node.replace_all_uses_with(true_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _clean_up_device_info(gm: fx.GraphModule):
|
||||
"""Correct device information in the graph."""
|
||||
devices = {t.device for _, t in gm.named_parameters()}
|
||||
if len(devices) == 0:
|
||||
return
|
||||
elif len(devices) > 1:
|
||||
raise AssertionError("All parameters should be on the same device.")
|
||||
device = devices.pop()
|
||||
meta_device = torch.device("meta")
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if any(a == meta_device for a in node.args):
|
||||
new_args = list(node.args)
|
||||
new_args = [a if a != meta_device else device for a in new_args]
|
||||
node.args = tuple(new_args)
|
||||
if any(a == meta_device for a in node.kwargs.values()):
|
||||
new_kwargs = dict(node.kwargs)
|
||||
new_kwargs = {k: v if v != meta_device else device for k, v in new_kwargs.items()}
|
||||
node.kwargs = new_kwargs
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _load_hook_for_deduplication(
|
||||
state_dict, prefix, *args, param_key_remaining: str, param_key_removed: str
|
||||
):
|
||||
"""Check for removed param key and and put it into the key that is remaining."""
|
||||
ad_logger.debug(f"Loading hook for deduplication: {param_key_remaining} <- {param_key_removed}")
|
||||
k_remaining = prefix + param_key_remaining
|
||||
k_removed = prefix + param_key_removed
|
||||
if k_removed in state_dict:
|
||||
state_dict[k_remaining] = state_dict.pop(k_removed)
|
||||
|
||||
|
||||
def _deduplicate_params_and_buffers(gm: fx.GraphModule):
|
||||
"""This will de-duplicate params and buffers that share the same tensor."""
|
||||
# get all get_attr nodes
|
||||
get_attr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
|
||||
|
||||
# sort by id of target
|
||||
targets: Dict[int, List[fx.Node]] = defaultdict(list)
|
||||
for n in get_attr_nodes:
|
||||
submod, _, name = n.target.rpartition(".")
|
||||
t_target = getattr(gm.get_submodule(submod), name)
|
||||
targets[id(t_target)].append(n)
|
||||
# now replace all instances of the same tensor with the same get_attr node (idx 0 in the list)
|
||||
for nodes in targets.values():
|
||||
node_kept = nodes[0]
|
||||
for n in nodes[1:]:
|
||||
n.replace_all_uses_with(node_kept)
|
||||
gm.graph.erase_node(n)
|
||||
|
||||
# remove the param/buffer from the submodule
|
||||
submod, _, name = n.target.rpartition(".")
|
||||
delattr(gm.get_submodule(submod), name)
|
||||
|
||||
# add load hooks to also load the weights correctly
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_load_hook_for_deduplication,
|
||||
param_key_remaining=node_kept.target,
|
||||
param_key_removed=n.target,
|
||||
)
|
||||
)
|
||||
|
||||
ad_logger.debug(f"Deduplicated: {n.target} --> {node_kept.target}")
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _clean_up_checks(gm: fx.GraphModule):
|
||||
"""This transformations removes shape checks and assertions from the graph."""
|
||||
check_ops = {
|
||||
torch.ops.aten._assert_scalar,
|
||||
torch.ops.aten.sym_constrain_range,
|
||||
torch.ops.aten.sym_constrain_range_for_size,
|
||||
torch.ops.aten._assert_tensor_metadata,
|
||||
# torch.ops.aten._functional_sym_constrain_range,
|
||||
# torch.ops.aten._functional_sym_constrain_range_for_size
|
||||
}
|
||||
graph: fx.Graph = gm.graph
|
||||
for node in reversed(graph.nodes):
|
||||
if len(node.users) > 0 or not is_op(node, check_ops):
|
||||
continue
|
||||
graph.erase_node(node)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
def _clean_up_input_constraints(gm: fx.GraphModule):
|
||||
"""This transformations updates the input constraints of the graph.
|
||||
|
||||
Specifically, we want to account for flattened sequences and hence the max constraint should
|
||||
be updated to reflect the flattened sequence length.
|
||||
"""
|
||||
graph: fx.Graph = gm.graph
|
||||
input_node = graph.find_nodes(op="placeholder")[0]
|
||||
sym_shape: torch.Size = input_node.meta["val"].shape
|
||||
|
||||
# get expressions in the symbolic shape
|
||||
vrs: List[ValueRanges] = []
|
||||
for s in sym_shape:
|
||||
if isinstance(s, int):
|
||||
vrs.append(ValueRanges(0, s))
|
||||
elif isinstance(s, torch.SymInt):
|
||||
vrs.append(gm.range_constraints[s.node.expr])
|
||||
else:
|
||||
raise TypeError(f"Unexpected type {type(s)} in symbolic shape.")
|
||||
|
||||
# update the max constraint for each vr
|
||||
max_total = math.prod(vr.upper for vr in vrs)
|
||||
for vr in vrs:
|
||||
object.__setattr__(vr, "upper", max_total)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
|
||||
# TODO: remove once https://github.com/pytorch/pytorch/issues/140710 is resolved
|
||||
def _torch_where_patch(condition: torch.Tensor, *args, **kwargs):
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
return torch.nonzero(condition, as_tuple=True)
|
||||
return _torch_where_patch.where_original(condition, *args, **kwargs)
|
||||
|
||||
|
||||
_torch_where_patch.where_original = torch.where
|
||||
|
||||
|
||||
def _torch_linear_patch(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.auto_deploy.torch_linear_simple(input, weight, bias)
|
||||
|
||||
|
||||
# TODO: remove once https://github.com/pytorch/pytorch/issues/142439 is resolved
|
||||
def _torch_modulelist_getitem_patch(self: nn.ModuleList, idx):
|
||||
if isinstance(idx, slice):
|
||||
# return a simple list.
|
||||
# NOTE: this obviously only works for any use case where we access the sliced module list
|
||||
# like a regular list like a for-loop. For most other things, this hack will not work.
|
||||
return list(self._modules.values())[idx]
|
||||
else:
|
||||
return _torch_modulelist_getitem_patch.getitem_original(self, idx)
|
||||
|
||||
|
||||
_torch_modulelist_getitem_patch.getitem_original = nn.ModuleList.__getitem__
|
||||
|
||||
|
||||
def _torch_tensor_patch(data, **kwargs):
|
||||
"""Patch torch.tensor to handle 0.0 on meta device.
|
||||
|
||||
``torch.tensor(0.0, device="meta")`` does not work and hence we are patching it to use
|
||||
``torch.zeros((), device="meta")`` instead, which is equivalent.
|
||||
"""
|
||||
device = kwargs.get("device", None)
|
||||
if data == 0.0 and device is not None and torch.device(device) == torch.device("meta"):
|
||||
return torch.zeros((), **kwargs)
|
||||
return _torch_tensor_patch.tensor_original(data, **kwargs)
|
||||
|
||||
|
||||
_torch_tensor_patch.tensor_original = torch.tensor
|
||||
|
||||
|
||||
def _transformers_version() -> str:
|
||||
"""Get the version of transformers."""
|
||||
return version.parse(importlib.metadata.version("transformers")).base_version
|
||||
|
||||
|
||||
# TODO (@lucaslie): https://github.com/NVIDIA/TensorRT-LLM/issues/5728
|
||||
# not great that this patch is here but it's the least invasisve change until we make headway on the
|
||||
# above issue.
|
||||
@contextmanager
|
||||
def _transformers_sdpa_mask_patch():
|
||||
"""Patch transformers.masking_utils.sdpa_mask to be export-compatible."""
|
||||
# this patch is only needed+compatible for transformers >= 4.53.0
|
||||
if version.parse(_transformers_version()) < version.parse("4.53.0"):
|
||||
yield # Just yield without doing anything (like nullcontext)
|
||||
return
|
||||
|
||||
# imports only after version check
|
||||
from transformers import masking_utils
|
||||
from transformers.integrations.executorch import sdpa_mask_without_vmap
|
||||
|
||||
# recall original implementation
|
||||
sdpa_mask_original = masking_utils.sdpa_mask
|
||||
|
||||
# patch function and mask attention interface
|
||||
masking_utils.sdpa_mask = sdpa_mask_without_vmap
|
||||
if "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping:
|
||||
sdpa_local_original = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._local_mapping["sdpa"]
|
||||
else:
|
||||
sdpa_local_original = None
|
||||
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = sdpa_mask_without_vmap
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# revert patches
|
||||
masking_utils.sdpa_mask = sdpa_mask_original
|
||||
if sdpa_local_original is None:
|
||||
del masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
||||
else:
|
||||
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = sdpa_local_original
|
||||
|
||||
|
||||
def add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> fx.GraphModule:
|
||||
"""Adds back the state dict load hooks stripped away during export."""
|
||||
hooks = {
|
||||
k: mod._load_state_dict_pre_hooks
|
||||
for k, mod in model.named_modules()
|
||||
if mod._load_state_dict_pre_hooks
|
||||
}
|
||||
|
||||
for mod_name, mod in gm.named_modules():
|
||||
if mod_name in hooks:
|
||||
for hook in hooks.pop(mod_name).values():
|
||||
mod._register_load_state_dict_pre_hook(hook.hook, with_module=hook.with_module)
|
||||
assert not (bool(hooks)), f"""Mismatch in names of exported and source modules with hooks.
|
||||
The following module names were not found in exported module {list(hooks.keys())}"""
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module):
|
||||
"""
|
||||
Add a load hook to handle aliased parameters in the model.
|
||||
|
||||
When parameters are aliased (multiple parameter names point to the same tensor),
|
||||
we need to ensure all aliases get the same value during loading. This hook:
|
||||
1. Identifies groups of aliased parameters
|
||||
2. For each group, finds a valid parameter value from the state dict
|
||||
3. Applies that value to all aliases in the group
|
||||
|
||||
Args:
|
||||
gm: The graph module to add the hook to
|
||||
model: The source model containing the original parameter aliases
|
||||
"""
|
||||
# Find all parameter aliases in the source model
|
||||
param_to_names = defaultdict(list)
|
||||
for name, param in model.named_parameters(remove_duplicate=False):
|
||||
param_to_names[id(param)].append(name)
|
||||
|
||||
# Filter to only groups with multiple aliases
|
||||
aliased_groups = [names for names in param_to_names.values() if len(names) > 1]
|
||||
|
||||
if not aliased_groups:
|
||||
return gm # No aliases to handle
|
||||
|
||||
def find_valid_param_value(
|
||||
state_dict: Dict[str, torch.Tensor], param_names: List[str]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Find a valid parameter value from state dict for a group of aliased parameters.
|
||||
|
||||
Args:
|
||||
state_dict: The state dict being loaded
|
||||
param_names: List of parameter names that are aliases of each other
|
||||
|
||||
Returns:
|
||||
A valid tensor value if found, None otherwise
|
||||
"""
|
||||
# First try to find a non-meta tensor value
|
||||
value = None
|
||||
for name in param_names:
|
||||
if name in state_dict:
|
||||
value = state_dict[name]
|
||||
if value.device.type != "meta":
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *args, **kwargs):
|
||||
"""Load hook that ensures aliased parameters get the same value."""
|
||||
for group in aliased_groups:
|
||||
# Find a valid value for this group of aliases
|
||||
value = find_valid_param_value(state_dict, group)
|
||||
assert value is not None, (
|
||||
f"No valid value found in state dict for aliased parameters: {group}"
|
||||
)
|
||||
|
||||
# Apply the value to all aliases
|
||||
for name in group:
|
||||
state_dict[name] = value
|
||||
|
||||
ad_logger.debug(f"Applied value from {group[0]} to aliased parameters: {group}")
|
||||
|
||||
# Register the hook
|
||||
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def torch_export(model: nn.Module, *export_args, **export_kwargs) -> te.ExportedProgram:
|
||||
"""Just like torch.export except we decorate it to be in inference_mode."""
|
||||
with torch_export_context():
|
||||
ep = te.export(model, *export_args, **export_kwargs)
|
||||
|
||||
# return the result
|
||||
return ep
|
||||
|
||||
|
||||
def torch_export_to_gm(
|
||||
model: nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
clone: bool = False, # clone or don't clone the model state_dict
|
||||
**export_kwargs,
|
||||
) -> fx.GraphModule:
|
||||
"""torch_export with wrapping into GraphModule + useful additions to the resulting module."""
|
||||
# we need to better control how F.scaled_dot_product_attention is represented in the graph
|
||||
# there is no guarantee how it is represented and we need to make sure it is easily identifiable
|
||||
# in the graph.
|
||||
sdpa_original = F.scaled_dot_product_attention
|
||||
F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
|
||||
# We overwrite the linear functional as well. This basically avoids exporting the view ops
|
||||
# that are used to flatten/unflatten multiple batch dimensions of the input tensor.
|
||||
linear_original = F.linear
|
||||
# patch linear → always supply bias
|
||||
F.linear = _torch_linear_patch
|
||||
|
||||
# patch torch.where(condition) to torch.nonzero(condition, as_tuple=True)
|
||||
torch.where = _torch_where_patch
|
||||
|
||||
# patch nn.ModuleList.__getitem__ to handle slicing
|
||||
nn.ModuleList.__getitem__ = _torch_modulelist_getitem_patch
|
||||
|
||||
# overwrite autocast/sdpa contextmanagers to be no-ops
|
||||
autocast_original = torch.autocast
|
||||
sdpa_kernel_original = torch.nn.attention.sdpa_kernel
|
||||
torch.autocast = lambda *args, **kwargs: nullcontext()
|
||||
torch.nn.attention.sdpa_kernel = lambda *args, **kwargs: nullcontext()
|
||||
|
||||
# patch torch.tensor to handle 0.0 on meta device
|
||||
torch.tensor = _torch_tensor_patch
|
||||
|
||||
# run export with sdpa masking patch and lifted to meta
|
||||
with _transformers_sdpa_mask_patch():
|
||||
with lift_to_meta(model) as state_dict:
|
||||
# clean up args, kwargs and move to correct device
|
||||
args, kwargs = tree_to((args, kwargs or {}), device="meta")
|
||||
|
||||
# NOTE: we always export in non-strict mode for now as it relaxes some
|
||||
# assumptions around tracing. Strict mode uses torchdynamo (symbolic bytecode analysis),
|
||||
# which can be brittle since it relies on the exact bytecode representation of the model
|
||||
# see here as well: https://pytorch.org/docs/stable/export.html#non-strict-export
|
||||
export_kwargs["strict"] = False
|
||||
|
||||
# run export and extract graph module
|
||||
egm: fx.GraphModule = torch_export(model, args, kwargs, **export_kwargs).module()
|
||||
|
||||
# load state_dict into egm
|
||||
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
|
||||
load_buffers_and_params(
|
||||
egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone
|
||||
)
|
||||
|
||||
# revert sdpa back to original
|
||||
F.scaled_dot_product_attention = sdpa_original
|
||||
|
||||
# revert linear back to original
|
||||
F.linear = linear_original
|
||||
|
||||
# revert torch.where patch
|
||||
torch.where = _torch_where_patch.where_original
|
||||
|
||||
# revert nn.ModuleList.__getitem__ patch
|
||||
nn.ModuleList.__getitem__ = _torch_modulelist_getitem_patch.getitem_original
|
||||
|
||||
# revert autocast/sdpa back to original
|
||||
torch.autocast = autocast_original
|
||||
torch.nn.attention.sdpa_kernel = sdpa_kernel_original
|
||||
|
||||
# revert torch.tensor patch
|
||||
torch.tensor = _torch_tensor_patch.tensor_original
|
||||
|
||||
# Export strips away all methods not traced during forward. The model could have
|
||||
# load hooks that contain logic for correct state_dict loading. We need to add those
|
||||
# hooks back to the exported graph module.
|
||||
add_missing_load_hooks(egm, model)
|
||||
|
||||
# Export will have LOTS of no-op slice nodes. Let's remove them to clean up the graph
|
||||
# representation
|
||||
_clean_up_no_op_slice_nodes(egm)
|
||||
|
||||
# Export does not clean "no-op" element-wise add nodes. We can safely remove those.
|
||||
_eliminate_no_op_add_nodes(egm)
|
||||
|
||||
# clean up devices in the graph
|
||||
_clean_up_device_info(egm)
|
||||
|
||||
# Add load hook to correctly load parameters that are aliased in the source model.
|
||||
add_load_hook_for_aliased_params(egm, model)
|
||||
|
||||
# deduplicate params and buffers
|
||||
_deduplicate_params_and_buffers(egm)
|
||||
|
||||
# clean up shape checks and assertions
|
||||
_clean_up_checks(egm)
|
||||
|
||||
# clean up input constraints
|
||||
_clean_up_input_constraints(egm)
|
||||
|
||||
return egm
|
||||
@ -3,11 +3,12 @@
|
||||
from .attention import *
|
||||
from .collectives import *
|
||||
from .eliminate_redundant_transposes import *
|
||||
from .ep_sharding import *
|
||||
from .fused_moe import *
|
||||
from .fusion import *
|
||||
from .kvcache import *
|
||||
from .quantization import *
|
||||
from .quantize_moe import *
|
||||
from .rms_norm import *
|
||||
from .rope import *
|
||||
from .sharding import *
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from ...utils.node_utils import is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_repeat_kv(gm: GraphModule) -> GraphModule:
|
||||
def match_repeat_kv(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the repeat_kv pattern in fx graphs.
|
||||
|
||||
@ -36,13 +36,11 @@ def match_repeat_kv(gm: GraphModule) -> GraphModule:
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_kv_patterns:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_kv_patterns} repeat_kv patterns")
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def match_eager_attention(gm: GraphModule) -> GraphModule:
|
||||
def match_eager_attention(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the eager attention pattern in fx graphs.
|
||||
|
||||
@ -68,12 +66,11 @@ def match_eager_attention(gm: GraphModule) -> GraphModule:
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_eager_patterns:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_eager_patterns} eager attention patterns")
|
||||
return gm
|
||||
|
||||
|
||||
def match_grouped_attention(gm: GraphModule) -> GraphModule:
|
||||
def match_grouped_attention(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the grouped attention pattern in fx graphs.
|
||||
|
||||
@ -101,12 +98,11 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule:
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_grouped_patterns:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_grouped_patterns} grouped attention patterns")
|
||||
return gm
|
||||
|
||||
|
||||
def match_causal_attn_mask(gm: GraphModule) -> GraphModule:
|
||||
def match_causal_attn_mask(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match attention operations with causal attention masks and optimize them.
|
||||
|
||||
@ -174,9 +170,8 @@ def match_causal_attn_mask(gm: GraphModule) -> GraphModule:
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_causal_patterns:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_causal_patterns} causal mask attention patterns")
|
||||
return gm
|
||||
|
||||
|
||||
def _match_repeat_kv_pattern(reshape_node: Node) -> Optional[Dict[str, Node]]:
|
||||
@ -748,7 +743,7 @@ def _has_triu_ancestor(node: Node, offset: int = 1, depth: int = 0, max_depth: i
|
||||
return False
|
||||
|
||||
|
||||
def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescriptor]) -> GraphModule:
|
||||
def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescriptor]) -> None:
|
||||
"""
|
||||
Match and transform attention operations to match the layout expected by the attention backend.
|
||||
|
||||
@ -832,9 +827,7 @@ def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescript
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_bsnd_patterns:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.debug(f"Transformed graph for bsnd layout: {gm}")
|
||||
|
||||
ad_logger.info(f"Found and matched {num_bsnd_patterns} attention layouts")
|
||||
|
||||
return gm
|
||||
|
||||
@ -15,7 +15,7 @@ from .._graph import canonicalize_graph
|
||||
# * version above with fused GEMMs (i.e. with a split node)
|
||||
# * all_reduce(pointwise_op(linear(x)))
|
||||
# * ...
|
||||
def fuse_collectives(gm: GraphModule) -> GraphModule:
|
||||
def fuse_collectives(gm: GraphModule) -> None:
|
||||
num_gemm_collective_fusions = 0
|
||||
ad_logger.debug("Before GEMM+Collective fusion: " + str(gm))
|
||||
|
||||
@ -54,13 +54,12 @@ def fuse_collectives(gm: GraphModule) -> GraphModule:
|
||||
gm.graph.erase_node(parent_node)
|
||||
num_gemm_collective_fusions += 1
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions")
|
||||
ad_logger.debug("After GEMM+Collective fusion: " + str(gm))
|
||||
return gm
|
||||
|
||||
|
||||
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> GraphModule:
|
||||
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None:
|
||||
"""Essentially, this function fuses the following operators into one allreduce trtllm implementation.
|
||||
|
||||
* target pattern:
|
||||
@ -72,7 +71,7 @@ def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> GraphModule:
|
||||
|
||||
"""
|
||||
if not is_trtllm_op_available():
|
||||
return gm
|
||||
return
|
||||
|
||||
num_ar_r_rms_fusions = 0
|
||||
ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
@ -158,14 +157,11 @@ def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> GraphModule:
|
||||
nonlocal num_ar_r_rms_fusions
|
||||
num_ar_r_rms_fusions += 1
|
||||
|
||||
return
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions")
|
||||
ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
return gm
|
||||
|
||||
@ -40,7 +40,7 @@ def _are_transpose_args_same(node1: Node, node2: Node) -> bool:
|
||||
return dim1_node1 == dim1_node2 and dim2_node1 == dim2_node2
|
||||
|
||||
|
||||
def eliminate_redundant_transposes(gm: GraphModule) -> GraphModule:
|
||||
def eliminate_redundant_transposes(gm: GraphModule) -> None:
|
||||
"""Eliminate redundant transpose operations in the graph.
|
||||
|
||||
This transformation identifies pairs of consecutive transpose operations with
|
||||
@ -107,7 +107,6 @@ def eliminate_redundant_transposes(gm: GraphModule) -> GraphModule:
|
||||
# Clean up the graph
|
||||
if nodes_to_eliminate:
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found and eliminated {len(nodes_to_eliminate)} redundant transpose pairs")
|
||||
ad_logger.debug("After eliminating redundant transposes: " + str(gm))
|
||||
return gm
|
||||
|
||||
@ -1,130 +0,0 @@
|
||||
"""
|
||||
Expert Parallel Sharding for Mixture-of-Experts (MoE) Graphs.
|
||||
|
||||
This module implements graph transformations to enable expert sharding
|
||||
for Mixture-of-Experts (MoE) models in a multi-GPU setting. The sharding
|
||||
algorithm partitions the expert weights, as well as updates the routing
|
||||
components (`selected_experts` and `final_scales`), so that each GPU only
|
||||
processes a subset of experts.
|
||||
|
||||
The sharding process consists of:
|
||||
|
||||
1. Identify MoE nodes in the FX graph
|
||||
2. Compute local sharding parameters (`selected_experts` and `final_scales`) to update the routing tensors.
|
||||
3. Partition expert weight lists according to the current rank and world size,
|
||||
and replace the MoE node’s arguments with these sharded versions.
|
||||
4. Append an all_reduce node after each MoE node to aggregate outputs across devices,
|
||||
then canonicalize the modified graph.
|
||||
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def ep_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
ad_logger.debug("Before sharding graph: " + str(gm))
|
||||
|
||||
if world_size < 2:
|
||||
ad_logger.info("Skipping sharding for single device")
|
||||
return gm
|
||||
|
||||
assert isinstance(gm, GraphModule), "Expecting GraphModule"
|
||||
num_moe_patterns = 0
|
||||
for node in list(gm.graph.nodes):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
_insert_sharded_moe(gm, node, rank, world_size)
|
||||
num_moe_patterns += 1
|
||||
# canonicalize and return
|
||||
gm = canonicalize_graph(gm)
|
||||
|
||||
ad_logger.debug("After sharding: " + str(gm))
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE patterns")
|
||||
return gm
|
||||
|
||||
|
||||
def _insert_sharded_moe(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
"""Update the torch_moe node with sharded weight lists,
|
||||
sharded `selected_experts` and `final_scales(router_logics)`.
|
||||
Add an all_reduce node after the moe node.
|
||||
"""
|
||||
num_experts = len(node.args[3])
|
||||
args = list(node.args)
|
||||
|
||||
# -- Handle selected_experts and final_scales sharding --
|
||||
selected_experts = args[1]
|
||||
final_scales = args[2]
|
||||
|
||||
experts_per_rank = num_experts // world_size
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
lower = experts_per_rank * rank
|
||||
# selected_experts_local = selected_experts - low
|
||||
selected_experts_local = gm.graph.create_node(
|
||||
"call_function", operator.sub, args=(selected_experts, lower), kwargs={}
|
||||
)
|
||||
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
# if rank == world_size -1:
|
||||
# rank_mask = (selected_experts // experts_per_rank) >= rank
|
||||
# else:
|
||||
# rank_mask = (selected_experts // experts_per_rank) == rank
|
||||
div_node = gm.graph.create_node(
|
||||
"call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={}
|
||||
)
|
||||
comp_op = torch.ge if rank == world_size - 1 else torch.eq
|
||||
rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={})
|
||||
|
||||
# final_scales_local = final_scales * rank_mask
|
||||
final_scales_local = gm.graph.create_node(
|
||||
"call_function", operator.mul, args=(final_scales, rank_mask), kwargs={}
|
||||
)
|
||||
|
||||
# -- Shard expert weights --
|
||||
def get_partition(lst, world_size, rank):
|
||||
num_experts = len(lst)
|
||||
expert_size_per_partition = num_experts // world_size
|
||||
expert_start = rank * expert_size_per_partition
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
expert_end = (
|
||||
num_experts if (rank == world_size - 1) else expert_start + expert_size_per_partition
|
||||
)
|
||||
return lst[expert_start:expert_end]
|
||||
|
||||
w1_list_sharded = get_partition(args[3], world_size, rank)
|
||||
w2_list_sharded = get_partition(args[4], world_size, rank)
|
||||
w3_list_sharded = get_partition(args[5], world_size, rank)
|
||||
|
||||
# -- Update args --
|
||||
args[1] = selected_experts_local
|
||||
args[2] = final_scales_local
|
||||
args[3] = w1_list_sharded
|
||||
args[4] = w2_list_sharded
|
||||
args[5] = w3_list_sharded
|
||||
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
)
|
||||
node.args = tuple(args)
|
||||
|
||||
# -- add an all_reduce node --
|
||||
with gm.graph.inserting_after(node):
|
||||
dist_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)
|
||||
)
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
@ -7,10 +7,11 @@ from torch.fx import GraphModule, Node
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
|
||||
from ...utils.quantization_utils import get_scales_and_type_from_node
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
def match_moe_pattern(gm: GraphModule) -> None:
|
||||
graph = gm.graph
|
||||
|
||||
ad_logger.debug("Before MoE Pattern Matching: " + str(gm))
|
||||
@ -21,8 +22,8 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
# Step 1: Identify Expert Compute pattern
|
||||
pattern_input_nodes, pattern_output_nodes, expert_weights = _match_expert_compute_pattern(
|
||||
start_boundary, end_boundary
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = (
|
||||
_match_expert_compute_pattern(start_boundary, end_boundary)
|
||||
)
|
||||
if not expert_weights:
|
||||
continue
|
||||
@ -56,29 +57,70 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
if final_hidden_state_node is None:
|
||||
continue
|
||||
|
||||
# Step 5: Insert the moe op into the graph.
|
||||
# Step 5: Insert the MoE op into the graph.
|
||||
ad_logger.debug(
|
||||
f"""Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n
|
||||
Capturing input hidden states node: {hidden_states},
|
||||
selected_experts node: {selected_experts}, routing_weights node: {normalized_routing_weights},
|
||||
expert weights : {expert_weights} """
|
||||
f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n"
|
||||
f"Input hidden states node: {hidden_states}, "
|
||||
f"selected_experts node: {selected_experts}, "
|
||||
f"routing_weights node: {normalized_routing_weights}, "
|
||||
f"expert weights: {expert_weights}, weight type: {weight_type}"
|
||||
)
|
||||
with graph.inserting_before(final_hidden_state_node):
|
||||
w1_list = expert_weights["w1"]
|
||||
w2_list = expert_weights["w2"]
|
||||
w3_list = expert_weights["w3"]
|
||||
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
),
|
||||
)
|
||||
if weight_type == "fp8":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
),
|
||||
)
|
||||
elif weight_type == "fp4":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
expert_scales["w1_alpha"],
|
||||
expert_scales["w2_alpha"],
|
||||
expert_scales["w3_alpha"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
),
|
||||
)
|
||||
|
||||
final_hidden_state_node.replace_all_uses_with(fused_moe_node)
|
||||
graph.erase_node(final_hidden_state_node)
|
||||
@ -88,17 +130,15 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
|
||||
num_moe_patterns += 1
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE Patterns")
|
||||
ad_logger.debug("After MoE Pattern Matching: " + str(gm))
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def fuse_moe(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
def fuse_moe(gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
Scan the FX graph and replace all calls to torch.ops.moe.torch_moe with
|
||||
Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
|
||||
torch.ops.auto_deploy.trtllm_moe_fused.
|
||||
"""
|
||||
ad_logger.debug("Before MoE fusion: " + str(gm))
|
||||
@ -106,11 +146,10 @@ def fuse_moe(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _insert_fused_moe_ops(gm)
|
||||
if fused_key_counter:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {fused_key_counter} MoE fusions")
|
||||
ad_logger.debug("After MoE fusion: " + str(gm))
|
||||
return gm
|
||||
|
||||
|
||||
def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
@ -146,6 +185,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
# TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models
|
||||
torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
args=(
|
||||
hidden_states,
|
||||
@ -227,6 +267,32 @@ def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
|
||||
return common
|
||||
|
||||
|
||||
def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
|
||||
"""
|
||||
Given a linear op node, extract the input tensor node, weight tensor,
|
||||
any quantization scales (if the op is quantized), and return a weight type.
|
||||
|
||||
For a torch.ops.auto_deploy.torch_linear_simple.default op:
|
||||
- Returns (input_node, weight, None, "simple")
|
||||
|
||||
For a torch.ops.auto_deploy.torch_quant_fp8_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8")
|
||||
For a torch.ops.auto_deploy.torch_quant_fp4_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4")
|
||||
"""
|
||||
input_node = linear_node.args[0]
|
||||
if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple):
|
||||
weight = linear_node.args[1]
|
||||
return input_node, weight, None, ""
|
||||
elif {
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
|
||||
}:
|
||||
weight = linear_node.args[1]
|
||||
scales, quant_type = get_scales_and_type_from_node(linear_node)
|
||||
return input_node, weight, scales, quant_type
|
||||
|
||||
|
||||
def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
|
||||
"""
|
||||
Match the expert compute pattern between the given boundaries.
|
||||
@ -235,24 +301,39 @@ def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
|
||||
|
||||
(F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
|
||||
|
||||
For each expert, the function returns:
|
||||
- pattern_input_nodes: a list of input nodes (x) used for the expert compute.
|
||||
- pattern_output_nodes: a list of final expert output nodes (the linear op with weight w2).
|
||||
- expert_weights: a dict with keys "w1", "w2", and "w3" mapping to lists of
|
||||
corresponding weight nodes from the w1, w2, and w3 branches.
|
||||
For each expert, the function extracts the input node from the w1 branch and
|
||||
collects the weight parameters from three linear ops (w1, w3, and w2 branches).
|
||||
|
||||
This function supports both:
|
||||
- torch.ops.auto_deploy.torch_linear_simple.default ops, and
|
||||
- torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales).
|
||||
- torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales).
|
||||
|
||||
Returns:
|
||||
A tuple:
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
|
||||
|
||||
- pattern_input_nodes: List of input nodes (x) used for the expert compute.
|
||||
- pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
|
||||
- expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
|
||||
- expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
|
||||
(empty if weight_type is "simple").
|
||||
- weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
|
||||
"""
|
||||
pattern_input_nodes, pattern_output_nodes = [], []
|
||||
expert_weights = defaultdict(list)
|
||||
expert_scales = defaultdict(list)
|
||||
weight_type = "simple" # default
|
||||
|
||||
nodes = list(start_boundary.graph.nodes)
|
||||
region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)]
|
||||
|
||||
for node in region_nodes:
|
||||
if not is_linear_op(node):
|
||||
# Accept both simple and quantized linear ops.
|
||||
if not is_linear_op(node, include_quantization=True):
|
||||
continue
|
||||
|
||||
final_linear = node
|
||||
# Must have at least one argument, and that first argument must be a Node.
|
||||
if not final_linear.args or not isinstance(final_linear.args[0], Node):
|
||||
continue
|
||||
|
||||
@ -261,47 +342,68 @@ def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
|
||||
continue
|
||||
|
||||
arg_a, arg_b = mul_node.args[:2]
|
||||
# Pick the silu op from either arg_a or arg_b.
|
||||
silu_node = (
|
||||
arg_a
|
||||
if (isinstance(arg_a, Node) and is_op(arg_a, torch.ops.aten.silu))
|
||||
if is_op(arg_a, torch.ops.aten.silu)
|
||||
else arg_b
|
||||
if (isinstance(arg_b, Node) and is_op(arg_b, torch.ops.aten.silu))
|
||||
if is_op(arg_b, torch.ops.aten.silu)
|
||||
else None
|
||||
)
|
||||
if silu_node is None:
|
||||
continue
|
||||
|
||||
if not (
|
||||
silu_node.args
|
||||
and isinstance(silu_node.args[0], Node)
|
||||
and is_linear_op(silu_node.args[0])
|
||||
):
|
||||
if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)):
|
||||
continue
|
||||
linear_w1_node = silu_node.args[0]
|
||||
|
||||
# The other branch should be a linear op (w3 branch).
|
||||
linear_w3_node = arg_b if arg_a is silu_node else arg_a
|
||||
if not (isinstance(linear_w3_node, Node) and is_linear_op(linear_w3_node)):
|
||||
if not is_linear_op(linear_w3_node, include_quantization=True):
|
||||
continue
|
||||
if not (linear_w1_node.args and linear_w3_node.args):
|
||||
continue
|
||||
|
||||
input_node_w1 = linear_w1_node.args[0]
|
||||
weight_w1 = linear_w1_node.args[1] if len(linear_w1_node.args) > 1 else None
|
||||
weight_w3 = linear_w3_node.args[1] if len(linear_w3_node.args) > 1 else None
|
||||
weight_w2 = final_linear.args[1] if len(final_linear.args) > 1 else None
|
||||
# Extract parameters from each linear op.
|
||||
input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters(
|
||||
linear_w1_node
|
||||
)
|
||||
_, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node)
|
||||
_, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear)
|
||||
|
||||
if None in (weight_w1, weight_w3, weight_w2):
|
||||
continue
|
||||
|
||||
# Ensure the weight type is consistent across branches.
|
||||
if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2:
|
||||
continue
|
||||
weight_type = wt_type_w1
|
||||
|
||||
pattern_input_nodes.append(input_node_w1)
|
||||
pattern_output_nodes.append(final_linear)
|
||||
expert_weights["w1"].append(weight_w1)
|
||||
expert_weights["w3"].append(weight_w3)
|
||||
expert_weights["w2"].append(weight_w2)
|
||||
|
||||
return pattern_input_nodes, pattern_output_nodes, expert_weights
|
||||
# TODO: sanity check that all experts have same weight type
|
||||
if weight_type == "fp8":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
elif weight_type == "fp4":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w1_alpha"].append(quant_params_w1["alpha"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w3_alpha"].append(quant_params_w3["alpha"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
expert_scales["w2_alpha"].append(quant_params_w2["alpha"])
|
||||
|
||||
return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type
|
||||
|
||||
|
||||
def _find_final_hidden_state_node(
|
||||
@ -376,7 +478,7 @@ def _extract_index_branches_from_expert_outputs(
|
||||
if not mul or len(mul.args) < 2:
|
||||
continue
|
||||
idx_node = mul.args[1]
|
||||
if not (isinstance(idx_node, Node) and is_op(idx_node, torch.ops.aten.index)):
|
||||
if not is_op(idx_node, torch.ops.aten.index):
|
||||
continue
|
||||
routing_branches.append(idx_node.args[0])
|
||||
experts = idx_node.args[1]
|
||||
|
||||
@ -116,7 +116,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
|
||||
def fuse_gemms(gm: GraphModule) -> GraphModule:
|
||||
def fuse_gemms(gm: GraphModule) -> None:
|
||||
ad_logger.info("GEMM fusion")
|
||||
ad_logger.debug("Before GEMM fusion: " + str(gm))
|
||||
# sort linear nodes by parent node
|
||||
@ -139,8 +139,7 @@ def fuse_gemms(gm: GraphModule) -> GraphModule:
|
||||
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
|
||||
|
||||
# clean up and return
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.debug("After GEMM fusion: " + str(gm))
|
||||
torch.cuda.empty_cache()
|
||||
return gm
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import operator
|
||||
from typing import Dict
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
@ -14,7 +14,7 @@ from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from .._graph import add_graph_input, canonicalize_graph
|
||||
|
||||
|
||||
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> GraphModule:
|
||||
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
|
||||
"""Modify the graph module by adding new input nodes and canonicalizing the graph.
|
||||
|
||||
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
|
||||
@ -22,9 +22,6 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> GraphM
|
||||
Args:
|
||||
egm: The graph module to analyze and modify.
|
||||
cm: Cached sequence interface containing extra argument information.
|
||||
|
||||
Returns:
|
||||
The updated GraphModule with new input nodes and a canonicalized graph.
|
||||
"""
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
|
||||
@ -45,17 +42,15 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> GraphM
|
||||
input_nodes.append(add_graph_input(egm, name))
|
||||
ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata")
|
||||
|
||||
egm = canonicalize_graph(egm)
|
||||
|
||||
return egm
|
||||
canonicalize_graph(egm)
|
||||
|
||||
|
||||
def insert_cached_attention(
|
||||
egm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
attn_descriptor: AttentionDescriptor,
|
||||
attn_descriptor: Type[AttentionDescriptor],
|
||||
cache_config: CacheConfig,
|
||||
) -> GraphModule:
|
||||
) -> None:
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
@ -68,7 +63,7 @@ def insert_cached_attention(
|
||||
|
||||
if not source_attn_nodes:
|
||||
# If there are no nodes for kv cache insertion found, return current graph
|
||||
return egm
|
||||
return
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
@ -131,15 +126,13 @@ def insert_cached_attention(
|
||||
graph.erase_node(attn_node)
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
egm = canonicalize_graph(egm)
|
||||
canonicalize_graph(egm)
|
||||
ad_logger.info(
|
||||
f"Replaced {num_cached_attn_replacements} {source_op} ops "
|
||||
f"with {attn_descriptor.get_cached_attention_op()}"
|
||||
)
|
||||
ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}")
|
||||
|
||||
return egm
|
||||
|
||||
|
||||
def resize_kv_cache(
|
||||
egm: GraphModule,
|
||||
@ -150,8 +143,13 @@ def resize_kv_cache(
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
"""
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
ad_logger.info(f"Free memory: {free_mem}, Total memory: {total_mem}")
|
||||
|
||||
def _get_mem_info_in_mb():
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
return free_mem // 1024**2, total_mem // 1024**2
|
||||
|
||||
free_mem, total_mem = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_num_pages = cm.info.num_pages
|
||||
ad_logger.info(
|
||||
@ -165,14 +163,16 @@ def resize_kv_cache(
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = torch.cuda.mem_get_info()
|
||||
ad_logger.info(f"Free memory before forward pass: {free_mem_pre}")
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
egm(*cm.args)
|
||||
free_mem_post, _ = torch.cuda.mem_get_info()
|
||||
ad_logger.info(f"Free memory after forward pass: {free_mem_post}")
|
||||
|
||||
free_mem_post, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass: {memory_for_forward_pass}")
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
@ -11,7 +11,6 @@ from ...utils.node_utils import (
|
||||
get_quantization_params_from_linear_node,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
is_match,
|
||||
)
|
||||
from ...utils.quantization_utils import (
|
||||
QuantizationImpl,
|
||||
@ -19,6 +18,7 @@ from ...utils.quantization_utils import (
|
||||
is_quantized_graph,
|
||||
is_quantized_op,
|
||||
remove_output_quantizers,
|
||||
should_skip_quantization,
|
||||
)
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
@ -169,23 +169,22 @@ def _insert_quantized_bmm(
|
||||
node.args = (*node.args, *scale_values)
|
||||
|
||||
|
||||
def quantize(gm: GraphModule, quant_config: Dict[str, Any]):
|
||||
"""Quantize the GraphModule and replace linear and bmm with quantized versions."""
|
||||
def quantize(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
|
||||
"""Quantize the GraphModule and replace linear with quantized linear."""
|
||||
# extract info from quant_config
|
||||
is_quant_graph = is_quantized_graph(gm)
|
||||
quant_algo = quant_config.get("quant_algo")
|
||||
skip = quant_config.get("exclude_modules", [])
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
|
||||
# no quantization to do
|
||||
if not (is_quant_graph or quant_config):
|
||||
ad_logger.info("No quantization to do.")
|
||||
return gm
|
||||
return
|
||||
|
||||
# tracking quantized operations in the graph
|
||||
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
for n in gm.graph.nodes:
|
||||
# check if we should skip this node
|
||||
if is_match(n, skip):
|
||||
if should_skip_quantization(n, excluded_patterns):
|
||||
continue
|
||||
|
||||
# Process linear operations
|
||||
@ -215,10 +214,8 @@ def quantize(gm: GraphModule, quant_config: Dict[str, Any]):
|
||||
if is_quant_graph:
|
||||
remove_output_quantizers(gm)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
for quant_algo in quantized_nodes:
|
||||
for op_type, count in quantized_nodes[quant_algo].items():
|
||||
ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.")
|
||||
ad_logger.debug("After quantization: " + str(gm))
|
||||
|
||||
return gm
|
||||
|
||||
@ -0,0 +1,167 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
quantized_moe_op_map = {
|
||||
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
"NVFP4": torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
}
|
||||
|
||||
|
||||
def _quantize_moe_node(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
quant_impl: QuantizationImpl,
|
||||
quantized_op: Callable[..., Node],
|
||||
):
|
||||
"""
|
||||
Replace a torch.ops.auto_deploy.torch_moe node with its quantized version,
|
||||
quantizing each expert weight list and registering scales + hooks.
|
||||
Automatically handles different scale configurations per quantization type.
|
||||
"""
|
||||
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
|
||||
|
||||
scale_keys = quant_impl.scale_names()
|
||||
|
||||
def quantize_param_list(weight_names: List[str]) -> Tuple[List[Node], List[List[Node]]]:
|
||||
new_attrs = []
|
||||
scale_nodes_group = []
|
||||
for name in weight_names:
|
||||
orig_weight = gm.get_parameter(name)
|
||||
new_weight = quant_impl.quantize_weight(orig_weight)
|
||||
|
||||
# Replace parameter in submodule
|
||||
modname, _, attrname = name.rpartition(".")
|
||||
submod = gm.get_submodule(modname)
|
||||
setattr(submod, attrname, nn.Parameter(new_weight, requires_grad=False))
|
||||
|
||||
# Register new scale buffers
|
||||
for scale_name, scale_val in quant_impl.default_scales(orig_weight.shape).items():
|
||||
submod.register_buffer(scale_name, scale_val)
|
||||
|
||||
# Register load hook
|
||||
gm._register_load_state_dict_pre_hook(partial(quant_impl.load_hook, weight_name=name))
|
||||
|
||||
# Create get_attr nodes for new param and each scale
|
||||
with gm.graph.inserting_before(node):
|
||||
new_weight_attr = gm.graph.get_attr(name)
|
||||
new_attrs.append(new_weight_attr)
|
||||
scales = [gm.graph.get_attr(modname + "." + s) for s in scale_keys]
|
||||
scale_nodes_group.append(scales)
|
||||
|
||||
return new_attrs, scale_nodes_group
|
||||
|
||||
# Quantize all three expert weights
|
||||
w1_attrs, w1_scales = quantize_param_list(w1_names)
|
||||
w2_attrs, w2_scales = quantize_param_list(w2_names)
|
||||
w3_attrs, w3_scales = quantize_param_list(w3_names)
|
||||
|
||||
# Collect scale tensors per scale type across w1, w2, w3
|
||||
def collect_scales(index: int) -> Tuple[List[Node], List[Node], List[Node]]:
|
||||
return (
|
||||
[s[index] for s in w1_scales],
|
||||
[s[index] for s in w2_scales],
|
||||
[s[index] for s in w3_scales],
|
||||
)
|
||||
|
||||
# Prepare args
|
||||
args = [
|
||||
node.args[0], # x
|
||||
node.args[1], # selected_experts
|
||||
node.args[2], # routing_weights
|
||||
w1_attrs,
|
||||
w2_attrs,
|
||||
w3_attrs,
|
||||
]
|
||||
|
||||
for idx in range(len(scale_keys)):
|
||||
s1, s2, s3 = collect_scales(idx)
|
||||
args.extend([s1, s2, s3])
|
||||
|
||||
# Replace the current node with the quantized version
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
quantized_op,
|
||||
args=tuple(args),
|
||||
)
|
||||
ad_logger.debug(f"Updating {node.name} args to {new_node.args}")
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
def quantize_moe(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
|
||||
quantized version using the quant_algo from quant_config.
|
||||
"""
|
||||
quant_algo = quant_config.get("quant_algo")
|
||||
if not quant_algo:
|
||||
ad_logger.info("No quantization to do.")
|
||||
return gm
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
|
||||
quant_impl = QuantizationImpl.create(quant_algo)
|
||||
quantized_op = quantized_moe_op_map[quant_algo]
|
||||
|
||||
count = 0
|
||||
|
||||
for node in list(gm.graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
# Check that all expert weights should be quantized
|
||||
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
|
||||
if any(
|
||||
should_skip_quantization(n, excluded_patterns)
|
||||
for n in w1_names + w2_names + w3_names
|
||||
):
|
||||
continue
|
||||
_quantize_moe_node(gm, node, quant_impl, quantized_op)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return gm
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {count} {quant_algo} quantized {quantized_op} nodes.")
|
||||
return
|
||||
|
||||
|
||||
# TODO(Fridah-nv): robust handling similar to `extract_param_names_from_lin_node` or expand it
|
||||
def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""
|
||||
Given a torch.ops.moe.torch_moe node in gm.graph, extract three lists of
|
||||
the parameter names for w1_weight, w2_weight, and w3_weight.
|
||||
|
||||
Returns:
|
||||
(w1_names, w2_names, w3_names), each a list of strings like 'layer.expert_0.w1.weight'
|
||||
"""
|
||||
# args layout: (x, selected_experts, routing_weights, w1_list, w2_list, w3_list)
|
||||
try:
|
||||
w1_list, w2_list, w3_list = moe_node.args[3:6]
|
||||
except ValueError:
|
||||
raise RuntimeError(
|
||||
f"Expected moe_node.args to have at least 6 entries, got {len(moe_node.args)}"
|
||||
)
|
||||
|
||||
def _unwrap_list(arg) -> List[str]:
|
||||
if not isinstance(arg, (list, tuple)):
|
||||
raise TypeError(f"Expected a Python list/tuple of get_attr Nodes, got {type(arg)}")
|
||||
names: List[str] = []
|
||||
for elt in arg:
|
||||
if not isinstance(elt, Node) or elt.op != "get_attr":
|
||||
raise RuntimeError(f"Expected each list element to be a get_attr Node, got {elt}")
|
||||
names.append(elt.target)
|
||||
return names
|
||||
|
||||
w1_names = _unwrap_list(w1_list)
|
||||
w2_names = _unwrap_list(w2_list)
|
||||
w3_names = _unwrap_list(w3_list)
|
||||
|
||||
return w1_names, w2_names, w3_names
|
||||
@ -0,0 +1,113 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
|
||||
"triton": torch.ops.auto_deploy.triton_rms_norm,
|
||||
"torch": torch.ops.auto_deploy.torch_rmsnorm,
|
||||
}
|
||||
|
||||
|
||||
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
"""Backend-specific rms_norm implementation.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using the specified backend implementation.
|
||||
"""
|
||||
|
||||
assert backend.lower() in _BACKEND_OPS, (
|
||||
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
|
||||
)
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
|
||||
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
|
||||
|
||||
This function sets up pattern matching to identify RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
"""
|
||||
if backend.lower() not in _BACKEND_OPS:
|
||||
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
|
||||
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Create dummy tensors for pattern matching
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
ad_logger.info(f"RMSNorm pattern count: {cnt}")
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.debug("RMSNorm pattern matching completed.")
|
||||
@ -119,7 +119,7 @@ def _explicit_not_interleaved(match: Match) -> bool:
|
||||
return not any(isinstance(n, Node) and _match_input_interleave_pattern(n) for n in (q, k))
|
||||
|
||||
|
||||
def match_rope_pattern(gm: GraphModule) -> GraphModule:
|
||||
def match_rope_pattern(gm: GraphModule) -> int:
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
@ -174,12 +174,12 @@ def match_rope_pattern(gm: GraphModule) -> GraphModule:
|
||||
)
|
||||
|
||||
num_matches = patterns.apply(graph)
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found and matched {num_matches} RoPE patterns")
|
||||
return gm, num_matches
|
||||
return num_matches
|
||||
|
||||
|
||||
def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphModule:
|
||||
def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> None:
|
||||
"""
|
||||
Match and transform input and output of rope ops to the layout specified to meet requirements of optimized ops.
|
||||
Supported layout is 'bsnd' (batch, seq, head, dim).
|
||||
@ -189,7 +189,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
|
||||
ad_logger.warning(
|
||||
f"Unsupported RoPE layout '{expected_layout}'; expected '{supported}'. Skipping RoPE layout matching."
|
||||
)
|
||||
return gm
|
||||
return
|
||||
|
||||
ad_logger.info(f"Match RoPE layout to {expected_layout}")
|
||||
|
||||
@ -291,12 +291,11 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
|
||||
k_rope_new.args = (k_rope_old, 1, 2)
|
||||
|
||||
if num_rope_layout_matches:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_rope_layout_matches} RoPE layout matches")
|
||||
return gm
|
||||
|
||||
|
||||
def optimize_rope(gm: GraphModule) -> GraphModule:
|
||||
def optimize_rope(gm: GraphModule) -> None:
|
||||
"""
|
||||
Scan the FX graph and replace calls to the torch-reference RoPE ops with
|
||||
the optimized `rope::flashinfer` kernel.
|
||||
@ -317,9 +316,8 @@ def optimize_rope(gm: GraphModule) -> GraphModule:
|
||||
continue
|
||||
num_rope_optimizations += 1
|
||||
if num_rope_optimizations:
|
||||
gm = canonicalize_graph(gm)
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_rope_optimizations} RoPE optimizations")
|
||||
return gm
|
||||
|
||||
|
||||
def _optimize_explicit(
|
||||
|
||||
@ -18,12 +18,15 @@ Our sharding algorithm for tensor parallelism (TP) is based on the following ste
|
||||
|
||||
import math
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from typing import Callable, DefaultDict, Dict, List, Set
|
||||
from typing import Callable, DefaultDict, Dict, List, Literal, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
@ -38,6 +41,249 @@ from ...utils.quantization_utils import QuantizationImpl
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
class SplitDimension(IntEnum):
|
||||
"""Enum for tensor split dimensions in sharding."""
|
||||
|
||||
ROW = 0 # Split along rows (first dimension)
|
||||
COLUMN = 1 # Split along columns (second dimension)
|
||||
|
||||
|
||||
class ShardingTransformInfo(BaseModel, ABC):
|
||||
"""Abstract base class for transformation configurations."""
|
||||
|
||||
model_config = ConfigDict(frozen=True) # Makes the model immutable and hashable
|
||||
|
||||
target_node: str
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
|
||||
"""
|
||||
Validate whether the transformation is valid.
|
||||
Execute right before applying the transformation.
|
||||
"""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Apply the transformation to the graph module.
|
||||
|
||||
This method must be implemented by each transformation class.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_and_apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Check if the transformation is valid and apply it if it is."""
|
||||
if not self.validate(gm, node):
|
||||
ad_logger.warning(f"Skipping invalid transformation {self}.")
|
||||
return
|
||||
self.apply(gm, node)
|
||||
|
||||
|
||||
class TPShardingInfo(ShardingTransformInfo):
|
||||
"""Configuration for TP sharding transformations."""
|
||||
|
||||
split_dim: SplitDimension
|
||||
dist_op: Optional[Literal["all_reduce", "all_gather"]] = None
|
||||
min_local_shape: int = 1
|
||||
|
||||
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
|
||||
"""Validate the transformation configuration."""
|
||||
if self.dist_op is not None:
|
||||
if self.split_dim == SplitDimension.ROW:
|
||||
if self.dist_op == "all_reduce":
|
||||
ad_logger.warning(
|
||||
f"Row split is only supported for all_gather. Skipping {self}."
|
||||
)
|
||||
return False
|
||||
if self.split_dim == SplitDimension.COLUMN:
|
||||
if self.dist_op == "all_gather":
|
||||
ad_logger.warning(
|
||||
f"Column split is only supported for all_reduce. Skipping {self}."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Apply TP sharding transformation to the graph module."""
|
||||
|
||||
_insert_sharded_matmul(
|
||||
gm=gm,
|
||||
node=node,
|
||||
dim=self.split_dim.value,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
add_dist=self.dist_op is not None,
|
||||
min_local_shape=self.min_local_shape,
|
||||
)
|
||||
|
||||
|
||||
class BMMShardingInfo(ShardingTransformInfo):
|
||||
"""Configuration for BMM sharding transformations."""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
|
||||
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
|
||||
"""Validate the transformation configuration."""
|
||||
if not is_op(node, torch.ops.aten.bmm):
|
||||
ad_logger.warning(f"BMM sharding is only supported for BMM nodes. Skipping {self}.")
|
||||
return False
|
||||
|
||||
# Get the input tensors
|
||||
lhs_tensor = node.args[0]
|
||||
rhs_tensor = node.args[1]
|
||||
|
||||
# Check batch sizes from meta information
|
||||
lhs_batch_size = lhs_tensor.meta["val"].shape[0]
|
||||
rhs_batch_size = rhs_tensor.meta["val"].shape[0]
|
||||
|
||||
assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match"
|
||||
bmm_batch_size = lhs_batch_size
|
||||
|
||||
# Check if the distribution is balanced
|
||||
remainder = bmm_batch_size % self.world_size
|
||||
|
||||
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment.
|
||||
if remainder:
|
||||
ad_logger.warning(
|
||||
f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. "
|
||||
f"This will result in uneven distribution of work across devices. Skipping."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Apply BMM sharding transformation to the graph module."""
|
||||
|
||||
def handle_tensor(
|
||||
bmm_node: Node, tensor_node: Node, arg_idx: int, start_idx: int, end_idx: int
|
||||
):
|
||||
"""Unified helper function to shard either a parameter tensor or a dynamic tensor.
|
||||
|
||||
Args:
|
||||
bmm_node: The BMM node that is being processed
|
||||
tensor_node: The input tensor node to shard
|
||||
arg_idx: The argument index of the tensor in the BMM node
|
||||
start_idx: Start index for sharding
|
||||
end_idx: End index for sharding
|
||||
"""
|
||||
|
||||
# Define slice function for the sharding
|
||||
def slice_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
return t[start_idx:end_idx]
|
||||
|
||||
if tensor_node.op == "get_attr":
|
||||
# Handle parameter tensor
|
||||
weight_key = tensor_node.target
|
||||
modname, _, param_name = weight_key.rpartition(".")
|
||||
param = gm.get_parameter(weight_key)
|
||||
|
||||
# Update the parameter with its shard
|
||||
param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
|
||||
gm.get_submodule(modname).register_parameter(param_name, param_new)
|
||||
|
||||
# Register load state dict hook
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_load_hook,
|
||||
f_split=slice_tensor,
|
||||
param_key=weight_key,
|
||||
param_shape=param_new.shape,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Handle dynamic tensor
|
||||
with gm.graph.inserting_before(bmm_node):
|
||||
tensor_slice = gm.graph.call_function(
|
||||
torch.ops.aten.slice.Tensor, args=(tensor_node, 0, start_idx, end_idx, 1)
|
||||
)
|
||||
# Update BMM node to use the sliced tensor
|
||||
bmm_node.update_arg(arg_idx, tensor_slice)
|
||||
|
||||
# Get the input tensors
|
||||
lhs_tensor = node.args[0]
|
||||
rhs_tensor = node.args[1]
|
||||
# Handle both tensors
|
||||
handle_tensor(node, lhs_tensor, 0, self.start_idx, self.end_idx)
|
||||
handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx)
|
||||
|
||||
# Add all_gather node after BMM to collect results
|
||||
with gm.graph.inserting_after(node):
|
||||
gather_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_gather,
|
||||
args=(node, 0), # Gather along batch dimension (0)
|
||||
)
|
||||
node.replace_all_uses_with(gather_node)
|
||||
gather_node.replace_input_with(gather_node, node)
|
||||
|
||||
|
||||
class EPShardingInfo(ShardingTransformInfo):
|
||||
"""Configuration for EP sharding transformations."""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
|
||||
"""Validate the transformation configuration."""
|
||||
if not is_op(
|
||||
node,
|
||||
(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
),
|
||||
):
|
||||
ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply(self, gm: GraphModule, node: Node) -> None:
|
||||
"""Apply EP sharding transformation to the graph module."""
|
||||
_insert_sharded_moe(gm, node, self.rank, self.world_size)
|
||||
|
||||
|
||||
class ShardingConfig(BaseModel):
|
||||
"""Configuration for sharding the model."""
|
||||
|
||||
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
|
||||
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
|
||||
ep_transforms: List[EPShardingInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def sharding_transform_executor(gm: GraphModule, sharding_config: ShardingConfig) -> None:
|
||||
"""Apply transformations to the graph module.
|
||||
|
||||
Args:
|
||||
gm: Graph module to apply transformations to
|
||||
sharding_config: Transformation configuration containing list of transformations to apply
|
||||
"""
|
||||
# create a node dict for faster lookup
|
||||
node_dict = {n.name: n for n in gm.graph.nodes}
|
||||
|
||||
def check_and_apply(transform: ShardingTransformInfo) -> None:
|
||||
if transform.target_node is None or transform.target_node not in node_dict:
|
||||
ad_logger.warning(
|
||||
f"Skipping transformation {transform} because target node "
|
||||
+ f"{transform.target_node} not found in graph"
|
||||
)
|
||||
return
|
||||
transform.check_and_apply(gm, node_dict[transform.target_node])
|
||||
|
||||
for tp_transform in sharding_config.tp_transforms:
|
||||
check_and_apply(tp_transform)
|
||||
for bmm_transform in sharding_config.bmm_transforms:
|
||||
check_and_apply(bmm_transform)
|
||||
for ep_transform in sharding_config.ep_transforms:
|
||||
check_and_apply(ep_transform)
|
||||
|
||||
# canonicalize and return
|
||||
gm = canonicalize_graph(gm)
|
||||
ad_logger.debug("After applying sharding transformations: " + str(gm))
|
||||
|
||||
|
||||
def _load_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
@ -79,8 +325,8 @@ def _insert_sharded_matmul(
|
||||
world_size: int,
|
||||
add_dist: bool = False,
|
||||
min_local_shape: int = 1,
|
||||
):
|
||||
"""Replaces the matmul node with a new matmul node that accepts sharded weights.
|
||||
) -> None:
|
||||
"""Replace the matmul node with a new matmul node that accepts sharded weights.
|
||||
|
||||
The state_dict is also updated to contain the sharded weights.
|
||||
"""
|
||||
@ -200,22 +446,37 @@ def _insert_sharded_matmul(
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
|
||||
def _simple_shard(
|
||||
gm: GraphModule, nodes_linear: Dict[Node, List[Node]], rank: int, world_size: int
|
||||
):
|
||||
def _append_simple_shard(
|
||||
nodes_linear: Dict[Node, List[Node]],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
sharding_config: ShardingConfig,
|
||||
) -> None:
|
||||
# for every linear node:
|
||||
# --> row_split (dim 0 of weight) + all_gather (dim -1 of output)
|
||||
tp_shards: List[TPShardingInfo] = []
|
||||
for node_group in nodes_linear.values():
|
||||
for n in node_group:
|
||||
_insert_sharded_matmul(gm, n, 0, rank, world_size, add_dist=True)
|
||||
tp_shards.append(
|
||||
TPShardingInfo(
|
||||
target_node=n.name,
|
||||
split_dim=SplitDimension.ROW,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op="all_gather",
|
||||
min_local_shape=1,
|
||||
)
|
||||
)
|
||||
sharding_config.tp_transforms.extend(tp_shards)
|
||||
|
||||
|
||||
def column_row_shard(
|
||||
def detect_column_row_shard(
|
||||
gm: GraphModule,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
sharding_config: ShardingConfig,
|
||||
simple_shard_only: bool = False,
|
||||
) -> GraphModule:
|
||||
) -> None:
|
||||
"""A transformation to apply sharding to the model following tensor parallelism.
|
||||
|
||||
The transformation is based on the following steps:
|
||||
@ -236,7 +497,7 @@ def column_row_shard(
|
||||
|
||||
if world_size < 2:
|
||||
ad_logger.info("Skipping sharding for single device")
|
||||
return gm
|
||||
return
|
||||
|
||||
assert isinstance(gm, GraphModule), "Expecting GraphModule"
|
||||
|
||||
@ -312,13 +573,13 @@ def column_row_shard(
|
||||
|
||||
if simple_shard_only:
|
||||
ad_logger.debug(f"Forcing Simple Shard: Linear groups: {nodes_linear}")
|
||||
_simple_shard(gm, nodes_linear, rank, world_size)
|
||||
_append_simple_shard(nodes_linear, rank, world_size, sharding_config)
|
||||
continue
|
||||
|
||||
# simple shard when we have != 2 groups of linear nodes
|
||||
if len(nodes_linear) != 2:
|
||||
ad_logger.debug(f"Linear groups: {nodes_linear}")
|
||||
_simple_shard(gm, nodes_linear, rank, world_size)
|
||||
_append_simple_shard(nodes_linear, rank, world_size, sharding_config)
|
||||
continue
|
||||
|
||||
# let's look at the unnacounted nodes. They are okay as long as they fall before the
|
||||
@ -348,7 +609,7 @@ def column_row_shard(
|
||||
# check if any unaccounted nodes are left. If so, do a simply shard
|
||||
if unaccounted_nodes or attention_related_nodes:
|
||||
ad_logger.debug(f"Unaccounted nodes: {unaccounted_nodes}")
|
||||
_simple_shard(gm, nodes_linear, rank, world_size)
|
||||
_append_simple_shard(nodes_linear, rank, world_size, sharding_config)
|
||||
continue
|
||||
|
||||
# If we can account for all sharded nodes, we can do a two-way shard
|
||||
@ -360,7 +621,7 @@ def column_row_shard(
|
||||
# Column-row shard boundary region detection is probably wrong - there should be
|
||||
# only one attention operation. Fall back to simple shard.
|
||||
ad_logger.debug(f"More than one attention node: {unaccounted_nodes}")
|
||||
_simple_shard(gm, nodes_linear, rank, world_size)
|
||||
_append_simple_shard(nodes_linear, rank, world_size, sharding_config)
|
||||
continue
|
||||
# Extract head dimension. We cannot shard below the head_dim size.
|
||||
# Assume that head_dim is the last (innermost) dimension of the tensor
|
||||
@ -369,19 +630,27 @@ def column_row_shard(
|
||||
min_local_shape = 1
|
||||
for i, group in enumerate(nodes_linear.values()):
|
||||
for n in group:
|
||||
_insert_sharded_matmul(
|
||||
gm, n, i, rank, world_size, add_dist=i > 0, min_local_shape=min_local_shape
|
||||
if i > 0:
|
||||
dist_op = "all_reduce"
|
||||
else:
|
||||
dist_op = None
|
||||
sharding_config.tp_transforms.append(
|
||||
TPShardingInfo(
|
||||
target_node=n.name,
|
||||
split_dim=i,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=min_local_shape,
|
||||
)
|
||||
)
|
||||
|
||||
# canonicalize and return
|
||||
if num_shards:
|
||||
gm = canonicalize_graph(gm)
|
||||
ad_logger.debug("After sharding: " + str(gm))
|
||||
ad_logger.info(f"Found {num_shards} TP shards")
|
||||
return gm
|
||||
|
||||
|
||||
def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
def detect_dp_bmm_shard(
|
||||
gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig
|
||||
) -> None:
|
||||
"""A transformation to apply sharding to batched matrix multiplications in the graph.
|
||||
|
||||
We'll shard the BMM nodes by slicing the batch dimension of input tensors into world_size number of slices.
|
||||
@ -394,57 +663,12 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
|
||||
if world_size < 2:
|
||||
ad_logger.info("Skipping sharding for single device")
|
||||
return gm
|
||||
return
|
||||
|
||||
assert isinstance(gm, GraphModule), "Expecting GraphModule"
|
||||
|
||||
num_bmm_shards = 0
|
||||
|
||||
def handle_tensor(
|
||||
bmm_node: Node, tensor_node: Node, arg_idx: int, start_idx: int, end_idx: int
|
||||
):
|
||||
"""Unified helper function to shard either a parameter tensor or a dynamic tensor.
|
||||
|
||||
Args:
|
||||
bmm_node: The BMM node that is being processed
|
||||
tensor_node: The input tensor node to shard
|
||||
arg_idx: The argument index of the tensor in the BMM node
|
||||
start_idx: Start index for sharding
|
||||
end_idx: End index for sharding
|
||||
"""
|
||||
|
||||
# Define slice function for the sharding
|
||||
def slice_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
return t[start_idx:end_idx]
|
||||
|
||||
if tensor_node.op == "get_attr":
|
||||
# Handle parameter tensor
|
||||
weight_key = tensor_node.target
|
||||
modname, _, param_name = weight_key.rpartition(".")
|
||||
param = gm.get_parameter(weight_key)
|
||||
|
||||
# Update the parameter with its shard
|
||||
param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
|
||||
gm.get_submodule(modname).register_parameter(param_name, param_new)
|
||||
|
||||
# Register load state dict hook
|
||||
gm._register_load_state_dict_pre_hook(
|
||||
partial(
|
||||
_load_hook,
|
||||
f_split=slice_tensor,
|
||||
param_key=weight_key,
|
||||
param_shape=param_new.shape,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Handle dynamic tensor
|
||||
with gm.graph.inserting_before(bmm_node):
|
||||
tensor_slice = gm.graph.call_function(
|
||||
torch.ops.aten.slice.Tensor, args=(tensor_node, 0, start_idx, end_idx, 1)
|
||||
)
|
||||
# Update BMM node to use the sliced tensor
|
||||
bmm_node.update_arg(arg_idx, tensor_slice)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, {torch.ops.aten.bmm}):
|
||||
continue
|
||||
@ -482,23 +706,19 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
start_idx = remainder + rank * base_size
|
||||
end_idx = start_idx + base_size
|
||||
|
||||
sharding_config.bmm_transforms.append(
|
||||
BMMShardingInfo(
|
||||
target_node=node.name,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
start_idx=start_idx,
|
||||
end_idx=end_idx,
|
||||
)
|
||||
)
|
||||
ad_logger.debug(
|
||||
f"Sharding BMM for rank {rank}: batch_size={bmm_batch_size}, start_idx={start_idx}, end_idx={end_idx}"
|
||||
)
|
||||
|
||||
# Handle both tensors
|
||||
handle_tensor(node, lhs_tensor, 0, start_idx, end_idx)
|
||||
handle_tensor(node, rhs_tensor, 1, start_idx, end_idx)
|
||||
|
||||
# Add all_gather node after BMM to collect results
|
||||
with gm.graph.inserting_after(node):
|
||||
gather_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_gather,
|
||||
args=(node, 0), # Gather along batch dimension (0)
|
||||
)
|
||||
node.replace_all_uses_with(gather_node)
|
||||
gather_node.replace_input_with(gather_node, node)
|
||||
|
||||
num_bmm_shards += 1
|
||||
|
||||
# Canonicalize and return
|
||||
@ -506,4 +726,123 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
gm = canonicalize_graph(gm)
|
||||
ad_logger.debug("After sharding BMM: " + str(gm))
|
||||
ad_logger.info(f"Found {num_bmm_shards} BMM shards")
|
||||
return gm
|
||||
|
||||
|
||||
def detect_ep_shard(
|
||||
gm: GraphModule, rank: int, world_size: int, sharding_config: ShardingConfig
|
||||
) -> None:
|
||||
ad_logger.debug("Before sharding graph: " + str(gm))
|
||||
|
||||
if world_size < 2:
|
||||
ad_logger.info("Skipping sharding for single device")
|
||||
return
|
||||
|
||||
assert isinstance(gm, GraphModule), "Expecting GraphModule"
|
||||
num_moe_patterns = 0
|
||||
for node in list(gm.graph.nodes):
|
||||
if not is_op(
|
||||
node,
|
||||
(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
),
|
||||
):
|
||||
continue
|
||||
sharding_config.ep_transforms.append(
|
||||
EPShardingInfo(
|
||||
target_node=node.name,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
)
|
||||
num_moe_patterns += 1
|
||||
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE patterns")
|
||||
|
||||
|
||||
def _insert_sharded_moe(
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
"""Update the torch_moe node with sharded weight lists,
|
||||
sharded `selected_experts` and `final_scales(router_logics)`.
|
||||
Add an all_reduce node after the moe node.
|
||||
"""
|
||||
quant_impl = QuantizationImpl.create(node)
|
||||
scale_names = quant_impl.scale_names() if quant_impl else []
|
||||
|
||||
num_experts = len(node.args[3])
|
||||
args = list(node.args)
|
||||
|
||||
# -- Handle selected_experts and final_scales sharding --
|
||||
selected_experts = args[1]
|
||||
final_scales = args[2]
|
||||
|
||||
experts_per_rank = num_experts // world_size
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
lower = experts_per_rank * rank
|
||||
# selected_experts_local = selected_experts - low
|
||||
selected_experts_local = gm.graph.create_node(
|
||||
"call_function", operator.sub, args=(selected_experts, lower), kwargs={}
|
||||
)
|
||||
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
# if rank == world_size -1:
|
||||
# rank_mask = (selected_experts // experts_per_rank) >= rank
|
||||
# else:
|
||||
# rank_mask = (selected_experts // experts_per_rank) == rank
|
||||
div_node = gm.graph.create_node(
|
||||
"call_function", operator.floordiv, args=(selected_experts, experts_per_rank), kwargs={}
|
||||
)
|
||||
comp_op = torch.ge if rank == world_size - 1 else torch.eq
|
||||
rank_mask = gm.graph.create_node("call_function", comp_op, args=(div_node, rank), kwargs={})
|
||||
|
||||
# final_scales_local = final_scales * rank_mask
|
||||
final_scales_local = gm.graph.create_node(
|
||||
"call_function", operator.mul, args=(final_scales, rank_mask), kwargs={}
|
||||
)
|
||||
|
||||
# -- Shard expert weights --
|
||||
def get_partition(lst, world_size, rank):
|
||||
num_experts = len(lst)
|
||||
expert_size_per_partition = num_experts // world_size
|
||||
expert_start = rank * expert_size_per_partition
|
||||
# For num_experts % world_size != 0 case,
|
||||
# assign the last (num_experts % world_size) experts to the last rank
|
||||
expert_end = (
|
||||
num_experts if (rank == world_size - 1) else expert_start + expert_size_per_partition
|
||||
)
|
||||
return lst[expert_start:expert_end]
|
||||
|
||||
w1_list_sharded = get_partition(args[3], world_size, rank)
|
||||
w2_list_sharded = get_partition(args[4], world_size, rank)
|
||||
w3_list_sharded = get_partition(args[5], world_size, rank)
|
||||
|
||||
# -- Update args --
|
||||
args[1] = selected_experts_local
|
||||
args[2] = final_scales_local
|
||||
args[3] = w1_list_sharded
|
||||
args[4] = w2_list_sharded
|
||||
args[5] = w3_list_sharded
|
||||
|
||||
# Shard scales for quantized ops
|
||||
for i in range(len(scale_names) * 3): # 3 layers (w1, w2, w3) × #scale_names per layer
|
||||
args[6 + i] = get_partition(args[6 + i], world_size, rank)
|
||||
|
||||
ad_logger.debug(
|
||||
f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
|
||||
)
|
||||
node.args = tuple(args)
|
||||
|
||||
# -- add an all_reduce node --
|
||||
with gm.graph.inserting_after(node):
|
||||
dist_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)
|
||||
)
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
@ -5,12 +5,11 @@ from typing import Tuple
|
||||
|
||||
import model_explorer
|
||||
import torch
|
||||
import torch.export as te
|
||||
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
|
||||
from model_explorer.pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl
|
||||
from torch import fx
|
||||
|
||||
from ..export import torch_export
|
||||
|
||||
|
||||
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
|
||||
shape = tensor.shape
|
||||
@ -79,7 +78,7 @@ CUSTOM_OPS = (
|
||||
|
||||
# TODO(yudong): make viz as non-block call.
|
||||
def visualize_namespace(gm: fx.GraphModule, args: Tuple[torch.Tensor, ...], dynamic_shapes):
|
||||
ep = torch_export(gm, args=args, dynamic_shapes=dynamic_shapes)
|
||||
ep = te.export(gm, args=args, dynamic_shapes=dynamic_shapes)
|
||||
graph = ep.graph
|
||||
# Ensure the ops land up in the right module for better viz
|
||||
for n in graph.nodes:
|
||||
|
||||
@ -3,24 +3,26 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
|
||||
from ..compile import compile_and_capture
|
||||
from ..custom_ops.attention_interface import AttentionRegistry
|
||||
from ..distributed import common as dist_ad
|
||||
from ..llm_args import LlmArgs
|
||||
from ..llm_args import AutoDeployConfig
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
from ._graph import canonicalize_graph, lift_to_meta, move_to_device
|
||||
from .export import torch_export_to_gm
|
||||
from .library import (
|
||||
column_row_shard,
|
||||
dp_bmm_shard,
|
||||
ShardingConfig,
|
||||
detect_column_row_shard,
|
||||
detect_dp_bmm_shard,
|
||||
detect_ep_shard,
|
||||
eliminate_redundant_transposes,
|
||||
ep_shard,
|
||||
fuse_allreduce_residual_rmsnorm,
|
||||
fuse_collectives,
|
||||
fuse_rmsnorm,
|
||||
insert_cached_attention,
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
@ -32,17 +34,19 @@ from .library import (
|
||||
match_rope_pattern,
|
||||
optimize_rope,
|
||||
quantize,
|
||||
quantize_moe,
|
||||
resize_kv_cache,
|
||||
sharding_transform_executor,
|
||||
update_in_out_nodes,
|
||||
)
|
||||
|
||||
|
||||
class InferenceOptimizer:
|
||||
def __init__(self, factory: ModelFactory, ad_config: LlmArgs):
|
||||
def __init__(self, factory: ModelFactory, ad_config: AutoDeployConfig):
|
||||
self.factory = factory
|
||||
self.ad_config = ad_config
|
||||
|
||||
def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
|
||||
def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
|
||||
"""Transform a model into an optimized inference model.
|
||||
|
||||
Args:
|
||||
@ -54,53 +58,46 @@ class InferenceOptimizer:
|
||||
quantization: The quantization method to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A GraphModule representing the optimized inference model.
|
||||
A nn.Module representing the optimized inference model.
|
||||
"""
|
||||
############################################################################################
|
||||
# INITIALIZE MODEL
|
||||
# RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS
|
||||
############################################################################################
|
||||
model = self.factory.build_model(device="meta")
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
egm = new_optimizer(cm)
|
||||
|
||||
############################################################################################
|
||||
# EXPORT MODEL TO GRAPH MODULE
|
||||
############################################################################################
|
||||
|
||||
cm.info.set_example_sequence()
|
||||
egm = torch_export_to_gm(model, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
del model
|
||||
ad_logger.debug("original graph: " + str(egm))
|
||||
local_rank, world_size = dist_ad.get_rank_world_size()
|
||||
# TODO (lucaslie): continue moving legacy transforms to the new optimizer
|
||||
|
||||
############################################################################################
|
||||
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
|
||||
# quantization
|
||||
egm = quantize(egm, self.factory.get_quant_config())
|
||||
quantize(egm, self.factory.get_quant_config())
|
||||
quantize_moe(egm, self.factory.get_quant_config())
|
||||
|
||||
# Match MoE pattern
|
||||
egm = match_moe_pattern(egm)
|
||||
match_moe_pattern(egm)
|
||||
|
||||
# Match repeat_kv pattern
|
||||
egm = match_repeat_kv(egm)
|
||||
match_repeat_kv(egm)
|
||||
|
||||
# Match eager attention pattern
|
||||
egm = match_eager_attention(egm)
|
||||
match_eager_attention(egm)
|
||||
|
||||
# Match grouped attention pattern
|
||||
egm = match_grouped_attention(egm)
|
||||
match_grouped_attention(egm)
|
||||
|
||||
# Match and optimize causal attention masks
|
||||
egm = match_causal_attn_mask(egm)
|
||||
match_causal_attn_mask(egm)
|
||||
|
||||
# Match attention layout expected by our backend
|
||||
egm = match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend))
|
||||
match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend))
|
||||
|
||||
# Match rope
|
||||
egm, _ = match_rope_pattern(egm)
|
||||
match_rope_pattern(egm)
|
||||
|
||||
# Match RoPE layout expected by our backend
|
||||
egm = match_rope_layout(
|
||||
match_rope_layout(
|
||||
egm, AttentionRegistry.get(self.ad_config.attn_backend).get_attention_layout()
|
||||
)
|
||||
|
||||
@ -108,26 +105,35 @@ class InferenceOptimizer:
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
|
||||
local_rank, world_size = dist_ad.get_rank_world_size()
|
||||
|
||||
# eliminate redundant transpose operations
|
||||
egm = eliminate_redundant_transposes(egm)
|
||||
eliminate_redundant_transposes(egm)
|
||||
|
||||
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
|
||||
egm = optimize_rope(egm)
|
||||
optimize_rope(egm)
|
||||
|
||||
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
|
||||
sharding_config = ShardingConfig()
|
||||
|
||||
# run TP sharding across ranks
|
||||
egm = column_row_shard(egm, local_rank, world_size, self.ad_config.simple_shard_only)
|
||||
detect_column_row_shard(
|
||||
egm, local_rank, world_size, sharding_config, self.ad_config.simple_shard_only
|
||||
)
|
||||
|
||||
# run EP sharding across ranks
|
||||
egm = ep_shard(egm, local_rank, world_size)
|
||||
detect_ep_shard(egm, local_rank, world_size, sharding_config)
|
||||
|
||||
# run BMM sharding across ranks
|
||||
egm = dp_bmm_shard(egm, local_rank, world_size)
|
||||
detect_dp_bmm_shard(egm, local_rank, world_size, sharding_config)
|
||||
|
||||
sharding_transform_executor(egm, sharding_config)
|
||||
|
||||
# let's run a shape propagation pass to update the graph with correct meta values for
|
||||
# subsequent optimization passes. Lift state_dict to meta as shape propagation involves device check
|
||||
with lift_to_meta(egm):
|
||||
egm = canonicalize_graph(egm, shape_prop=True)
|
||||
canonicalize_graph(egm, shape_prop=True)
|
||||
|
||||
############################################################################################
|
||||
# MOVE MODEL AND LOAD WEIGHTS
|
||||
@ -146,17 +152,21 @@ class InferenceOptimizer:
|
||||
|
||||
# run MoE fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# egm = fuse_moe(egm)
|
||||
# fuse_moe(egm)
|
||||
|
||||
# run GEMM fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# egm = fuse_gemms(egm)
|
||||
# fuse_gemms(egm)
|
||||
|
||||
# check if we can fuse allreduce, residual and rmsnorm
|
||||
egm = fuse_allreduce_residual_rmsnorm(egm)
|
||||
fuse_allreduce_residual_rmsnorm(egm)
|
||||
|
||||
# check if we can fuse collectives
|
||||
egm = fuse_collectives(egm)
|
||||
fuse_collectives(egm)
|
||||
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
fuse_rmsnorm(egm, "flashinfer")
|
||||
|
||||
# visualize the final graph
|
||||
if self.ad_config.visualize:
|
||||
@ -175,12 +185,12 @@ class InferenceOptimizer:
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
|
||||
egm = update_in_out_nodes(egm, cm)
|
||||
update_in_out_nodes(egm, cm)
|
||||
|
||||
# detect attention op and replace with cache-aware op
|
||||
for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
egm = insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
# initialize cache on correct device
|
||||
cm.initialize_caches()
|
||||
|
||||
122
tensorrt_llm/_torch/auto_deploy/utils/_config.py
Normal file
122
tensorrt_llm/_torch/auto_deploy/utils/_config.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Helper functions for config-related settings."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
|
||||
from pydantic_settings.sources.types import PathType
|
||||
|
||||
|
||||
def deep_merge_dicts(*confs: Union[Dict, DictConfig]) -> Dict:
|
||||
"""Deep merge a list of dictionaries via OmegaConf.merge.
|
||||
|
||||
Args:
|
||||
*confs: A list of dictionaries or DictConfig objects to merge.
|
||||
|
||||
Returns:
|
||||
A merged dictionary.
|
||||
"""
|
||||
if len(confs) == 0:
|
||||
return {}
|
||||
merged_conf = OmegaConf.merge(*[OmegaConf.create(conf) for conf in confs])
|
||||
result = OmegaConf.to_container(merged_conf, resolve=True)
|
||||
assert isinstance(result, Dict), f"Expected dict, got {type(result)}"
|
||||
return result
|
||||
|
||||
|
||||
class DynamicYamlWithDeepMergeSettingsSource(YamlConfigSettingsSource):
|
||||
"""YAML config settings source that dynamically loads files and merges them via deep update.
|
||||
|
||||
We utilize the omegaconf library for deep merging.
|
||||
"""
|
||||
|
||||
def _read_files(self, files: PathType | None) -> dict[str, Any]:
|
||||
if files is None:
|
||||
return {}
|
||||
if isinstance(files, (str, os.PathLike)):
|
||||
files = [files]
|
||||
|
||||
confs = []
|
||||
for file in files:
|
||||
file_path = Path(file).expanduser()
|
||||
if file_path.is_file():
|
||||
confs.append(OmegaConf.load(file_path))
|
||||
|
||||
return deep_merge_dicts(*confs)
|
||||
|
||||
def __call__(self):
|
||||
"""Call additional config files based on current state."""
|
||||
yaml_data = self.yaml_data # this points to the default yaml data now
|
||||
additional_files_data = self._read_files(self.current_state.get("yaml_configs", []))
|
||||
|
||||
return deep_merge_dicts(yaml_data, additional_files_data)
|
||||
|
||||
|
||||
class DynamicYamlMixInForSettings:
|
||||
"""Mix-in class for settings providing dynamic yaml loading as lowest priority source.
|
||||
|
||||
NOTE: This class must come FIRST in the MRO such that `yaml_configs` can be processed before
|
||||
since otherwise we cannot load default values from the `yaml_configs` first.
|
||||
|
||||
This mix-in enforces the following precedence order:
|
||||
- init settings
|
||||
- env settings
|
||||
- dotenv settings
|
||||
- file secret settings
|
||||
- yaml configs
|
||||
- default settings
|
||||
|
||||
You can learn more about the different settings sources in
|
||||
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#field-value-priority.
|
||||
|
||||
Note in particular how yaml settings have precedence only over default settings. You can hence
|
||||
think of the yaml settings as a way to override default settings.
|
||||
|
||||
Also consider the following consequences of precedence order in nested config settings:
|
||||
- yaml configs for outer settings get converted to init settings for inner settings and hence
|
||||
ALWAYS take precedence over yaml configs specified for inner settings.
|
||||
- This implies inner settings from outer yaml configs also take precedence over outer inner
|
||||
settings like env settings since they are now init settings from the view of the inner
|
||||
settings.
|
||||
- Explicitly initialized fields for inner settings take precedence over outer yaml configs for
|
||||
inner settings since they are provided as init arguments.
|
||||
- Check out ``tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_config.py`` for more
|
||||
examples.
|
||||
|
||||
|
||||
You can also provide multiple yaml config files to load. In this case, the files are deep merged
|
||||
together in the order they are provided. Hence, the following order (decreasing precedence) for
|
||||
multiple yaml config files is:
|
||||
- default yaml provided as ``yaml_file`` argument in the ``model_config`` (``ConfigDict``)
|
||||
- argument 0 of ``yaml_configs``
|
||||
- argument 1 of ``yaml_configs``
|
||||
- ...
|
||||
- last argument of ``yaml_configs``
|
||||
"""
|
||||
|
||||
yaml_configs: List[PathType] = Field(
|
||||
default_factory=list,
|
||||
description="Additional yaml config files to load.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Customise settings sources."""
|
||||
deferred_yaml_settings = DynamicYamlWithDeepMergeSettingsSource(settings_cls)
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
deferred_yaml_settings, # yaml files have lowest priority just before default values
|
||||
)
|
||||
@ -25,7 +25,8 @@ except ImportError:
|
||||
modelopt_quantize_op = None
|
||||
modelopt_dynamic_block_quantize_op = None
|
||||
|
||||
OperatorLike = Union[OpOverloadPacket, OpOverload, Callable]
|
||||
OpOrOverload = Union[OpOverloadPacket, OpOverload]
|
||||
OperatorLike = Union[OpOrOverload, Callable]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -106,27 +107,17 @@ def get_quantization_params_from_linear_node(linear_op: torch.fx.node.Node):
|
||||
return input_params, weight_params, output_params
|
||||
|
||||
|
||||
def is_match(node: Node, names_to_skip: List[str]):
|
||||
if names_to_skip is None:
|
||||
return False
|
||||
for n in names_to_skip:
|
||||
module_stack = node.meta.get("nn_module_stack", None)
|
||||
if module_stack is None:
|
||||
return False
|
||||
module_stack = list(module_stack.keys())
|
||||
if n in module_stack[-1]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_weight_node(mm_node: Node) -> int:
|
||||
"""Extracts the weight node from the given matmul node."""
|
||||
"""Extracts the weight node from the given linear or BMM node. We assume torch.bmm(activation, weight)"""
|
||||
|
||||
def find_get_attr_node(node: Node) -> Node:
|
||||
"""Recursively traverse inputs of allowed nodes to find a node with 'get_attr' op."""
|
||||
# If node is a get_attr node return node
|
||||
# List of nodes allowed in between a get_attr node and the matmul node
|
||||
allowed_ops = {torch.ops.aten.to.dtype}
|
||||
allowed_ops = {
|
||||
torch.ops.aten.to.dtype,
|
||||
torch.ops.aten.view.default,
|
||||
}
|
||||
|
||||
if node.op == "get_attr":
|
||||
return node
|
||||
@ -161,8 +152,8 @@ def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str]
|
||||
Args:
|
||||
mm_node: Matmul node in the graph.
|
||||
"""
|
||||
assert is_linear_op(mm_node, include_quantization=True), (
|
||||
f"Expecting linear node, Found: {mm_node}"
|
||||
assert is_linear_op(mm_node, include_quantization=True) or is_bmm_op(mm_node), (
|
||||
f"Expecting linear or bmm node, Found: {mm_node}"
|
||||
)
|
||||
weight_node = extract_weight_node(mm_node)
|
||||
|
||||
@ -215,6 +206,37 @@ def is_op(node: Node, ops: Union[OperatorLike, Iterable[OperatorLike]]) -> bool:
|
||||
return is_match
|
||||
|
||||
|
||||
def filtered_nodes(
|
||||
nodes: Iterable[Node], ops: Union[OperatorLike, Iterable[OperatorLike]]
|
||||
) -> Iterable[Node]:
|
||||
"""Iterate over nodes that are filtered by the given operations.
|
||||
|
||||
This utility function simplifies the common pattern of iterating through nodes
|
||||
and filtering by operation type.
|
||||
|
||||
Args:
|
||||
nodes: Iterable of nodes to filter (e.g., gm.graph.nodes)
|
||||
ops: Operation(s) to match against
|
||||
|
||||
Yields:
|
||||
Node: Nodes that match the given operations
|
||||
|
||||
Example:
|
||||
# Instead of:
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.aten.linear):
|
||||
continue
|
||||
# process node
|
||||
|
||||
# Use:
|
||||
for node in filtered_nodes(gm.graph.nodes, torch.ops.aten.linear):
|
||||
# process node
|
||||
"""
|
||||
for node in nodes:
|
||||
if is_op(node, ops):
|
||||
yield node
|
||||
|
||||
|
||||
def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
|
||||
"""Check if the node is a linear op.
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ from torch._inductor.pattern_matcher import (
|
||||
)
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from fnmatch import fnmatch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -12,7 +13,9 @@ from ..custom_ops.quant import (
|
||||
)
|
||||
from .logger import ad_logger
|
||||
from .node_utils import (
|
||||
extract_param_names_from_lin_node,
|
||||
get_quantization_params_from_linear_node,
|
||||
is_bmm_op,
|
||||
is_linear_op,
|
||||
is_op,
|
||||
modelopt_dynamic_block_quantize_op,
|
||||
@ -20,7 +23,7 @@ from .node_utils import (
|
||||
)
|
||||
|
||||
try:
|
||||
from ...quantization.utils import float4_sf_dtype
|
||||
from ....quantization.utils.fp4_utils import float4_sf_dtype
|
||||
except ImportError:
|
||||
float4_sf_dtype = None
|
||||
|
||||
@ -83,6 +86,7 @@ class QuantizationImpl:
|
||||
quantization_impl_map = {
|
||||
"": None,
|
||||
"FP8": FP8QuantizationImpl,
|
||||
"NVFP4": FP4QuantizationImpl,
|
||||
}
|
||||
return quantization_impl_map[quant_type_or_node]
|
||||
|
||||
@ -461,3 +465,48 @@ class FP8BMMQuantizationImpl(QuantizationImpl):
|
||||
attr_name,
|
||||
torch.nn.Parameter(param_cm, requires_grad=param.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
def should_skip_quantization(
|
||||
node_or_name: Union[Node, str],
|
||||
excluded_patterns: list[str],
|
||||
) -> bool:
|
||||
"""Check if a node or parameter name should be skipped based on excluded patterns."""
|
||||
if isinstance(node_or_name, str):
|
||||
modname, _, _ = node_or_name.rpartition(".")
|
||||
else:
|
||||
if not (is_linear_op(node_or_name, include_quantization=False) or is_bmm_op(node_or_name)):
|
||||
return True
|
||||
param_name, _ = extract_param_names_from_lin_node(node_or_name)
|
||||
modname, _, _ = param_name.rpartition(".")
|
||||
|
||||
return any(fnmatch(modname, pattern) for pattern in excluded_patterns)
|
||||
|
||||
|
||||
def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Optional[Node]]:
|
||||
"""
|
||||
Extracts scale tensors from node.args/kwargs using a fixed list of expected scale names.
|
||||
"""
|
||||
scales = {}
|
||||
args = list(node.args)
|
||||
|
||||
# Try kwargs first
|
||||
for i, name in enumerate(scale_names):
|
||||
scales[name] = node.kwargs.get(name, None)
|
||||
|
||||
# Fallback to positional args (starting after input, weight, bias)
|
||||
for i, name in enumerate(scale_names):
|
||||
if scales[name] is None and len(args) > 3 + i:
|
||||
scales[name] = args[3 + i]
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
|
||||
"""Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
|
||||
for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
|
||||
if is_op(node, qtype.target_op()):
|
||||
return extract_scales_from_node(
|
||||
node, qtype.scale_names()
|
||||
), qtype.__name__.lower().replace("quantizationimpl", "")
|
||||
return None, "simple"
|
||||
|
||||
@ -388,6 +388,9 @@ def throughput_command(
|
||||
logger.warning(
|
||||
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
|
||||
)
|
||||
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
|
||||
kwargs.pop("pipeline_parallel_size", None)
|
||||
|
||||
llm = AutoDeployLLM(**kwargs)
|
||||
else:
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
@ -5,9 +5,19 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _torch_test_utils import all_close, reset_parameters
|
||||
from torch.export import export
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo
|
||||
|
||||
|
||||
class FakeFactory:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
|
||||
def build_model(self, device: str) -> nn.Module:
|
||||
return self.model.to(device=device)
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module):
|
||||
@ -58,17 +68,17 @@ def run_test(
|
||||
|
||||
# graph transformation + check
|
||||
if check_num_matches:
|
||||
gm_transformed, num_matches = transform(gm, *args)
|
||||
num_matches = transform(gm, *args)
|
||||
assert check_num_matches == num_matches, (
|
||||
f"expect {check_num_matches} matches, but got {num_matches}"
|
||||
)
|
||||
else:
|
||||
gm_transformed = transform(gm, *args)
|
||||
print(gm_transformed)
|
||||
transform(gm, *args)
|
||||
print(gm)
|
||||
# in case buffers or other tensors were added during the transform
|
||||
gm_transformed = gm_transformed.to("cuda")
|
||||
y_transformed = gm_transformed(x)
|
||||
n_p_transformed = count_parameters(gm_transformed)
|
||||
gm = gm.to("cuda")
|
||||
y_transformed = gm(x)
|
||||
n_p_transformed = count_parameters(gm)
|
||||
|
||||
n_p_t_expected = _get_expected_num_params(num_params_model)
|
||||
assert n_p_transformed == n_p_t_expected, (
|
||||
@ -76,7 +86,7 @@ def run_test(
|
||||
)
|
||||
|
||||
# check if the transformation worked
|
||||
assert check_transformed_graph(gm_transformed)
|
||||
assert check_transformed_graph(gm)
|
||||
|
||||
if strict_loading and not skip_output_assert:
|
||||
# check if output equals without loading state dict
|
||||
@ -84,26 +94,43 @@ def run_test(
|
||||
|
||||
if test_load_hook and not skip_output_assert:
|
||||
# check if loading hook works from original state dict
|
||||
reset_parameters(gm_transformed)
|
||||
y_random = gm_transformed(x)
|
||||
reset_parameters(gm)
|
||||
y_random = gm(x)
|
||||
assert not all_close(y_model, y_random), f"{y_model=}, {y_random=}"
|
||||
|
||||
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
|
||||
y_loaded_from_original = gm_transformed(x)
|
||||
gm.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
|
||||
y_loaded_from_original = gm(x)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
|
||||
|
||||
# check if loading hook works from state_dict of a transformed model
|
||||
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
|
||||
reset_parameters(gm_transformed)
|
||||
y_random2 = gm_transformed(x)
|
||||
state_dict_sharded = copy.deepcopy(gm.state_dict())
|
||||
reset_parameters(gm)
|
||||
y_random2 = gm(x)
|
||||
assert not all_close(y_model, y_random2), f"{y_model=}, {y_random2=}"
|
||||
|
||||
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
|
||||
y_loaded_from_transformed = gm_transformed(x)
|
||||
gm.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
|
||||
y_loaded_from_transformed = gm(x)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
|
||||
|
||||
# check if we can still export the model as expected
|
||||
torch_export(gm_transformed, args=(x,))
|
||||
export(gm, args=(x,))
|
||||
|
||||
# return graph module for further testing
|
||||
return gm_transformed
|
||||
return gm
|
||||
|
||||
|
||||
def run_sharding_pattern_detection_test(
|
||||
detected_transformations: List[ShardingTransformInfo],
|
||||
expected_transformations: List[ShardingTransformInfo],
|
||||
) -> None:
|
||||
"""Compare two lists of transformations ignoring order.
|
||||
|
||||
Args:
|
||||
detected_transformations: List of detected transformation configurations
|
||||
expected_transformations: List of expected transformation configurations
|
||||
"""
|
||||
# Convert to sets for unordered comparison
|
||||
detected_set = set(detected_transformations)
|
||||
expected_set = set(expected_transformations)
|
||||
|
||||
assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern"
|
||||
|
||||
@ -242,23 +242,14 @@ class BMMDynamicModel(nn.Module):
|
||||
self.hidden_dim = hidden_dim
|
||||
self.batch_size = batch_size
|
||||
# Create a linear layer to generate dynamic weights
|
||||
self.weight_generator = nn.Linear(hidden_dim, hidden_dim * hidden_dim)
|
||||
self.weight = nn.Parameter(torch.randn(batch_size, hidden_dim * hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: [batch_size, seq_len, hidden_dim]
|
||||
batch_size, seq_len, hidden_dim = x.shape
|
||||
|
||||
# Generate dynamic weights from input
|
||||
# Take mean across sequence dimension to get [batch_size, hidden_dim]
|
||||
weight_input = x.mean(dim=1) # [batch_size, hidden_dim]
|
||||
|
||||
# Generate weights: [batch_size, hidden_dim * hidden_dim]
|
||||
weight_flat = self.weight_generator(weight_input)
|
||||
|
||||
# Reshape to BMM weight format: [batch_size, hidden_dim, hidden_dim]
|
||||
dynamic_weights = weight_flat.view(batch_size, hidden_dim, hidden_dim)
|
||||
|
||||
# Perform BMM with dynamic weights
|
||||
dynamic_weights = self.weight.view(batch_size, hidden_dim, hidden_dim)
|
||||
return torch.bmm(x, dynamic_weights)
|
||||
|
||||
|
||||
@ -437,6 +428,15 @@ _SMALL_MODEL_CONFIGS = {
|
||||
"q_lora_rank": 128,
|
||||
},
|
||||
},
|
||||
"Qwen/Qwen2.5-3B-Instruct": {
|
||||
"model": _hf_model_dir_or_hub_id(
|
||||
f"{llm_models_root()}/Qwen/Qwen2.5-3B-Instruct",
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
),
|
||||
"model_kwargs": {
|
||||
"num_hidden_layers": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,201 @@
|
||||
"""Torch attention reference implementations for testing.
|
||||
|
||||
This module provides clean reference implementations using the torch backend
|
||||
that can be used across all attention operation test files to eliminate
|
||||
code duplication and ensure consistency.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
class TorchAttentionReference:
|
||||
"""Reference implementation using the torch backend for consistency."""
|
||||
|
||||
@staticmethod
|
||||
def basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions, scale=None):
|
||||
"""Reference implementation for basic MHA with cache (generate phase).
|
||||
|
||||
This matches the signature of triton_attention_fused_mha_with_cache.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seq, n_heads, head_dim]
|
||||
k: Key tensor [batch, seq, n_kv_heads, head_dim]
|
||||
v: Value tensor [batch, seq, n_kv_heads, head_dim]
|
||||
k_cache: Key cache [batch, max_seq_len, n_kv_heads, head_dim]
|
||||
v_cache: Value cache [batch, max_seq_len, n_kv_heads, head_dim]
|
||||
input_positions: Positions to update cache [batch]
|
||||
scale: Optional attention scale
|
||||
|
||||
Returns:
|
||||
Attention output [batch, seq, n_heads, head_dim] (same shape as q)
|
||||
"""
|
||||
batch_size, seq_len = q.shape[:2]
|
||||
|
||||
# Convert to flattened format for torch backend
|
||||
seq_len_tensor = torch.full((batch_size,), seq_len, device=q.device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(batch_size, device=q.device, dtype=torch.int32)
|
||||
seq_start = torch.arange(
|
||||
0, batch_size * seq_len, seq_len, device=q.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Flatten inputs to [1, total_seq_len, ...] format
|
||||
q_flat = q.view(1, batch_size * seq_len, -1)
|
||||
k_flat = k.view(1, batch_size * seq_len, -1)
|
||||
v_flat = v.view(1, batch_size * seq_len, -1)
|
||||
|
||||
# Call torch backend via custom op registry
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
seq_len_tensor,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scale,
|
||||
)
|
||||
|
||||
# Reshape back to original format [batch, seq, n_heads, head_dim]
|
||||
if q.ndim == 4:
|
||||
# Input was [batch, seq, n_heads, head_dim], but triton always returns flattened
|
||||
# So return [batch, seq, n_heads * head_dim] to match triton behavior
|
||||
return output_flat.view(batch_size, seq_len, -1)
|
||||
else:
|
||||
# Input was [batch, seq, n_heads * head_dim], return same shape
|
||||
return output_flat.view(batch_size, seq_len, -1)
|
||||
|
||||
@staticmethod
|
||||
def flattened_mha_with_cache(
|
||||
q, k, v, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache, scale=None
|
||||
):
|
||||
"""Reference implementation following triton flattened MHA pattern.
|
||||
|
||||
This function directly calls the torch backend implementation via custom op registry.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
q, k, v, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache, scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def decode_with_prefilled_cache(q, k_ref, v_ref, k_cache, v_cache, prefill_lengths):
|
||||
"""Reference for decode phase with pre-filled cache (flashinfer tests).
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seq=1, n_heads, head_dim]
|
||||
k_ref: Reference keys (full context including prefill + new token)
|
||||
v_ref: Reference values (full context including prefill + new token)
|
||||
k_cache: Key cache [batch, max_seq_len, n_heads, head_dim]
|
||||
v_cache: Value cache [batch, max_seq_len, n_heads, head_dim]
|
||||
prefill_lengths: Number of pre-filled tokens per batch [batch]
|
||||
|
||||
Returns:
|
||||
Attention output [batch, seq=1, n_heads * head_dim]
|
||||
"""
|
||||
batch_size = q.shape[0]
|
||||
seq_len = torch.ones(batch_size, device=q.device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(batch_size, device=q.device, dtype=torch.int32)
|
||||
# Fix: Each sequence starts at its own position in the flattened tensor
|
||||
seq_start = torch.arange(batch_size, device=q.device, dtype=torch.int32)
|
||||
|
||||
# For decode phase, input_positions should be the prefill_lengths (where to append new token)
|
||||
input_positions = prefill_lengths.to(torch.int32)
|
||||
|
||||
# Extract the new k,v tokens from k_ref, v_ref (last token for each batch)
|
||||
k_new = k_ref[:, -1:, :, :] # [batch, 1, n_heads, head_dim]
|
||||
v_new = v_ref[:, -1:, :, :] # [batch, 1, n_heads, head_dim]
|
||||
|
||||
# Convert to flattened format [1, total_seq_len, ...]
|
||||
q_flat = q.view(1, batch_size, -1)
|
||||
k_flat = k_new.view(1, batch_size, -1)
|
||||
v_flat = v_new.view(1, batch_size, -1)
|
||||
|
||||
# Call torch backend via custom op registry
|
||||
output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
q_flat,
|
||||
k_flat,
|
||||
v_flat,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
None,
|
||||
)
|
||||
|
||||
# Return in flattened format to match flashinfer backend behavior [batch, seq=1, n_heads * head_dim]
|
||||
return output_flat.view(batch_size, 1, -1)
|
||||
|
||||
@staticmethod
|
||||
def mha_with_features(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scale=None,
|
||||
logit_cap=None,
|
||||
sliding_window_size=None,
|
||||
):
|
||||
"""Reference implementation with advanced features (logit capping, sliding window).
|
||||
|
||||
This demonstrates how to use the torch backend with additional features.
|
||||
"""
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len,
|
||||
input_positions,
|
||||
cache_loc,
|
||||
seq_start,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scale,
|
||||
None, # sinks
|
||||
sliding_window_size,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_flattened_inputs(q_list, k_list, v_list, input_positions_list):
|
||||
"""Helper to convert list of per-sequence tensors to flattened format.
|
||||
|
||||
Args:
|
||||
q_list: List of query tensors per sequence
|
||||
k_list: List of key tensors per sequence
|
||||
v_list: List of value tensors per sequence
|
||||
input_positions_list: List of input positions per sequence
|
||||
|
||||
Returns:
|
||||
Tuple of (q_flat, k_flat, v_flat, seq_len, input_positions, cache_loc, seq_start)
|
||||
"""
|
||||
device = q_list[0].device
|
||||
|
||||
# Compute sequence metadata
|
||||
seq_lengths = [q.shape[0] for q in q_list]
|
||||
seq_len = torch.tensor(seq_lengths, device=device, dtype=torch.int32)
|
||||
seq_start = torch.tensor(
|
||||
[sum(seq_lengths[:i]) for i in range(len(seq_lengths))],
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Flatten tensors
|
||||
q_flat = torch.cat(q_list, dim=0).unsqueeze(0) # [1, total_seq_len, ...]
|
||||
k_flat = torch.cat(k_list, dim=0).unsqueeze(0) # [1, total_seq_len, ...]
|
||||
v_flat = torch.cat(v_list, dim=0).unsqueeze(0) # [1, total_seq_len, ...]
|
||||
|
||||
# Create metadata tensors
|
||||
input_positions = torch.tensor(input_positions_list, device=device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(len(q_list), device=device, dtype=torch.int32)
|
||||
|
||||
return q_flat, k_flat, v_flat, seq_len, input_positions, cache_loc, seq_start
|
||||
@ -8,8 +8,8 @@ from transformers import AutoConfig, AutoProcessor, Llama4ForConditionalGenerati
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
|
||||
|
||||
|
||||
# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651
|
||||
|
||||
@ -3,10 +3,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _dist_test_utils import get_device_counts
|
||||
from torch.export import export
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.distributed import common as dist
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.trtllm import is_trtllm_op_available
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.collectives import (
|
||||
fuse_allreduce_residual_rmsnorm,
|
||||
)
|
||||
@ -64,14 +65,14 @@ def _test_allreduce_fusion(port: int):
|
||||
original_outputs, residual_original = gm(x, residual)
|
||||
|
||||
# Fuse ops
|
||||
gm_fused = fuse_allreduce_residual_rmsnorm(gm)
|
||||
fuse_allreduce_residual_rmsnorm(gm)
|
||||
|
||||
# Run the fused graph
|
||||
fused_outputs, residual_fused = gm_fused(x, residual)
|
||||
fused_outputs, residual_fused = gm(x, residual)
|
||||
|
||||
# Check if fused node in the graph
|
||||
has_fused_node = False
|
||||
for node in gm_fused.graph.nodes:
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm):
|
||||
has_fused_node = True
|
||||
assert has_fused_node, "Fused node not found."
|
||||
@ -85,8 +86,8 @@ def _test_allreduce_fusion(port: int):
|
||||
)
|
||||
|
||||
# check if we can still export the model as expected
|
||||
torch_export(gm_fused, args=args)
|
||||
torch_export_to_gm(gm_fused, args=args)
|
||||
export(gm, args=args)
|
||||
torch_export_to_gm(gm, args=args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
|
||||
@ -6,10 +6,16 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _dist_test_utils import get_device_counts
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_sharding_pattern_detection_test, run_test
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import dp_bmm_shard
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import (
|
||||
BMMShardingInfo,
|
||||
ShardingConfig,
|
||||
detect_dp_bmm_shard,
|
||||
sharding_transform_executor,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -48,9 +54,9 @@ class BMM(nn.Module):
|
||||
|
||||
|
||||
def _run_job(
|
||||
num_experts_multiplier: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
num_experts_multiplier: int,
|
||||
) -> None:
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
@ -63,22 +69,82 @@ def _run_job(
|
||||
num_params = num_p_og // world_size
|
||||
return num_params
|
||||
|
||||
def transform_func(gm) -> None:
|
||||
sharding_config = ShardingConfig()
|
||||
detect_dp_bmm_shard(gm, rank, world_size, sharding_config)
|
||||
sharding_transform_executor(gm, sharding_config)
|
||||
|
||||
# now run the test
|
||||
op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather")
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
transform=partial(dp_bmm_shard, rank=rank, world_size=world_size),
|
||||
transform=transform_func,
|
||||
check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes)
|
||||
== (world_size > 1),
|
||||
_get_expected_num_params=_get_expected_num_params,
|
||||
)
|
||||
|
||||
|
||||
def _run_pattern_detection_job(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
num_experts_multiplier: int,
|
||||
) -> None:
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
num_features = 10
|
||||
num_experts = num_experts_multiplier * world_size
|
||||
start_idx = rank * num_experts_multiplier
|
||||
end_idx = start_idx + num_experts_multiplier
|
||||
model = BMM(num_experts, num_features).to(device="cuda", dtype=torch.float16)
|
||||
x = torch.randn(batch_size * num_experts, num_features, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test pattern detection - create expected transformations for validation
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
expected_transformations = []
|
||||
# if world_size == 1, no sharding transformations should be detected
|
||||
if world_size > 1:
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.aten.bmm):
|
||||
expected_transformations.append(
|
||||
BMMShardingInfo(
|
||||
target_node=node.name,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
start_idx=start_idx,
|
||||
end_idx=end_idx,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
sharding_config = ShardingConfig()
|
||||
detect_dp_bmm_shard(gm, rank, world_size, sharding_config)
|
||||
detected_transformations = sharding_config.bmm_transforms
|
||||
|
||||
# Run pattern detection test
|
||||
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts_multiplier", [1, 2])
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
def test_sharding(device_count: int, num_experts_multiplier: int):
|
||||
dist_common.spawn_multiprocess_job(
|
||||
job=partial(_run_job, num_experts_multiplier=num_experts_multiplier),
|
||||
job=partial(_run_job, num_experts_multiplier),
|
||||
size=device_count,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 8])
|
||||
@pytest.mark.parametrize("num_experts_multiplier", [1, 2])
|
||||
def test_sharding_pattern_detection(world_size: int, num_experts_multiplier: int):
|
||||
"""Test pattern detection logic without distributed execution.
|
||||
|
||||
This test verifies only the pattern detection logic with provided world_size.
|
||||
No need to run distributed job, can be run on single process.
|
||||
"""
|
||||
_run_pattern_detection_job(
|
||||
num_experts_multiplier=num_experts_multiplier,
|
||||
rank=0,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@ -5,11 +5,17 @@ from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
from _dist_test_utils import get_device_counts
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_sharding_pattern_detection_test, run_test
|
||||
from _model_test_utils import MoEOpModel
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import ep_shard
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import (
|
||||
EPShardingInfo,
|
||||
ShardingConfig,
|
||||
detect_ep_shard,
|
||||
sharding_transform_executor,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -33,12 +39,17 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
|
||||
expected_expert = num_experts_per_rank * hidden_size * intermediate_size * 3
|
||||
return n_gate + expected_expert
|
||||
|
||||
def transform_func(gm) -> None:
|
||||
sharding_config = ShardingConfig()
|
||||
detect_ep_shard(gm, rank, world_size, sharding_config)
|
||||
sharding_transform_executor(gm, sharding_config)
|
||||
|
||||
op_expected = torch.ops.auto_deploy.torch_dist_all_reduce
|
||||
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
transform=partial(ep_shard, rank=rank, world_size=world_size),
|
||||
transform=transform_func,
|
||||
check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes)
|
||||
== (world_size > 1),
|
||||
_get_expected_num_params=partial(_get_expected_num_params, rank, world_size),
|
||||
@ -46,6 +57,46 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> None:
|
||||
device = "cuda"
|
||||
hidden_size = 32
|
||||
intermediate_size = 16
|
||||
model = MoEOpModel(
|
||||
hidden_size=hidden_size, num_experts=num_experts, intermediate_size=intermediate_size
|
||||
).to(device=device, dtype=torch.bfloat16)
|
||||
x = model.get_input(device=device, dtype=torch.bfloat16)
|
||||
|
||||
# Test pattern detection - create expected transformations for validation
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
expected_transformations = []
|
||||
# if world_size == 1, no sharding transformations should be detected
|
||||
if world_size > 1:
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(
|
||||
node,
|
||||
(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
),
|
||||
):
|
||||
expected_transformations.append(
|
||||
EPShardingInfo(
|
||||
target_node=node.name,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
sharding_config = ShardingConfig()
|
||||
detect_ep_shard(gm, rank, world_size, sharding_config)
|
||||
detected_transformations = sharding_config.ep_transforms
|
||||
|
||||
# Run pattern detection test
|
||||
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
@pytest.mark.parametrize("num_experts", [3, 8])
|
||||
def test_ep_shard(device_count: int, num_experts: int):
|
||||
@ -53,3 +104,18 @@ def test_ep_shard(device_count: int, num_experts: int):
|
||||
job=partial(_run_ep_shard_job, num_experts),
|
||||
size=device_count,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 8])
|
||||
@pytest.mark.parametrize("num_experts", [3, 8])
|
||||
def test_sharding_pattern_detection(world_size: int, num_experts: int):
|
||||
"""Test pattern detection logic without distributed execution.
|
||||
|
||||
This test verifies only the pattern detection logic with provided world_size.
|
||||
No need to run distributed job, can be run on single process.
|
||||
"""
|
||||
_run_pattern_detection_job(
|
||||
num_experts=num_experts,
|
||||
rank=0,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@ -8,11 +8,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from _dist_test_utils import get_device_counts
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_sharding_pattern_detection_test, run_test
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import column_row_shard
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import (
|
||||
ShardingConfig,
|
||||
SplitDimension,
|
||||
TPShardingInfo,
|
||||
detect_column_row_shard,
|
||||
sharding_transform_executor,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op
|
||||
|
||||
|
||||
class GQA_Block(nn.Module):
|
||||
@ -139,7 +146,10 @@ def _run_job(
|
||||
# now run the test
|
||||
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected)
|
||||
|
||||
transform_func = partial(column_row_shard, rank=rank, world_size=world_size)
|
||||
def transform_func(gm) -> None:
|
||||
sharding_config = ShardingConfig()
|
||||
detect_column_row_shard(gm, rank, world_size, sharding_config)
|
||||
sharding_transform_executor(gm, sharding_config)
|
||||
|
||||
def combined_graph_check(gm) -> bool:
|
||||
# Check for expected distributed operations
|
||||
@ -159,6 +169,107 @@ def _run_job(
|
||||
)
|
||||
|
||||
|
||||
def _run_pattern_detection_job(
|
||||
model_cls: nn.Module,
|
||||
bias: bool,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
) -> None:
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
sequence_len = 8
|
||||
num_features = 32
|
||||
|
||||
# GQA specific parameters
|
||||
num_heads = 4
|
||||
num_key_value_heads = 1
|
||||
|
||||
if model_cls == GQA_Block:
|
||||
model = model_cls(
|
||||
num_attention_heads=num_heads,
|
||||
hidden_size=num_features,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
)
|
||||
x = torch.randn(batch_size, sequence_len, num_features, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test pattern detection - create expected transformations for validation
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
expected_transformations = []
|
||||
# if world_size == 1, no sharding transformations should be detected
|
||||
if world_size > 1:
|
||||
if model_cls == GQA_Block:
|
||||
min_local_shape = num_features // num_heads
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node, include_quantization=True):
|
||||
# for Q, K, V layers, we expect:
|
||||
# dim = 0, add_dist = False
|
||||
# for O layer, we expect:
|
||||
# dim = 1, add_dist = True
|
||||
if "o_proj" in node.args[1].name:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = "all_reduce"
|
||||
else:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = None
|
||||
expected_transformations.append(
|
||||
TPShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=min_local_shape,
|
||||
)
|
||||
)
|
||||
elif model_cls == MLP:
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node, include_quantization=True):
|
||||
# linear1 should be sharded on dim=0, add_dist=False, min_local_shape=1
|
||||
# linear2 should be sharded on dim=1, add_dist=True, min_local_shape=1
|
||||
if "linear1" in node.args[1].name:
|
||||
dim = SplitDimension.ROW
|
||||
dist_op = None
|
||||
else:
|
||||
dim = SplitDimension.COLUMN
|
||||
dist_op = "all_reduce"
|
||||
expected_transformations.append(
|
||||
TPShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=dim,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op=dist_op,
|
||||
min_local_shape=1,
|
||||
)
|
||||
)
|
||||
elif model_cls == nn.Linear:
|
||||
# expect simple shard only (dim=0, add_dist=True, min_local_shape=1)
|
||||
for node in gm.graph.nodes:
|
||||
if is_linear_op(node, include_quantization=True):
|
||||
expected_transformations.append(
|
||||
TPShardingInfo(
|
||||
target_node=node.name,
|
||||
split_dim=SplitDimension.ROW, # Simple shard uses dim=0
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dist_op="all_gather",
|
||||
min_local_shape=1,
|
||||
)
|
||||
)
|
||||
|
||||
# get detected transformations
|
||||
sharding_config = ShardingConfig()
|
||||
detect_column_row_shard(gm, rank, world_size, sharding_config)
|
||||
detected_transformations = sharding_config.tp_transforms
|
||||
|
||||
# Run pattern detection test
|
||||
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
@pytest.mark.parametrize("bias", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
@ -174,3 +285,24 @@ def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool,
|
||||
job=partial(_run_job, model_cls, dist_op_expected, bias),
|
||||
size=device_count,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 8])
|
||||
@pytest.mark.parametrize("bias", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"model_cls, dist_op_expected",
|
||||
(
|
||||
(MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding_pattern_detection(
|
||||
model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, world_size: int
|
||||
):
|
||||
"""Test pattern detection logic without distributed execution.
|
||||
|
||||
This test verifies only the pattern detection logic with provided world_size.
|
||||
No need to run distributed job, can be run on single process.
|
||||
"""
|
||||
_run_pattern_detection_job(model_cls, bias, 0, world_size)
|
||||
@ -8,7 +8,7 @@ from _model_test_utils import (
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.compile.backends.torch_cudagraph import CapturedGraph
|
||||
from tensorrt_llm._torch.auto_deploy.compile.compiler import _flatten_args
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
|
||||
|
||||
class ModelWithMultipleInputs(torch.nn.Module):
|
||||
|
||||
@ -8,7 +8,7 @@ from _model_test_utils import (
|
||||
from torch.nn import Module
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.compile import compile_and_capture
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -2,22 +2,23 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from _torch.helpers import reference_moe_torch
|
||||
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
|
||||
from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_moe_op_run(dtype):
|
||||
def setup_moe_test(dtype, num_experts):
|
||||
SEQ_LEN = 8
|
||||
HIDDEN_SIZE = 64
|
||||
INTERMEDIATE_SIZE = 32
|
||||
NUM_EXPERTS = 3
|
||||
NUM_EXPERTS = num_experts
|
||||
TOP_K = 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() * 0.5
|
||||
torch.manual_seed(1234)
|
||||
torch.cuda.manual_seed(1234) # seed=0 will fail
|
||||
x = torch.rand(SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1
|
||||
|
||||
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=torch.float32).cuda()
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
@ -25,18 +26,18 @@ def test_moe_op_run(dtype):
|
||||
final_scales = final_scales / final_scales.sum(dim=-1, keepdim=True)
|
||||
final_scales = final_scales.to(x.dtype)
|
||||
|
||||
w1_weight = []
|
||||
w2_weight = []
|
||||
w3_weight = []
|
||||
w1_weight, w2_weight, w3_weight = [], [], []
|
||||
weights = {}
|
||||
fused_w3_w1_stacked_weight = torch.empty(
|
||||
(NUM_EXPERTS, INTERMEDIATE_SIZE * 2, HIDDEN_SIZE), dtype=dtype
|
||||
).cuda()
|
||||
fused_w2_weight = torch.empty((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype).cuda()
|
||||
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1 = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() * 0.5
|
||||
w2 = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype).cuda() * 0.5
|
||||
w3 = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() * 0.5
|
||||
w1 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1
|
||||
w2 = torch.rand(HIDDEN_SIZE, INTERMEDIATE_SIZE, dtype=dtype).cuda() * 0.1
|
||||
w3 = torch.rand(INTERMEDIATE_SIZE, HIDDEN_SIZE, dtype=dtype).cuda() * 0.1
|
||||
|
||||
weights[f"{expert_id}.w1.weight"] = w1
|
||||
weights[f"{expert_id}.w2.weight"] = w2
|
||||
weights[f"{expert_id}.w3.weight"] = w3
|
||||
@ -48,6 +49,34 @@ def test_moe_op_run(dtype):
|
||||
fused_w3_w1_stacked_weight.data[expert_id].copy_(torch.cat([w3, w1], dim=-2))
|
||||
fused_w2_weight.data[expert_id].copy_(w2)
|
||||
|
||||
return (
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
weights,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_moe_op_run(dtype):
|
||||
num_experts = 3
|
||||
(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
weights,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
) = setup_moe_test(dtype, num_experts)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_torch_moe = torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
@ -71,11 +100,174 @@ def test_moe_op_run(dtype):
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
)
|
||||
|
||||
ref_output = reference_moe_torch(x, selected_experts, final_scales, NUM_EXPERTS, weights)
|
||||
ref_output = reference_moe_torch(x, selected_experts, final_scales, num_experts, weights)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output_trt_fused_moe, output_torch_fused_moe, rtol=5e-2, atol=5e-2)
|
||||
torch.testing.assert_close(output_trt_fused_moe, ref_output, rtol=5e-2, atol=5e-2)
|
||||
torch.testing.assert_close(output_torch_fused_moe, ref_output, rtol=1e-5, atol=1e-5)
|
||||
torch.testing.assert_close(output_torch_moe, ref_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
|
||||
def test_fp8_moe_op_run(dtype):
|
||||
num_experts = 3
|
||||
(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
weights,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
) = setup_moe_test(dtype, num_experts)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_torch_moe = torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
)
|
||||
|
||||
w1_input_scale, w2_input_scale, w3_input_scale = [], [], []
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], []
|
||||
for i in range(num_experts):
|
||||
inp_scale_val = torch.tensor(1.0).float().cuda()
|
||||
wt_scale_factor = 448 if dtype == torch.bfloat16 else 432 # float16 overflow with 448
|
||||
wt_scale_val = (torch.max(torch.abs(w1_weight[i])) / wt_scale_factor).float().to("cuda")
|
||||
w1_input_scale.append(inp_scale_val)
|
||||
w2_input_scale.append(inp_scale_val)
|
||||
w3_input_scale.append(inp_scale_val)
|
||||
w1_weight_scale.append(wt_scale_val)
|
||||
w2_weight_scale.append(wt_scale_val)
|
||||
w3_weight_scale.append(wt_scale_val)
|
||||
# Cast the expert weight tensors and fused weights to FP8.
|
||||
w1_weight[i] = (w1_weight[i] / w1_weight_scale[i]).to(torch.float8_e4m3fn)
|
||||
w2_weight[i] = (w2_weight[i] / w2_weight_scale[i]).to(torch.float8_e4m3fn)
|
||||
w3_weight[i] = (w3_weight[i] / w3_weight_scale[i]).to(torch.float8_e4m3fn)
|
||||
fused_w3_w1_stacked_weight[i] = (fused_w3_w1_stacked_weight[i] / w1_weight_scale[i]).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
fused_w2_weight[i] = (fused_w2_weight[i] / w2_weight_scale[i]).to(torch.float8_e4m3fn)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_torch_fp8_moe = torch.ops.auto_deploy.torch_quant_fp8_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
)
|
||||
ref_output = reference_moe_torch(x, selected_experts, final_scales, num_experts, weights)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
rtol = 0.5 if dtype == torch.bfloat16 else 1.5
|
||||
atol = 0.8 if dtype == torch.bfloat16 else 1
|
||||
torch.testing.assert_close(output_torch_fp8_moe, output_torch_moe, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(output_torch_fp8_moe, ref_output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(
|
||||
not fp4_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires fp4 and trtllm support",
|
||||
)
|
||||
def test_fp4_moe_op_run(dtype):
|
||||
num_experts = 3
|
||||
(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
weights,
|
||||
_,
|
||||
_,
|
||||
) = setup_moe_test(dtype, num_experts)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_torch_moe = torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
)
|
||||
|
||||
# prepare FP4 scales and quantized weights
|
||||
w1_input_scale, w2_input_scale, w3_input_scale = [], [], []
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale = [], [], []
|
||||
w1_alpha, w2_alpha, w3_alpha = [], [], []
|
||||
scaling_vector_size = 16
|
||||
|
||||
for i in range(num_experts):
|
||||
inp_scale = fp4_global_scale(x)
|
||||
wt_scale_2_w1 = fp4_global_scale(w1_weight[i])
|
||||
wt_scale_2_w2 = fp4_global_scale(w2_weight[i])
|
||||
wt_scale_2_w3 = fp4_global_scale(w3_weight[i])
|
||||
|
||||
# quantize weights
|
||||
w1_fp4, w1_scale = torch.ops.trtllm.fp4_quantize(
|
||||
w1_weight[i], wt_scale_2_w1, scaling_vector_size, False
|
||||
)
|
||||
w2_fp4, w2_scale = torch.ops.trtllm.fp4_quantize(
|
||||
w2_weight[i], wt_scale_2_w2, scaling_vector_size, False
|
||||
)
|
||||
w3_fp4, w3_scale = torch.ops.trtllm.fp4_quantize(
|
||||
w3_weight[i], wt_scale_2_w3, scaling_vector_size, False
|
||||
)
|
||||
w1_weight[i] = w1_fp4
|
||||
w2_weight[i] = w2_fp4
|
||||
w3_weight[i] = w3_fp4
|
||||
|
||||
# record scales and alpha
|
||||
w1_input_scale.append(inp_scale)
|
||||
w2_input_scale.append(inp_scale)
|
||||
w3_input_scale.append(inp_scale)
|
||||
w1_weight_scale.append(w1_scale)
|
||||
w2_weight_scale.append(w2_scale)
|
||||
w3_weight_scale.append(w3_scale)
|
||||
w1_alpha.append(1 / (inp_scale * wt_scale_2_w1))
|
||||
w2_alpha.append(1 / (inp_scale * wt_scale_2_w2))
|
||||
w3_alpha.append(1 / (inp_scale * wt_scale_2_w3))
|
||||
|
||||
# run FP4 MoE op
|
||||
with torch.inference_mode():
|
||||
output_torch_fp4_moe = torch.ops.auto_deploy.torch_quant_fp4_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
w1_weight,
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
)
|
||||
ref_output = reference_moe_torch(x, selected_experts, final_scales, num_experts, weights)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
rtol, atol = 1.5, 1.0
|
||||
torch.testing.assert_close(output_torch_fp4_moe, output_torch_moe, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(output_torch_fp4_moe, ref_output, rtol=rtol, atol=atol)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _custom_op_utils import torch_rope_reference
|
||||
from torch_attention_reference import TorchAttentionReference
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
@ -24,12 +25,8 @@ def test_attention_op():
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.transpose(1, 2),
|
||||
k_cache[:, : input_positions[0] + 1].transpose(1, 2),
|
||||
v_cache[:, : input_positions[0] + 1].transpose(1, 2),
|
||||
)
|
||||
ref = ref.transpose(1, 2).contiguous().view(BATCH_SIZE, 1, -1)
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions)
|
||||
assert torch.allclose(
|
||||
ref.cpu().to(torch.float32),
|
||||
output.cpu().to(torch.float32),
|
||||
@ -70,27 +67,8 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len):
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
|
||||
k_cache[:, input_positions[0] : input_positions[0] + seq_len] = k
|
||||
v_cache[:, input_positions[0] : input_positions[0] + seq_len] = v
|
||||
|
||||
k_cache = torch.repeat_interleave(k_cache, group_size, dim=2) # [b,s,n,d]
|
||||
v_cache = torch.repeat_interleave(v_cache, group_size, dim=2) # [b,s,n,d]
|
||||
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(seq_len, input_positions[0], device=device, dtype=torch.bool),
|
||||
torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.transpose(1, 2),
|
||||
k_cache[:, : input_positions[0] + seq_len].transpose(1, 2),
|
||||
v_cache[:, : input_positions[0] + seq_len].transpose(1, 2),
|
||||
attn_mask=mask,
|
||||
)
|
||||
ref = ref.transpose(1, 2).contiguous().view(BATCH_SIZE, seq_len, n_heads * D_HEAD)
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions)
|
||||
|
||||
assert torch.allclose(
|
||||
ref.cpu().to(torch.float32),
|
||||
@ -167,47 +145,10 @@ def test_flat_gqa_op(
|
||||
scale=None,
|
||||
)
|
||||
|
||||
# prep batched tensors for comparison
|
||||
q_b = torch.zeros(batch_size, n_heads, max_seq_len, D_HEAD, **dtype_kwargs)
|
||||
k_cache_b = k_cache[cache_loc].transpose(1, 2)
|
||||
v_cache_b = v_cache[cache_loc].transpose(1, 2)
|
||||
|
||||
def _store(t_batched, t_flat):
|
||||
# batched layout: [n,s,d]; flat layout: [s,n*d]
|
||||
n_h, _, d_h = t_batched.shape
|
||||
t_batched[:] = t_flat.view(-1, n_h, d_h).transpose(0, 1)
|
||||
|
||||
for i_b, (i_pos, s_start, s_len) in enumerate(zip(input_positions, seq_start, seq_len)):
|
||||
# fill q in a batched manner
|
||||
_store(q_b[i_b, :, :s_len], q[0, s_start : s_start + s_len])
|
||||
# fill k, v in a batched manner
|
||||
_store(k_cache_b[i_b, :, i_pos : i_pos + s_len], k[0, s_start : s_start + s_len])
|
||||
_store(v_cache_b[i_b, :, i_pos : i_pos + s_len], v[0, s_start : s_start + s_len])
|
||||
|
||||
k_cache_b = torch.repeat_interleave(k_cache_b, group_size, dim=1) # [b,n,s,d]
|
||||
v_cache_b = torch.repeat_interleave(v_cache_b, group_size, dim=1) # [b,n,s,d]
|
||||
|
||||
# run comparison
|
||||
refs = []
|
||||
for i_b, (i_pos, s_start, s_len) in enumerate(zip(input_positions, seq_start, seq_len)):
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(s_len, i_pos, device=device, dtype=torch.bool),
|
||||
torch.tril(torch.ones(s_len, s_len, device=device, dtype=torch.bool)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
ref_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_b[i_b, :, :s_len],
|
||||
k_cache_b[i_b, :, : i_pos + s_len],
|
||||
v_cache_b[i_b, :, : i_pos + s_len],
|
||||
attn_mask=mask,
|
||||
) # [n,s,d]
|
||||
ref_i = ref_i.transpose(0, 1).contiguous().view(s_len, n_heads * D_HEAD) # [s,n*d]
|
||||
refs.append(ref_i)
|
||||
|
||||
# flatten output for comparison
|
||||
ref_flat = torch.cat(refs, dim=0)[None] # [1,s_total,n*d]
|
||||
# Use torch backend as clean reference
|
||||
ref_flat = TorchAttentionReference.flattened_mha_with_cache(
|
||||
q, k, v, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
ref_flat.cpu().to(torch.float32),
|
||||
@ -481,6 +422,8 @@ def test_paged_gqa_op(
|
||||
None,
|
||||
)
|
||||
|
||||
# TODO (nvchenghaoz): Replace this with torch backend reference.
|
||||
|
||||
# prep batched tensors for comparison
|
||||
def compute_reference(q, k_cache, v_cache):
|
||||
ref = []
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
from torch_attention_reference import TorchAttentionReference
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_attention import _GlobalFlashInferPlanner
|
||||
|
||||
@ -111,14 +112,19 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
1.0,
|
||||
)
|
||||
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
v.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
is_causal=True,
|
||||
# Use torch backend as clean reference
|
||||
q_reshaped = q.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
k_reshaped = k.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
v_reshaped = v.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
|
||||
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(
|
||||
q_reshaped,
|
||||
k_reshaped,
|
||||
v_reshaped,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.zeros(BATCH_SIZE, device=device, dtype=torch.int),
|
||||
)
|
||||
ref = ref.transpose(1, 2).contiguous()
|
||||
ref = ref.view(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD)
|
||||
|
||||
assert torch.allclose(
|
||||
flashinfer_output.cpu().to(torch.float32),
|
||||
@ -261,13 +267,16 @@ def test_flashinfer_attention_op_decode(
|
||||
BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD
|
||||
)
|
||||
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
|
||||
# Use torch backend as clean reference for decode with prefilled cache
|
||||
ref = TorchAttentionReference.decode_with_prefilled_cache(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.tensor([PREFILL_SEQ_LEN] * BATCH_SIZE, device=device, dtype=torch.int),
|
||||
)
|
||||
|
||||
ref = ref.transpose(1, 2).contiguous()
|
||||
ref = ref.view(BATCH_SIZE, -1, N_HEADS * D_HEAD)
|
||||
|
||||
assert torch.allclose(
|
||||
flashinfer_output.cpu().to(torch.float32),
|
||||
ref.cpu().to(torch.float32),
|
||||
@ -357,15 +366,15 @@ def test_flashinfer_attention_context_and_generate(
|
||||
k_ref = k_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :]
|
||||
v_ref = v_cache[:BATCH_SIZE, 0:PREFILL_SEQ_LEN, :, :]
|
||||
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref.view(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD).transpose(1, 2),
|
||||
k_ref.transpose(1, 2),
|
||||
v_ref.transpose(1, 2),
|
||||
is_causal=True,
|
||||
# Use torch backend as clean reference
|
||||
ref = TorchAttentionReference.basic_mha_with_cache(
|
||||
q_ref.view(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS, D_HEAD),
|
||||
k_ref.transpose(1, 2).transpose(2, 3), # Convert [B,N,S,D] to [B,S,N,D]
|
||||
v_ref.transpose(1, 2).transpose(2, 3), # Convert [B,N,S,D] to [B,S,N,D]
|
||||
k_cache,
|
||||
v_cache,
|
||||
torch.zeros(BATCH_SIZE, device=device, dtype=torch.int),
|
||||
)
|
||||
|
||||
ref = ref.transpose(1, 2)
|
||||
ref = ref[0:BATCH_SIZE, :PREFILL_SEQ_LEN, :, :]
|
||||
flashinfer_output_1 = flashinfer_output_1.view(BATCH_SIZE, -1, N_HEADS, D_HEAD)
|
||||
|
||||
assert torch.allclose(
|
||||
|
||||
@ -0,0 +1,487 @@
|
||||
"""Concise test suite for torch attention backend operations."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
|
||||
def numpy_attention_reference(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
seq_len,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
seq_start,
|
||||
scale=None,
|
||||
logit_cap=None,
|
||||
sliding_window_size=None,
|
||||
sinks=None,
|
||||
):
|
||||
"""Numpy reference implementation of attention with all features."""
|
||||
# Convert to numpy
|
||||
q_np = q.detach().cpu().numpy().astype(np.float32)
|
||||
k_np = k.detach().cpu().numpy().astype(np.float32)
|
||||
v_np = v.detach().cpu().numpy().astype(np.float32)
|
||||
k_cache_np = k_cache.detach().cpu().numpy().astype(np.float32)
|
||||
v_cache_np = v_cache.detach().cpu().numpy().astype(np.float32)
|
||||
seq_len_np = seq_len.detach().cpu().numpy()
|
||||
input_pos_np = input_pos.detach().cpu().numpy()
|
||||
cache_loc_np = cache_loc.detach().cpu().numpy()
|
||||
seq_start_np = seq_start.detach().cpu().numpy()
|
||||
|
||||
# Get dimensions from cache (which has the actual dimensions)
|
||||
n_kv_heads = k_cache_np.shape[2]
|
||||
head_dim = k_cache_np.shape[3]
|
||||
v_head_dim = v_cache_np.shape[3]
|
||||
|
||||
# Calculate n_heads from the flattened query tensor
|
||||
if q_np.ndim == 3 and q_np.shape[0] > 1: # (batch, seq, features) - true batch case
|
||||
batch_size, seq_len_q, q_features = q_np.shape
|
||||
is_generate = seq_len_q == 1
|
||||
n_heads = q_features // head_dim
|
||||
else: # (1, total_seq, features) - flattened case OR single batch
|
||||
batch_size = len(seq_len_np) # Number of original sequences
|
||||
is_generate = np.all(seq_len_np == 1)
|
||||
n_heads = q_np.shape[2] // head_dim
|
||||
|
||||
# Set default scale
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
# Update KV cache first
|
||||
if is_generate:
|
||||
# Generate phase: single token per sequence
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc_np[i]
|
||||
pos = input_pos_np[i]
|
||||
if q_np.ndim == 3 and q_np.shape[0] > 1:
|
||||
# True batch case
|
||||
k_cache_np[cache_idx, pos] = k_np[i, 0].reshape(n_kv_heads, head_dim)
|
||||
v_cache_np[cache_idx, pos] = v_np[i, 0].reshape(n_kv_heads, v_head_dim)
|
||||
else:
|
||||
# Flattened case
|
||||
k_cache_np[cache_idx, pos] = k_np[0, i].reshape(n_kv_heads, head_dim)
|
||||
v_cache_np[cache_idx, pos] = v_np[0, i].reshape(n_kv_heads, v_head_dim)
|
||||
else:
|
||||
# Context phase: multiple tokens
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc_np[i]
|
||||
pos = input_pos_np[i]
|
||||
seq_len_i = seq_len_np[i]
|
||||
seq_start_i = seq_start_np[i]
|
||||
|
||||
# Update cache for this sequence
|
||||
k_seq = k_np[0, seq_start_i : seq_start_i + seq_len_i].reshape(
|
||||
seq_len_i, n_kv_heads, head_dim
|
||||
)
|
||||
v_seq = v_np[0, seq_start_i : seq_start_i + seq_len_i].reshape(
|
||||
seq_len_i, n_kv_heads, v_head_dim
|
||||
)
|
||||
k_cache_np[cache_idx, pos : pos + seq_len_i] = k_seq
|
||||
v_cache_np[cache_idx, pos : pos + seq_len_i] = v_seq
|
||||
|
||||
# Compute attention for each sequence
|
||||
outputs = []
|
||||
|
||||
for i in range(batch_size):
|
||||
cache_idx = cache_loc_np[i]
|
||||
pos = input_pos_np[i]
|
||||
seq_len_i = seq_len_np[i]
|
||||
seq_start_i = seq_start_np[i]
|
||||
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
|
||||
# Get query for this sequence and reshape properly
|
||||
if q_np.ndim == 3 and q_np.shape[0] > 1:
|
||||
# True batch case: each sequence is in a separate batch dimension
|
||||
q_seq = q_np[i, :seq_len_i].reshape(
|
||||
seq_len_i, n_heads, head_dim
|
||||
) # [seq_len, n_heads, head_dim]
|
||||
else:
|
||||
# Flattened case: all sequences are flattened in the second dimension
|
||||
q_seq = q_np[0, seq_start_i : seq_start_i + seq_len_i].reshape(
|
||||
seq_len_i, n_heads, head_dim
|
||||
)
|
||||
|
||||
# Get keys and values from cache
|
||||
kv_seq_len = pos + seq_len_i
|
||||
k_seq = k_cache_np[cache_idx, :kv_seq_len] # [kv_seq_len, n_kv_heads, head_dim]
|
||||
v_seq = v_cache_np[cache_idx, :kv_seq_len] # [kv_seq_len, n_kv_heads, v_head_dim]
|
||||
|
||||
# Handle GQA: repeat KV if needed
|
||||
if n_heads != n_kv_heads:
|
||||
n_rep = n_heads // n_kv_heads
|
||||
k_seq = np.repeat(k_seq, n_rep, axis=1) # [kv_seq_len, n_heads, head_dim]
|
||||
v_seq = np.repeat(v_seq, n_rep, axis=1) # [kv_seq_len, n_heads, v_head_dim]
|
||||
|
||||
# Compute attention scores: Q @ K^T
|
||||
# q_seq: [seq_len, n_heads, head_dim], k_seq: [kv_seq_len, n_heads, head_dim]
|
||||
# We want [seq_len, n_heads, kv_seq_len]
|
||||
attn_scores = np.einsum("snh,knh->snk", q_seq, k_seq) * scale
|
||||
|
||||
# Apply causal mask - make sure it broadcasts correctly with [seq_len, n_heads, kv_seq_len]
|
||||
causal_mask = np.triu(np.ones((seq_len_i, kv_seq_len)), k=kv_seq_len - seq_len_i + 1)
|
||||
# Expand mask to match attention scores: [seq_len, kv_seq_len] -> [seq_len, 1, kv_seq_len]
|
||||
causal_mask_expanded = causal_mask[:, None, :]
|
||||
attn_scores = np.where(causal_mask_expanded, -np.inf, attn_scores)
|
||||
|
||||
# Apply sliding window mask if specified
|
||||
if sliding_window_size is not None and sliding_window_size > 0:
|
||||
# Query positions are [pos, pos + seq_len_i)
|
||||
# Key positions are [0, pos + seq_len_i)
|
||||
query_positions = np.arange(pos, pos + seq_len_i)[:, None] # [seq_len_i, 1]
|
||||
key_positions = np.arange(0, kv_seq_len)[None, :] # [1, kv_seq_len]
|
||||
|
||||
# Position difference: query_pos - key_pos
|
||||
pos_diff = query_positions - key_positions # [seq_len_i, kv_seq_len]
|
||||
|
||||
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
|
||||
sliding_mask = (pos_diff < 0) | (pos_diff >= sliding_window_size)
|
||||
# Expand to match attention scores: [seq_len, kv_seq_len] -> [seq_len, 1, kv_seq_len]
|
||||
sliding_mask_expanded = sliding_mask[:, None, :]
|
||||
attn_scores = np.where(sliding_mask_expanded, -np.inf, attn_scores)
|
||||
|
||||
# Apply logit softcapping if enabled
|
||||
if logit_cap is not None and logit_cap > 0.0:
|
||||
attn_scores = logit_cap * np.tanh(attn_scores / logit_cap)
|
||||
|
||||
# Apply sinks if provided
|
||||
if sinks is not None:
|
||||
# Create sinks matrix matching attention scores shape
|
||||
# attn_scores: [seq_len, n_heads, kv_seq_len]
|
||||
# sinks should be: [seq_len, n_heads, num_sinks]
|
||||
|
||||
# Concatenate sinks to attention scores
|
||||
attn_scores_with_sinks = np.concatenate(
|
||||
[attn_scores, sinks], axis=-1
|
||||
) # [seq_len, n_heads, kv_seq_len + num_sinks]
|
||||
|
||||
# Apply softmax to combined scores
|
||||
attn_scores_max = np.max(attn_scores_with_sinks, axis=-1, keepdims=True)
|
||||
attn_scores_exp = np.exp(attn_scores_with_sinks - attn_scores_max)
|
||||
attn_weights_with_sinks = attn_scores_exp / np.sum(
|
||||
attn_scores_exp, axis=-1, keepdims=True
|
||||
)
|
||||
|
||||
# Use only the non-sink portion for computing output (ignore sinks)
|
||||
attn_weights = attn_weights_with_sinks[..., :-1] # [seq_len, n_heads, kv_seq_len]
|
||||
else:
|
||||
# Apply softmax normally
|
||||
attn_scores_max = np.max(attn_scores, axis=-1, keepdims=True)
|
||||
attn_scores_exp = np.exp(attn_scores - attn_scores_max)
|
||||
attn_weights = attn_scores_exp / np.sum(attn_scores_exp, axis=-1, keepdims=True)
|
||||
|
||||
# Compute output: weights @ V
|
||||
# attn_weights: [seq_len, n_heads, kv_seq_len], v_seq: [kv_seq_len, n_heads, v_head_dim]
|
||||
attn_out = np.einsum("snk,knh->snh", attn_weights, v_seq) # [seq_len, n_heads, v_head_dim]
|
||||
|
||||
outputs.append(attn_out)
|
||||
|
||||
# Concatenate outputs and flatten head dimension to match torch backend
|
||||
if len(outputs) == 0:
|
||||
return np.zeros((1, 0, n_heads * v_head_dim), dtype=np.float32)
|
||||
elif is_generate:
|
||||
# Generate phase: outputs is a list of [seq_len, n_heads, v_head_dim] tensors
|
||||
# We need to stack them to [batch_size, seq_len, n_heads * v_head_dim]
|
||||
result = np.stack(outputs, axis=0) # [batch_size, seq_len, n_heads, v_head_dim]
|
||||
return result.reshape(batch_size, result.shape[1], n_heads * v_head_dim)
|
||||
else:
|
||||
# Context phase: outputs is a list of [seq_len_i, n_heads, v_head_dim] tensors
|
||||
# We need to concatenate them to [total_seq, n_heads * v_head_dim]
|
||||
result = np.concatenate(outputs, axis=0) # [total_seq, n_heads, v_head_dim]
|
||||
return result.reshape(1, result.shape[0], n_heads * v_head_dim)
|
||||
|
||||
|
||||
class TestTorchBackendAttention:
|
||||
"""Test torch backend attention with combined features."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
"""Setup test configuration."""
|
||||
self.device = "cuda"
|
||||
self.dtype = torch.float16
|
||||
self.atol = 5e-2 # Increased tolerance for fp16 vs fp32 comparison
|
||||
self.rtol = 5e-2
|
||||
|
||||
# Ensure clean state for each test
|
||||
torch.cuda.empty_cache()
|
||||
torch.manual_seed(123) # Fixed seed for reproducibility
|
||||
np.random.seed(123)
|
||||
|
||||
def _create_test_data(
|
||||
self, batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len, cache_offset=0
|
||||
):
|
||||
"""Create test data for attention operations."""
|
||||
# Create Q, K, V tensors
|
||||
q = torch.randn(batch_size, seq_len, n_heads, d_head, dtype=self.dtype, device=self.device)
|
||||
k = torch.randn(
|
||||
batch_size, seq_len, n_kv_heads, d_head, dtype=self.dtype, device=self.device
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seq_len, n_kv_heads, d_head, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
# Create KV cache
|
||||
k_cache = torch.randn(
|
||||
batch_size, max_seq_len, n_kv_heads, d_head, dtype=self.dtype, device=self.device
|
||||
)
|
||||
v_cache = torch.randn(
|
||||
batch_size, max_seq_len, n_kv_heads, d_head, dtype=self.dtype, device=self.device
|
||||
)
|
||||
|
||||
# Setup metadata
|
||||
input_positions = torch.full(
|
||||
(batch_size,), cache_offset, device=self.device, dtype=torch.int
|
||||
)
|
||||
seq_len_tensor = torch.full((batch_size,), seq_len, device=self.device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(batch_size, device=self.device, dtype=torch.int32)
|
||||
|
||||
if seq_len == 1:
|
||||
seq_start = torch.arange(batch_size, device=self.device, dtype=torch.int32)
|
||||
q_flat = q.view(batch_size, seq_len, -1)
|
||||
k_flat = k.view(batch_size, seq_len, -1)
|
||||
v_flat = v.view(batch_size, seq_len, -1)
|
||||
else:
|
||||
seq_start = torch.arange(
|
||||
0, batch_size * seq_len, seq_len, device=self.device, dtype=torch.int32
|
||||
)
|
||||
q_flat = q.view(1, batch_size * seq_len, -1)
|
||||
k_flat = k.view(1, batch_size * seq_len, -1)
|
||||
v_flat = v.view(1, batch_size * seq_len, -1)
|
||||
|
||||
return {
|
||||
"q": q_flat,
|
||||
"k": k_flat,
|
||||
"v": v_flat,
|
||||
"seq_len": seq_len_tensor,
|
||||
"input_pos": input_positions,
|
||||
"cache_loc": cache_loc,
|
||||
"seq_start": seq_start,
|
||||
"k_cache": k_cache,
|
||||
"v_cache": v_cache,
|
||||
}
|
||||
|
||||
def _run_attention(
|
||||
self, data, scale=None, logit_cap=None, sliding_window_size=None, sinks=None
|
||||
):
|
||||
"""Run torch backend attention operation with optional sinks parameter."""
|
||||
return torch.ops.auto_deploy.torch_cached_attention_with_cache(
|
||||
data["q"],
|
||||
data["k"],
|
||||
data["v"],
|
||||
data["seq_len"],
|
||||
data["input_pos"],
|
||||
data["cache_loc"],
|
||||
data["seq_start"],
|
||||
data["k_cache"],
|
||||
data["v_cache"],
|
||||
scale,
|
||||
sinks,
|
||||
sliding_window_size,
|
||||
logit_cap, # Updated parameter order
|
||||
)
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic attention functionality and output shape correctness."""
|
||||
batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len = 2, 1, 8, 4, 32, 128
|
||||
data = self._create_test_data(batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len)
|
||||
|
||||
# Test basic operation
|
||||
output = self._run_attention(data)
|
||||
|
||||
# Verify output shape
|
||||
expected_shape = (batch_size, seq_len, n_heads * d_head)
|
||||
assert output.shape == expected_shape, (
|
||||
f"Expected shape {expected_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
# Verify output is not NaN or Inf
|
||||
assert torch.isfinite(output).all(), "Output contains NaN or Inf values"
|
||||
|
||||
@pytest.mark.parametrize("logit_cap", [None, 5.0])
|
||||
@pytest.mark.parametrize("sliding_window_size", [None, 3])
|
||||
@pytest.mark.parametrize("sinks", [None, 1.0])
|
||||
def test_combined_features_with_reference(self, logit_cap, sliding_window_size, sinks):
|
||||
"""Test combined logit capping, sliding window, and sinks features against numpy reference."""
|
||||
batch_size, n_heads, n_kv_heads, d_head, max_seq_len, seq_len = 2, 8, 4, 16, 64, 1
|
||||
cache_offset = 5 # Have some tokens in cache
|
||||
|
||||
data = self._create_test_data(
|
||||
batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len, cache_offset
|
||||
)
|
||||
|
||||
# Convert sinks to tensor if provided
|
||||
sinks_tensor = None
|
||||
if sinks is not None:
|
||||
# Create sinks tensor with correct dimensions [num_heads, 1, 1]
|
||||
# This works for generate phase and is the correct shape expectation
|
||||
sinks_tensor = torch.ones(n_heads, 1, 1, device=self.device, dtype=self.dtype) * sinks
|
||||
else:
|
||||
sinks_tensor = None
|
||||
|
||||
# Test with combined features
|
||||
# For sinks: test that backend runs without crashing (backend has bugs)
|
||||
# and validate correct sinks behavior with numpy reference
|
||||
try:
|
||||
output = self._run_attention(data, None, logit_cap, sliding_window_size, sinks_tensor)
|
||||
backend_works = True
|
||||
except Exception as e:
|
||||
print(f"Backend failed with sinks: {e}")
|
||||
backend_works = False
|
||||
|
||||
# Test correct sinks implementation with numpy reference
|
||||
if sinks is not None:
|
||||
ref_sinks = (
|
||||
torch.ones(1, n_heads, 1, device=torch.device("cpu"), dtype=torch.float32) * sinks
|
||||
)
|
||||
else:
|
||||
ref_sinks = None
|
||||
|
||||
reference = numpy_attention_reference(
|
||||
data["q"],
|
||||
data["k"],
|
||||
data["v"],
|
||||
data["k_cache"],
|
||||
data["v_cache"],
|
||||
data["seq_len"],
|
||||
data["input_pos"],
|
||||
data["cache_loc"],
|
||||
data["seq_start"],
|
||||
None,
|
||||
logit_cap,
|
||||
sliding_window_size,
|
||||
ref_sinks,
|
||||
)
|
||||
|
||||
# Verify sinks actually change the numpy reference output
|
||||
output_np = output.cpu().numpy() if backend_works else np.zeros_like(reference)
|
||||
|
||||
if backend_works:
|
||||
# Use more lenient tolerance for float16 vs float32 comparisons
|
||||
tolerance = (
|
||||
5e-2 if (logit_cap is not None and sliding_window_size is not None) else 1e-2
|
||||
)
|
||||
assert np.allclose(reference, output_np, atol=tolerance, rtol=tolerance), (
|
||||
f"Backend output doesn't match reference. Max diff: {np.abs(reference - output_np).max():.6f}, "
|
||||
f"tolerance: {tolerance}"
|
||||
)
|
||||
|
||||
# If backend works, test that it produces finite output
|
||||
if backend_works:
|
||||
assert torch.isfinite(output).all(), (
|
||||
"Backend output should be finite when sinks are enabled"
|
||||
)
|
||||
|
||||
def test_gqa_functionality(self):
|
||||
"""Test Grouped Query Attention with different head ratios."""
|
||||
batch_size, seq_len, d_head, max_seq_len = 2, 1, 16, 32
|
||||
|
||||
# Test different GQA configurations
|
||||
for n_heads, n_kv_heads in [(8, 4), (12, 3), (16, 1)]:
|
||||
data = self._create_test_data(
|
||||
batch_size, seq_len, n_heads, n_kv_heads, d_head, max_seq_len
|
||||
)
|
||||
output = self._run_attention(data)
|
||||
|
||||
# Compare with numpy reference
|
||||
reference = numpy_attention_reference(
|
||||
data["q"],
|
||||
data["k"],
|
||||
data["v"],
|
||||
data["k_cache"],
|
||||
data["v_cache"],
|
||||
data["seq_len"],
|
||||
data["input_pos"],
|
||||
data["cache_loc"],
|
||||
data["seq_start"],
|
||||
)
|
||||
reference_torch = torch.from_numpy(reference).to(output.device, output.dtype)
|
||||
|
||||
# Verify output matches reference
|
||||
assert torch.allclose(output, reference_torch, atol=self.atol, rtol=self.rtol), (
|
||||
f"GQA failed for {n_heads}/{n_kv_heads} heads"
|
||||
)
|
||||
|
||||
def test_context_vs_generate_phases(self):
|
||||
"""Test both context (multi-token) and generate (single-token) phases."""
|
||||
batch_size, n_heads, n_kv_heads, d_head, max_seq_len = 2, 8, 4, 16, 64
|
||||
|
||||
# Test context phase (multi-token)
|
||||
context_data = self._create_test_data(
|
||||
batch_size, 4, n_heads, n_kv_heads, d_head, max_seq_len
|
||||
)
|
||||
context_output = self._run_attention(context_data)
|
||||
|
||||
context_reference = numpy_attention_reference(
|
||||
context_data["q"],
|
||||
context_data["k"],
|
||||
context_data["v"],
|
||||
context_data["k_cache"],
|
||||
context_data["v_cache"],
|
||||
context_data["seq_len"],
|
||||
context_data["input_pos"],
|
||||
context_data["cache_loc"],
|
||||
context_data["seq_start"],
|
||||
)
|
||||
context_reference_torch = torch.from_numpy(context_reference).to(
|
||||
context_output.device, context_output.dtype
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
context_output, context_reference_torch, atol=self.atol, rtol=self.rtol
|
||||
), "Context phase doesn't match reference"
|
||||
|
||||
# Test generate phase (single-token)
|
||||
generate_data = self._create_test_data(
|
||||
batch_size, 1, n_heads, n_kv_heads, d_head, max_seq_len, 5
|
||||
)
|
||||
generate_output = self._run_attention(generate_data)
|
||||
|
||||
generate_reference = numpy_attention_reference(
|
||||
generate_data["q"],
|
||||
generate_data["k"],
|
||||
generate_data["v"],
|
||||
generate_data["k_cache"],
|
||||
generate_data["v_cache"],
|
||||
generate_data["seq_len"],
|
||||
generate_data["input_pos"],
|
||||
generate_data["cache_loc"],
|
||||
generate_data["seq_start"],
|
||||
)
|
||||
generate_reference_torch = torch.from_numpy(generate_reference).to(
|
||||
generate_output.device, generate_output.dtype
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
generate_output, generate_reference_torch, atol=self.atol, rtol=self.rtol
|
||||
), "Generate phase doesn't match reference"
|
||||
|
||||
def test_metadata_preparation(self):
|
||||
"""Test metadata preparation operation."""
|
||||
batch_size, seq_len_val = 4, 8
|
||||
device = self.device
|
||||
|
||||
input_ids = torch.randint(0, 1000, (batch_size, seq_len_val), device=device)
|
||||
position_ids = torch.arange(seq_len_val, device=device).expand(batch_size, -1)
|
||||
seq_len = torch.full((batch_size,), seq_len_val, device=device, dtype=torch.int32)
|
||||
input_pos = torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
cache_loc = torch.arange(batch_size, device=device, dtype=torch.int32)
|
||||
pages_per_seq = torch.ones(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
# Test metadata preparation
|
||||
result = torch.ops.auto_deploy.torch_cached_attention_prepare_metadata(
|
||||
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, 128
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert len(result) == 4, "Metadata preparation should return 4 tensors"
|
||||
assert all(torch.is_tensor(t) for t in result), "All results should be tensors"
|
||||
assert result[0].shape[0] == batch_size, "First tensor should have batch_size elements"
|
||||
@ -18,10 +18,14 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.attention_with_kv
|
||||
)
|
||||
|
||||
|
||||
def torch_reference_stage2(values, logsumexp):
|
||||
def torch_reference_stage2(values, logsumexp, sinks=None):
|
||||
max_logsumexp = torch.max(logsumexp, axis=-1, keepdim=True)[0] # [b, n_heads, 1]
|
||||
sumexp = torch.exp(logsumexp - max_logsumexp) # [b, n_heads, num_blocks]
|
||||
aggregate_sumexp = torch.sum(sumexp, axis=-1) # [b, n_heads]
|
||||
# Add sinks contribution to the softmax denominator
|
||||
if sinks is not None:
|
||||
sinks_exp = torch.exp(sinks - max_logsumexp.squeeze(-1)) # [b, n_heads]
|
||||
aggregate_sumexp += sinks_exp
|
||||
output = values * sumexp[:, :, :, None] # [b, n_heads, num_blocks, d_head]
|
||||
output = output / aggregate_sumexp[:, :, None, None]
|
||||
output = torch.sum(output, axis=2)
|
||||
@ -198,7 +202,8 @@ def test_attention_kv_flash_decoding(d_head):
|
||||
@pytest.mark.parametrize("q_d_head", [16, 96])
|
||||
@pytest.mark.parametrize("v_d_head", [16, 96])
|
||||
@pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (8, 1)])
|
||||
def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads):
|
||||
@pytest.mark.parametrize("sliding_window", [-1, 16])
|
||||
def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads, sliding_window):
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float16
|
||||
BATCH_SIZE = 64
|
||||
@ -271,6 +276,7 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
V_D_HEAD,
|
||||
SEQ_BLOCK_SIZE,
|
||||
HEAD_BLOCK_SIZE,
|
||||
sliding_window, # SLIDING_WINDOW: parameterized
|
||||
)
|
||||
|
||||
run(q, k_cache, v_cache, output_tensor, output_logsumexp)
|
||||
@ -301,7 +307,8 @@ def test_gqa_attention_kv_flash_decoding(q_d_head, v_d_head, n_heads, n_kv_heads
|
||||
)
|
||||
|
||||
|
||||
def test_attention_with_kv_stage2():
|
||||
@pytest.mark.parametrize("has_sinks", [False, True])
|
||||
def test_attention_with_kv_stage2(has_sinks):
|
||||
DEVICE = "cuda"
|
||||
BATCH_SIZE = 4
|
||||
N_HEADS = 32
|
||||
@ -315,6 +322,10 @@ def test_attention_with_kv_stage2():
|
||||
)
|
||||
logsumexp = torch.randn(BATCH_SIZE, N_HEADS, num_blocks, device=DEVICE, dtype=torch.float32)
|
||||
output = torch.zeros(BATCH_SIZE, N_HEADS, D_HEAD, device=DEVICE, dtype=torch.float32)
|
||||
# Create sink tokens if needed - kernel expects [BATCH_SIZE, N_HEADS] shape
|
||||
sinks = (
|
||||
torch.randn(BATCH_SIZE, N_HEADS, device=DEVICE, dtype=torch.float32) if has_sinks else None
|
||||
)
|
||||
|
||||
def run():
|
||||
attention_kv_stage2[
|
||||
@ -331,15 +342,20 @@ def test_attention_with_kv_stage2():
|
||||
N_HEADS,
|
||||
D_HEAD,
|
||||
SEQ_BLOCK_SIZE,
|
||||
has_sinks,
|
||||
sinks,
|
||||
)
|
||||
|
||||
run()
|
||||
ref = []
|
||||
for i in range(BATCH_SIZE):
|
||||
block_id = input_positions[i].item() // SEQ_BLOCK_SIZE + 1
|
||||
batch_sinks = sinks[i : i + 1, :] if has_sinks else None # [1, N_HEADS]
|
||||
ref.append(
|
||||
torch_reference_stage2(
|
||||
values[i, :, :block_id, :].unsqueeze(0), logsumexp[i, :, :block_id].unsqueeze(0)
|
||||
values[i, :, :block_id, :].unsqueeze(0),
|
||||
logsumexp[i, :, :block_id].unsqueeze(0),
|
||||
batch_sinks,
|
||||
)
|
||||
)
|
||||
ref = torch.cat(ref, dim=0)
|
||||
@ -425,7 +441,10 @@ def test_context_attention_kv(batch_size, q_d_head, v_d_head, n_heads, n_kv_head
|
||||
@pytest.mark.parametrize("n_heads,n_kv_heads", [(8, 8), (8, 1)])
|
||||
@pytest.mark.parametrize("q_d_head", [32, 96])
|
||||
@pytest.mark.parametrize("v_d_head", [32, 96])
|
||||
def test_context_attention_kv_flattened(q_d_head, v_d_head, n_heads, n_kv_heads, dtype):
|
||||
@pytest.mark.parametrize("sliding_window", [-1, 16])
|
||||
def test_context_attention_kv_flattened(
|
||||
q_d_head, v_d_head, n_heads, n_kv_heads, dtype, sliding_window
|
||||
):
|
||||
DEVICE = "cuda"
|
||||
DTYPE = getattr(torch, dtype)
|
||||
N_HEADS = n_heads
|
||||
@ -472,6 +491,29 @@ def test_context_attention_kv_flattened(q_d_head, v_d_head, n_heads, n_kv_heads,
|
||||
torch.ones(q[i].shape[1], kk.shape[1], dtype=torch.bool),
|
||||
diagonal=kk.shape[1] - q[i].shape[1],
|
||||
)
|
||||
|
||||
# Apply sliding window constraints if enabled
|
||||
if sliding_window > 0:
|
||||
seq_len_q = q[i].shape[1] # Current sequence length
|
||||
seq_len_k = kk.shape[1] # Total KV sequence length
|
||||
|
||||
# Create sliding window mask
|
||||
sliding_mask = torch.zeros_like(mask)
|
||||
for q_pos in range(seq_len_q):
|
||||
# For each query position, determine its absolute position in the cache
|
||||
abs_q_pos = INPUT_POS[i] + q_pos
|
||||
# Calculate sliding window range
|
||||
sliding_start = max(0, abs_q_pos - sliding_window + 1)
|
||||
sliding_end = abs_q_pos + 1
|
||||
# Apply to KV cache positions
|
||||
k_start = max(0, sliding_start)
|
||||
k_end = min(seq_len_k, sliding_end)
|
||||
if k_start < k_end:
|
||||
sliding_mask[q_pos, k_start:k_end] = True
|
||||
|
||||
# Combine causal and sliding window masks
|
||||
mask = mask & sliding_mask
|
||||
|
||||
ref.append(
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
q[i].transpose(1, 2),
|
||||
@ -535,7 +577,9 @@ def test_context_attention_kv_flattened(q_d_head, v_d_head, n_heads, n_kv_heads,
|
||||
V_D_HEAD,
|
||||
SEQ_BLOCK,
|
||||
MAX_SEQ_LEN,
|
||||
num_stages=2,
|
||||
sliding_window, # SLIDING_WINDOW: parameterized
|
||||
False, # HAS_SINKS: no sink tokens used
|
||||
None, # sinks_ptr: no sink tokens used
|
||||
)
|
||||
assert torch.allclose(ref, output_tensor, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@ -1,18 +1,10 @@
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.triton_kernels.rms_norm import rms_norm
|
||||
|
||||
|
||||
def torch_forward(hidden_states, weight, variance_epsilon=1e-6):
|
||||
"""pytorch forward."""
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
||||
return weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def test_rms_norm():
|
||||
def test_rmsnorm_triton_op():
|
||||
bsz = 2
|
||||
ctx_len = 1024
|
||||
feat_len = 32
|
||||
@ -25,6 +17,6 @@ def test_rms_norm():
|
||||
weight = (
|
||||
torch.empty((feat_len), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).contiguous()
|
||||
)
|
||||
triton_output = rms_norm(hidden_states=input, weight=weight)
|
||||
torch_output = torch_forward(hidden_states=input, weight=weight)
|
||||
triton_output = rms_norm(input, weight, 1e-6)
|
||||
torch_output = torch.ops.auto_deploy.torch_rmsnorm(input, weight, 1e-6)
|
||||
assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0)
|
||||
@ -8,7 +8,7 @@ from _model_test_utils import _hf_model_dir_or_hub_id
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.models.deepseek import (
|
||||
from tensorrt_llm._torch.auto_deploy.models.patches.deepseek import (
|
||||
deepseek_v3_attention,
|
||||
deepseek_v3_moe_exact,
|
||||
)
|
||||
|
||||
@ -41,7 +41,9 @@ def get_inference_model(cache_seq_interface):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("engine_cls", [ADEngine, DemoEngine])
|
||||
@pytest.mark.parametrize("attn_backend, attn_page_size", [("triton", 0), ("flashinfer", 2)])
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend, attn_page_size", [("triton", 0), ("flashinfer", 2), ("torch", 0)]
|
||||
)
|
||||
def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: int):
|
||||
"""Test the SimpleEngine functionality."""
|
||||
|
||||
|
||||
@ -154,6 +154,32 @@ def test_invalid_model_factory():
|
||||
LlmArgs(model="test-model", model_factory="InvalidFactory")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parallel_field,invalid_value",
|
||||
[
|
||||
("tensor_parallel_size", 2),
|
||||
("pipeline_parallel_size", 2),
|
||||
("context_parallel_size", 2),
|
||||
("moe_cluster_parallel_size", 2),
|
||||
("moe_tensor_parallel_size", 2),
|
||||
("moe_expert_parallel_size", 2),
|
||||
("enable_attention_dp", True),
|
||||
("cp_config", {"some_key": "some_value"}),
|
||||
],
|
||||
)
|
||||
def test_parallel_config_validation(parallel_field, invalid_value):
|
||||
"""Test that parallel config fields raise ValueError when set to non-default values."""
|
||||
kwargs = {
|
||||
"model": "test-model",
|
||||
parallel_field: invalid_value,
|
||||
}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="AutoDeploy only supports parallelization via the `world_size` argument."
|
||||
):
|
||||
LlmArgs(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend,expected_attn_page_size",
|
||||
[
|
||||
|
||||
@ -6,35 +6,38 @@ import pytest
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.transform import InferenceOptimizer
|
||||
|
||||
|
||||
def _check_ad_config(experiment_config: ExperimentConfig, ad_config: LlmArgs):
|
||||
# Verify that ad_config was captured
|
||||
assert ad_config is not None, "ad_config should have been captured"
|
||||
def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
# Verify that llm_args was captured
|
||||
assert llm_args is not None, "llm_args should have been captured"
|
||||
|
||||
# Check that ad_config is an instance of LlmArgs
|
||||
assert isinstance(ad_config, LlmArgs), f"Expected AutoDeploy LlmArgs, got {type(ad_config)}"
|
||||
|
||||
# check that ad_config and experiment_config have the same args
|
||||
assert experiment_config.args == ad_config, (
|
||||
f"Expected experiment_config.args {experiment_config.args}, got {ad_config}"
|
||||
# Check that llm_args is an instance of LlmArgs and also an instance of AutoDeployConfig
|
||||
assert isinstance(llm_args, LlmArgs), f"Expected LlmArgs, got {type(llm_args)}"
|
||||
assert isinstance(llm_args, AutoDeployConfig), (
|
||||
f"Expected AutoDeployConfig, got {type(llm_args)}"
|
||||
)
|
||||
|
||||
# check that llm_args and experiment_config have the same args
|
||||
expected_ad_config: AutoDeployConfig = experiment_config.args
|
||||
expected_llm_args: LlmArgs = expected_ad_config.to_llm_args()
|
||||
assert expected_llm_args == llm_args, f"Expected llm args {expected_llm_args}, got {llm_args}"
|
||||
|
||||
# check expected parallel config
|
||||
world_size = experiment_config.args.world_size
|
||||
world_size = expected_ad_config.world_size
|
||||
expected_parallel_config = _ParallelConfig(
|
||||
auto_parallel=True, gpus_per_node=experiment_config.args.gpus_per_node
|
||||
auto_parallel=True, gpus_per_node=expected_llm_args.gpus_per_node
|
||||
)
|
||||
expected_parallel_config.world_size = world_size
|
||||
assert ad_config._parallel_config == expected_parallel_config, (
|
||||
f"Expected parallel_config {expected_parallel_config}, got {ad_config._parallel_config}"
|
||||
assert llm_args._parallel_config == expected_parallel_config, (
|
||||
f"Expected parallel_config {expected_parallel_config}, got {llm_args._parallel_config}"
|
||||
)
|
||||
|
||||
# backend should always be "_autodeploy"
|
||||
assert ad_config.backend == "_autodeploy", (
|
||||
f"Expected backend '_autodeploy', got {ad_config.backend}"
|
||||
assert llm_args.backend == "_autodeploy", (
|
||||
f"Expected backend '_autodeploy', got {llm_args.backend}"
|
||||
)
|
||||
|
||||
|
||||
@ -71,6 +74,16 @@ def _check_ad_config(experiment_config: ExperimentConfig, ad_config: LlmArgs):
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-simple",
|
||||
),
|
||||
get_small_model_config(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
attn_backend="torch",
|
||||
compile_backend="torch-simple",
|
||||
),
|
||||
get_small_model_config(
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
attn_backend="triton",
|
||||
compile_backend="torch-compile",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_build_ad(experiment_config: Dict):
|
||||
|
||||
@ -15,6 +15,7 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_name: str):
|
||||
_DATASET_NAME = "synthetic_128_128.txt"
|
||||
dataset_path = Path(temp_dir, _DATASET_NAME)
|
||||
dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py")
|
||||
script_dir = Path(root_dir, "benchmarks", "cpp")
|
||||
|
||||
# Generate a small dataset to run a test.
|
||||
command = [
|
||||
@ -36,7 +37,7 @@ def prepare_dataset(root_dir: str, temp_dir: str, model_name: str):
|
||||
"10",
|
||||
]
|
||||
print(f"Running command: {' '.join(command)}")
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
result = subprocess.run(command, cwd=str(script_dir), capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to prepare dataset: {result.stderr}")
|
||||
# Grab the stdout and write it to a dataset file for passing to suite.
|
||||
@ -59,7 +60,8 @@ def run_benchmark(model_name: str, dataset_path: str, temp_dir: str):
|
||||
"--extra_llm_api_options",
|
||||
f"{temp_dir}/model_kwargs.yaml",
|
||||
]
|
||||
runner.invoke(main, args, catch_exceptions=False)
|
||||
result = runner.invoke(main, args, catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_trtllm_bench(llm_root): # noqa: F811
|
||||
|
||||
@ -4,8 +4,10 @@ import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.attention import (
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
@ -416,6 +418,21 @@ class GroupedAttentionModel(torch.nn.Module):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
|
||||
def _get_match_repeat_kv_optimizer() -> Callable:
|
||||
config = {
|
||||
"cleanup_noop_slice": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
}
|
||||
|
||||
def _transform(gm: GraphModule) -> GraphModule:
|
||||
gm = InferenceOptimizer(None, config)(None, gm)
|
||||
match_repeat_kv(gm)
|
||||
return gm
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4), (8, 2)])
|
||||
@pytest.mark.parametrize(
|
||||
"model_cls", [RepeatKVModel, RepeatKVModel2, RepeatKVModel3, HFRepeatKVModel]
|
||||
@ -488,7 +505,7 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls):
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_repeat_kv,
|
||||
_get_match_repeat_kv_optimizer(),
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
|
||||
@ -44,13 +44,12 @@ class HFWrapper(nn.Module):
|
||||
return self.model(x)[0]
|
||||
|
||||
|
||||
def _joint_transform(gm: GraphModule) -> GraphModule:
|
||||
gm = match_repeat_kv(gm)
|
||||
gm = match_eager_attention(gm)
|
||||
gm = match_grouped_attention(gm)
|
||||
gm = match_causal_attn_mask(gm)
|
||||
gm = match_attention_layout(gm, MockAttentionDescriptor())
|
||||
return gm
|
||||
def _joint_transform(gm: GraphModule) -> None:
|
||||
match_repeat_kv(gm)
|
||||
match_eager_attention(gm)
|
||||
match_grouped_attention(gm)
|
||||
match_causal_attn_mask(gm)
|
||||
match_attention_layout(gm, MockAttentionDescriptor())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -78,6 +77,7 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
|
||||
dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).to("cuda")
|
||||
model.eval()
|
||||
x = torch.randint(
|
||||
0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda"
|
||||
)
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.rms_norm import fuse_rmsnorm
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size, device="cuda"))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
self.rms_norm = RMSNorm(1024, eps).to(torch.float16)
|
||||
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.rms_norm(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"variant, op",
|
||||
[
|
||||
("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
|
||||
("triton", torch.ops.auto_deploy.triton_rms_norm),
|
||||
("torch", torch.ops.auto_deploy.torch_rmsnorm),
|
||||
],
|
||||
)
|
||||
def test_rmsnorm_fusion(eps, variant, op):
|
||||
def checker(gm):
|
||||
return any(is_op(n, op) for n in gm.graph.nodes)
|
||||
|
||||
model = TestModel(eps)
|
||||
gm_transformed = run_test(
|
||||
model,
|
||||
torch.randn(2, 1024, device="cuda", dtype=torch.float16),
|
||||
partial(fuse_rmsnorm, backend=variant),
|
||||
checker,
|
||||
lambda num_p_og: num_p_og,
|
||||
dynamic_shapes={0: Dim("batch_size", max=8)},
|
||||
)
|
||||
print(gm_transformed.graph)
|
||||
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
|
||||
y_transformed = gm_transformed(new_input)
|
||||
y_model = model(new_input)
|
||||
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user