mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
doc:fix llm api reference blank issue. (#2669)
This commit is contained in:
parent
3537f75d25
commit
42d0585493
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -1326,10 +1326,6 @@
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#version-h">version.h</a><ul>
|
||||
<li class="toctree-l3"><a class="reference internal" href="#_CPPv4N12tensorrt_llm8executor19kTensorRtLlmVersionE"><code class="docutils literal notranslate"><span class="pre">tensorrt_llm::executor::kTensorRtLlmVersion</span></code></a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="runtime.html">Runtime</a></li>
|
||||
@ -8445,27 +8441,6 @@
|
||||
|
||||
</dd></dl>
|
||||
|
||||
</section>
|
||||
<section id="version-h">
|
||||
<h2>version.h<a class="headerlink" href="#version-h" title="Link to this heading"></a></h2>
|
||||
<dl class="cpp type">
|
||||
<dt class="sig sig-object cpp">
|
||||
<span class="target" id="namespacetensorrt__llm"></span><span class="k"><span class="pre">namespace</span></span><span class="w"> </span><span class="sig-name descname"><span class="n"><span class="pre">tensorrt_llm</span></span></span><br /></dt>
|
||||
<dd><dl class="cpp type">
|
||||
<dt class="sig sig-object cpp">
|
||||
<span class="target" id="namespacetensorrt__llm_1_1executor"></span><span class="k"><span class="pre">namespace</span></span><span class="w"> </span><span class="sig-name descname"><span class="n"><span class="pre">executor</span></span></span><br /></dt>
|
||||
<dd><div class="breathe-sectiondef docutils container">
|
||||
<p class="breathe-sectiondef-title rubric" id="breathe-section-title-variables">Variables</p>
|
||||
<dl class="cpp var">
|
||||
<dt class="sig sig-object cpp" id="_CPPv4N12tensorrt_llm8executor19kTensorRtLlmVersionE">
|
||||
<span id="_CPPv3N12tensorrt_llm8executor19kTensorRtLlmVersionE"></span><span id="_CPPv2N12tensorrt_llm8executor19kTensorRtLlmVersionE"></span><span id="tensorrt_llm::executor::kTensorRtLlmVersion__auto"></span><span class="target" id="version_8h_1a56d04b7cc852621bcfb54af652ddd6b3"></span><span class="k"><span class="pre">static</span></span><span class="w"> </span><span class="kt"><span class="pre">auto</span></span><span class="w"> </span><span class="k"><span class="pre">constexpr</span></span><span class="w"> </span><span class="sig-name descname"><span class="n"><span class="pre">kTensorRtLlmVersion</span></span></span><span class="w"> </span><span class="p"><span class="pre">=</span></span><span class="w"> </span><span class="s"><span class="pre">"0.14.0.dev2024100800"</span></span><a class="headerlink" href="#_CPPv4N12tensorrt_llm8executor19kTensorRtLlmVersionE" title="Link to this definition"></a><br /></dt>
|
||||
<dd></dd></dl>
|
||||
|
||||
</div>
|
||||
</dd></dl>
|
||||
|
||||
</dd></dl>
|
||||
|
||||
</section>
|
||||
</section>
|
||||
|
||||
@ -8480,7 +8455,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a60e3c2030>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac9bc9c70>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -12360,7 +12360,7 @@ one more than decoding draft tokens for prediction from primary head </p>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a60d0c5160>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac8691ac0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -161,7 +161,70 @@
|
||||
<h1>All modules for which code is available</h1>
|
||||
<ul><li><a href="tensorrt_llm/bindings.html">tensorrt_llm.bindings</a></li>
|
||||
<ul><li><a href="tensorrt_llm/bindings/executor.html">tensorrt_llm.bindings.executor</a></li>
|
||||
</ul></ul>
|
||||
</ul><li><a href="tensorrt_llm/builder.html">tensorrt_llm.builder</a></li>
|
||||
<li><a href="tensorrt_llm/executor.html">tensorrt_llm.executor</a></li>
|
||||
<li><a href="tensorrt_llm/functional.html">tensorrt_llm.functional</a></li>
|
||||
<li><a href="tensorrt_llm/layers/activation.html">tensorrt_llm.layers.activation</a></li>
|
||||
<li><a href="tensorrt_llm/layers/attention.html">tensorrt_llm.layers.attention</a></li>
|
||||
<li><a href="tensorrt_llm/layers/cast.html">tensorrt_llm.layers.cast</a></li>
|
||||
<li><a href="tensorrt_llm/layers/conv.html">tensorrt_llm.layers.conv</a></li>
|
||||
<li><a href="tensorrt_llm/layers/embedding.html">tensorrt_llm.layers.embedding</a></li>
|
||||
<li><a href="tensorrt_llm/layers/linear.html">tensorrt_llm.layers.linear</a></li>
|
||||
<li><a href="tensorrt_llm/layers/mlp.html">tensorrt_llm.layers.mlp</a></li>
|
||||
<li><a href="tensorrt_llm/layers/normalization.html">tensorrt_llm.layers.normalization</a></li>
|
||||
<li><a href="tensorrt_llm/layers/pooling.html">tensorrt_llm.layers.pooling</a></li>
|
||||
<li><a href="tensorrt_llm/llmapi/build_cache.html">tensorrt_llm.llmapi.build_cache</a></li>
|
||||
<li><a href="tensorrt_llm/llmapi/llm.html">tensorrt_llm.llmapi.llm</a></li>
|
||||
<li><a href="tensorrt_llm/llmapi/llm_utils.html">tensorrt_llm.llmapi.llm_utils</a></li>
|
||||
<li><a href="tensorrt_llm/models/baichuan/model.html">tensorrt_llm.models.baichuan.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/bert/model.html">tensorrt_llm.models.bert.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/bloom/model.html">tensorrt_llm.models.bloom.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/chatglm/config.html">tensorrt_llm.models.chatglm.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/chatglm/model.html">tensorrt_llm.models.chatglm.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/cogvlm/config.html">tensorrt_llm.models.cogvlm.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/cogvlm/model.html">tensorrt_llm.models.cogvlm.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/commandr/model.html">tensorrt_llm.models.commandr.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/dbrx/config.html">tensorrt_llm.models.dbrx.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/dbrx/model.html">tensorrt_llm.models.dbrx.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/deepseek_v1/model.html">tensorrt_llm.models.deepseek_v1.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/deepseek_v2/model.html">tensorrt_llm.models.deepseek_v2.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/dit/model.html">tensorrt_llm.models.dit.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/eagle/model.html">tensorrt_llm.models.eagle.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/enc_dec/model.html">tensorrt_llm.models.enc_dec.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/falcon/config.html">tensorrt_llm.models.falcon.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/falcon/model.html">tensorrt_llm.models.falcon.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/gemma/config.html">tensorrt_llm.models.gemma.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/gemma/model.html">tensorrt_llm.models.gemma.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/gpt/config.html">tensorrt_llm.models.gpt.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/gpt/model.html">tensorrt_llm.models.gpt.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/gptj/config.html">tensorrt_llm.models.gptj.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/gptj/model.html">tensorrt_llm.models.gptj.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/gptneox/model.html">tensorrt_llm.models.gptneox.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/llama/config.html">tensorrt_llm.models.llama.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/llama/model.html">tensorrt_llm.models.llama.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/mamba/model.html">tensorrt_llm.models.mamba.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/medusa/config.html">tensorrt_llm.models.medusa.config</a></li>
|
||||
<li><a href="tensorrt_llm/models/medusa/model.html">tensorrt_llm.models.medusa.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/mllama/model.html">tensorrt_llm.models.mllama.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/modeling_utils.html">tensorrt_llm.models.modeling_utils</a></li>
|
||||
<li><a href="tensorrt_llm/models/mpt/model.html">tensorrt_llm.models.mpt.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/opt/model.html">tensorrt_llm.models.opt.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/phi/model.html">tensorrt_llm.models.phi.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/phi3/model.html">tensorrt_llm.models.phi3.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/recurrentgemma/model.html">tensorrt_llm.models.recurrentgemma.model</a></li>
|
||||
<li><a href="tensorrt_llm/models/redrafter/model.html">tensorrt_llm.models.redrafter.model</a></li>
|
||||
<li><a href="tensorrt_llm/plugin/plugin.html">tensorrt_llm.plugin.plugin</a></li>
|
||||
<li><a href="tensorrt_llm/quantization/mode.html">tensorrt_llm.quantization.mode</a></li>
|
||||
<li><a href="tensorrt_llm/quantization/quantize_by_modelopt.html">tensorrt_llm.quantization.quantize_by_modelopt</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/enc_dec_model_runner.html">tensorrt_llm.runtime.enc_dec_model_runner</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/generation.html">tensorrt_llm.runtime.generation</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/kv_cache_manager.html">tensorrt_llm.runtime.kv_cache_manager</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/model_runner.html">tensorrt_llm.runtime.model_runner</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/model_runner_cpp.html">tensorrt_llm.runtime.model_runner_cpp</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/multimodal_model_runner.html">tensorrt_llm.runtime.multimodal_model_runner</a></li>
|
||||
<li><a href="tensorrt_llm/runtime/session.html">tensorrt_llm.runtime.session</a></li>
|
||||
<li><a href="tensorrt_llm/sampling_params.html">tensorrt_llm.sampling_params</a></li>
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
@ -170,7 +233,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a6148aee10>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac953a3c0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -667,7 +668,7 @@
|
||||
<span class="n">max_encoder_input_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span> <span class="c1"># for enc-dec DecoderModel</span>
|
||||
<span class="n">use_fused_mlp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">dry_run</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">visualize_network</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">visualize_network</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">monitor_memory</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">use_mrope</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
@ -768,7 +769,7 @@
|
||||
<span class="n">plugin_config</span><span class="o">.</span><span class="n">update_from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"plugin_config"</span><span class="p">])</span>
|
||||
|
||||
<span class="n">dry_run</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'dry_run'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">visualize_network</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'visualize_network'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">visualize_network</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'visualize_network'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">monitor_memory</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'monitor_memory'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">use_mrope</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'use_mrope'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
@ -1035,9 +1036,9 @@
|
||||
<span class="n">use_prompt_tuning</span><span class="o">=</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">></span> <span class="mi">0</span><span class="p">),</span>
|
||||
<span class="n">use_lora</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">max_lora_rank</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">max_lora_rank</span><span class="p">,</span>
|
||||
<span class="n">use_fp8_context_fmha</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span>
|
||||
<span class="p">]</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
|
||||
<span class="n">use_fp8_context_fmha</span><span class="o">=</span><span class="p">(</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span> <span class="o">==</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span>
|
||||
<span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
|
||||
<span class="n">use_optimize_cross_qkv</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
@ -1209,7 +1210,7 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">PretrainedModel</span><span class="p">,</span> <span class="n">build_config</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">)</span> <span class="o">-></span> <span class="n">Engine</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">'''Build engine from given model and optimization options specified in the build_config</span>
|
||||
<span class="sd"> WARNING: this function may change the given model object state in some optimization passes</span>
|
||||
<span class="sd"> WARNING: this function may change the given \p model object state in some optimization passes</span>
|
||||
<span class="sd"> to avoid cloning a model since normally the LLM models consumes large memory.</span>
|
||||
<span class="sd"> Create a new fresh model object if you need to build with different options.</span>
|
||||
<span class="sd"> '''</span>
|
||||
@ -1222,10 +1223,8 @@
|
||||
<span class="n">_init_max_seq_len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="ow">and</span> <span class="p">(</span>
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">==</span> <span class="mi">1</span>
|
||||
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span>
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span>
|
||||
<span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">"LlamaForCausalLM"</span>
|
||||
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">"Gemma2ForCausalLM"</span>
|
||||
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">"MedusaForCausalLM"</span><span class="p">)):</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">'Overriding reduce_fusion to False'</span><span class="p">)</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
@ -1238,10 +1237,7 @@
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">'max_draft_len'</span><span class="p">):</span>
|
||||
<span class="c1"># If model.config has 'max_draft_len' but build_config not specified,</span>
|
||||
<span class="c1"># use the value of model.config.max_draft_len to set the value of build_config.max_draft_len</span>
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">'redrafter_num_beams'</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span>
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">'redrafter_draft_len_per_beam'</span><span class="p">):</span>
|
||||
@ -1287,8 +1283,7 @@
|
||||
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">()</span>
|
||||
<span class="ow">and</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span>
|
||||
<span class="ow">in</span> <span class="p">[</span><span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">]):</span>
|
||||
<span class="ow">and</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">()):</span>
|
||||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||||
<span class="s2">"FP8 Paged Context FMHA only works with fp8 quantization workflow currently."</span>
|
||||
<span class="p">)</span>
|
||||
@ -1334,7 +1329,6 @@
|
||||
<span class="n">disable_weight_only_quant_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">disable_weight_only_quant_plugin</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span>
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">'disable_weight_only_quant_plugin'</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
|
||||
<span class="n">use_fp8_rowwise</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">()</span>
|
||||
<span class="n">use_fp4_gemm</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_nvfp4</span><span class="p">()</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">use_weight_only</span> <span class="ow">and</span> <span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span>
|
||||
@ -1351,8 +1345,7 @@
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_smooth_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">use_qserve</span><span class="p">:</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_qserve_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">use_fp4_gemm</span><span class="p">:</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fp4_gemm_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_fp8_rowwise</span><span class="p">:</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_fp8_rowwise_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="n">nccl_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
@ -1393,8 +1386,7 @@
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">"DecoderModel"</span> <span class="ow">or</span> <span class="s2">"mllama"</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span>
|
||||
<span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">"DecoderModel"</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">"MllamaForConditionalGeneration"</span><span class="p">:</span>
|
||||
<span class="n">prepare_input_args</span><span class="p">[</span><span class="s2">"max_seq_len"</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
|
||||
<span class="n">prepare_input_args</span><span class="p">[</span>
|
||||
<span class="s2">"max_decoder_input_len"</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span>
|
||||
@ -1418,7 +1410,7 @@
|
||||
<span class="s2">"spec_decoding_is_generation_length_variable"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">"Qwen2VLForConditionalGeneration"</span><span class="p">:</span>
|
||||
<span class="n">prepare_input_args</span><span class="p">[</span>
|
||||
<span class="s1">'mrope_rotary_cos_sin_size'</span><span class="p">]</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_embedding_dim</span>
|
||||
<span class="s1">'mrope_rotary_sin_cos_size'</span><span class="p">]</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_embedding_dim</span>
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||||
<span class="s2">"Paged Context FMHA is required for EAGLE. Turning it on"</span><span class="p">)</span>
|
||||
@ -1443,9 +1435,9 @@
|
||||
<span class="n">mapping</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">auto_parallel_config</span><span class="p">[</span><span class="s2">"mapping"</span><span class="p">]</span>
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span><span class="p">:</span>
|
||||
<span class="k">with</span> <span class="n">net_guard</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">to_onnx</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span><span class="p">)</span>
|
||||
<span class="n">network</span><span class="o">.</span><span class="n">to_onnx</span><span class="p">(</span><span class="sa">f</span><span class="s1">'rank</span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="si">}</span><span class="s1">.onnx'</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Network -> Engine</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
|
||||
@ -1484,7 +1476,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6db881d0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac642e570>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -14,7 +14,7 @@
|
||||
<script src="../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -207,7 +208,7 @@
|
||||
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
|
||||
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
|
||||
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
|
||||
<span class="sd"> For static dimension, it has min==opt==max, thus the shape param in the ctor can be an integer</span>
|
||||
<span class="sd"> For static dimension, it has min==opt==max, thus the \p shape param in the ctor can be an integer</span>
|
||||
<span class="sd"> '''</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
|
||||
@ -2759,9 +2760,9 @@
|
||||
<span class="n">to_array</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
||||
<span class="n">reduction_length</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="s1">'int64'</span><span class="p">)</span>
|
||||
<span class="n">lower_triangle</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||||
<span class="o"><=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="n">lower_triangle</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span>
|
||||
<span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="o"><=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="n">output</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">lower_triangle</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">output</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@ -4215,8 +4216,7 @@
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">NONE</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">LAST_PROCESS_FOR_UB</span> <span class="o">=</span> <span class="mi">2</span>
|
||||
<span class="n">RESIDUAL_RMS_PREPOST_NORM</span> <span class="o">=</span> <span class="mi">3</span></div>
|
||||
<span class="n">LAST_PROCESS_FOR_UB</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||||
|
||||
|
||||
|
||||
@ -4232,7 +4232,6 @@
|
||||
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">strategy</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
|
||||
@ -4241,7 +4240,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">residual</span> <span class="o">=</span> <span class="n">residual</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="o">=</span> <span class="n">norm_weight</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span> <span class="o">=</span> <span class="n">norm_pre_residual_weight</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||||
<span class="k">assert</span> <span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span> <span class="ow">or</span> <span class="p">(</span><span class="n">residual</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||||
|
||||
@ -4307,7 +4305,6 @@
|
||||
<span class="s2">"eps"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">eps</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_eps</span><span class="p">)</span>
|
||||
|
||||
<span class="n">p_affine</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||||
<span class="s2">"affine"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||||
@ -4332,9 +4329,6 @@
|
||||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">residual</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span><span class="p">:</span>
|
||||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||||
|
||||
@ -5391,10 +5385,9 @@
|
||||
<span class="n">spec_decoding_generation_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">long_rope_rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">long_rope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
@ -5679,7 +5672,7 @@
|
||||
|
||||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
|
||||
<span class="k">assert</span> <span class="p">(</span><span class="n">mrope_rotary_cos_sin</span>
|
||||
<span class="k">assert</span> <span class="p">(</span><span class="n">mrope_rotary_sin_cos</span>
|
||||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_mrope</span><span class="p">())</span>
|
||||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||||
<span class="s1">'GPTAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||||
@ -6008,20 +6001,19 @@
|
||||
<span class="c1"># add position_ids as well only if speculative decoding mode</span>
|
||||
<span class="k">assert</span> <span class="n">spec_decoding_position_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="k">assert</span> <span class="n">spec_decoding_generation_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="k">assert</span> <span class="n">spec_decoding_use</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||||
<span class="n">spec_decoding_generation_lengths</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_use</span>
|
||||
<span class="n">spec_decoding_position_offsets</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">long_rope_rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="n">long_rope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">long_rope_rotary_inv_freq</span><span class="p">,</span> <span class="n">long_rope_rotary_cos_sin</span><span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">mrope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">mrope_rotary_sin_cos</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="n">mrope_position_deltas</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="p">,</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="p">,</span>
|
||||
<span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">host_runtime_perf_knobs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@ -6502,7 +6494,6 @@
|
||||
<span class="s1">'identity'</span><span class="p">:</span> <span class="n">identity</span><span class="p">,</span>
|
||||
<span class="s1">'silu'</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
|
||||
<span class="s1">'softplus'</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
|
||||
<span class="s1">'relu2'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||||
<span class="s1">'squared-relu'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||||
@ -7524,7 +7515,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f59fcc7af90>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac686a480>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -196,7 +197,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dda48f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6835370>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -202,6 +203,8 @@
|
||||
<span class="kn">from</span> <span class="nn">.lora</span> <span class="kn">import</span> <span class="n">LoraRuntimeParams</span>
|
||||
<span class="kn">from</span> <span class="nn">.normalization</span> <span class="kn">import</span> <span class="n">GroupNorm</span><span class="p">,</span> <span class="n">LayerNorm</span><span class="p">,</span> <span class="n">RmsNorm</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">..functional</span> <span class="kn">import</span> <span class="n">maximum</span> <span class="c1"># isort:skip</span>
|
||||
|
||||
<span class="n">layernorm_map</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">LayerNormType</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">:</span> <span class="n">LayerNorm</span><span class="p">,</span>
|
||||
<span class="n">LayerNormType</span><span class="o">.</span><span class="n">RmsNorm</span><span class="p">:</span> <span class="n">RmsNorm</span><span class="p">,</span>
|
||||
@ -434,15 +437,13 @@
|
||||
<span class="n">spec_decoding_max_generation_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_generation_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_is_generation_length_variable</span> <span class="o">=</span> <span class="n">spec_decoding_is_generation_length_variable</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_max_generation_length</span> <span class="o">=</span> <span class="n">spec_decoding_max_generation_length</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span> <span class="o">=</span> <span class="n">spec_decoding_generation_lengths</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span> <span class="o">=</span> <span class="n">spec_decoding_position_offsets</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span> <span class="o">=</span> <span class="n">spec_decoding_packed_mask</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span> <span class="o">=</span> <span class="n">spec_decoding_use</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span> <span class="o">=</span> <span class="n">spec_decoding_packed_mask</span></div>
|
||||
|
||||
|
||||
|
||||
@ -452,10 +453,10 @@
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mrope_rotary_cos_sin</span> <span class="o">=</span> <span class="n">mrope_rotary_cos_sin</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mrope_rotary_sin_cos</span> <span class="o">=</span> <span class="n">mrope_rotary_sin_cos</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mrope_position_deltas</span> <span class="o">=</span> <span class="n">mrope_position_deltas</span></div>
|
||||
|
||||
|
||||
@ -1204,17 +1205,20 @@
|
||||
|
||||
<span class="c1"># KV cache scales.</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_rcp_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="mf">1.0</span><span class="p">]))</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">kv_quant_orig_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">kv_quant_orig_scale</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># Attention output scales</span>
|
||||
<span class="k">assert</span> <span class="p">(</span>
|
||||
<span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span>
|
||||
<span class="p">)</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">(</span>
|
||||
<span class="p">),</span> <span class="s2">"FP8 Context FMHA must be used together with the fp8 quantization workflow."</span>
|
||||
|
||||
<span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">(</span>
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fp4_gemm_plugin</span><span class="p">:</span>
|
||||
<span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="n">attention_output_orig_quant_scale</span> <span class="o">/</span> <span class="mf">6.0</span>
|
||||
|
||||
<span class="c1"># The rotary inv freq can be pre-computed.</span>
|
||||
<span class="n">rotary_inv_freq</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">attention_params</span><span class="p">,</span> <span class="s2">"rotary_inv_freq"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
@ -1323,10 +1327,9 @@
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span><span class="n">spec_decoding_use</span><span class="p">,</span>
|
||||
<span class="n">long_rope_rotary_inv_freq</span><span class="o">=</span><span class="n">long_rope_rotary_inv_freq</span><span class="p">,</span>
|
||||
<span class="n">long_rope_rotary_cos_sin</span><span class="o">=</span><span class="n">long_rope_rotary_cos_sin</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="o">=</span><span class="n">mrope_params</span><span class="o">.</span><span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="o">=</span><span class="n">mrope_params</span><span class="o">.</span><span class="n">mrope_rotary_sin_cos</span><span class="p">,</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="o">=</span><span class="n">mrope_params</span><span class="o">.</span><span class="n">mrope_position_deltas</span><span class="p">,</span>
|
||||
<span class="n">attn_logit_softcapping_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attn_value</span><span class="p">,</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span>
|
||||
@ -1395,8 +1398,8 @@
|
||||
|
||||
<span class="n">embed_positions</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">short</span><span class="p">,</span> <span class="n">long</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">select</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span>
|
||||
<span class="n">sequence_length</span>
|
||||
<span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_max_position_embeddings</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">sequence_length</span> <span class="o"><=</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">original_max_position_embeddings</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||||
<span class="n">embed_positions</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">embed_positions</span><span class="p">,</span>
|
||||
<span class="n">concat</span><span class="p">([</span><span class="n">select</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]),</span>
|
||||
<span class="n">sizes</span><span class="o">=</span><span class="n">shape</span><span class="p">(</span><span class="n">short</span><span class="p">))</span>
|
||||
@ -1673,14 +1676,10 @@
|
||||
<span class="n">dense_lora_params</span> <span class="o">=</span> <span class="n">lora_layer_params</span><span class="o">.</span><span class="n">get_runtime_params</span><span class="p">(</span>
|
||||
<span class="mi">0</span><span class="p">,</span> <span class="s2">"attn_dense"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="c1"># This case is used when we can skip this attention layer directly.</span>
|
||||
<span class="c1"># The output would be undefined and not used if skip_attn is not None</span>
|
||||
<span class="c1"># and set skip_attn as True during runtime</span>
|
||||
<span class="c1"># But when use_fp8_context_fmha is enabled, the output data type of</span>
|
||||
<span class="c1"># attention_plugin is fp8. Since TRT's conditional layer does not support</span>
|
||||
<span class="c1"># FP8 data type yet, we cannot use it to skip the computation in such case.</span>
|
||||
|
||||
<span class="n">dense_conditional</span> <span class="o">=</span> <span class="n">Conditional</span><span class="p">(</span><span class="n">skip_attn</span><span class="p">)</span>
|
||||
<span class="n">skip_case</span> <span class="o">=</span> <span class="n">dense_conditional</span><span class="o">.</span><span class="n">add_input</span><span class="p">(</span><span class="n">attention_input</span><span class="p">)</span>
|
||||
@ -1692,8 +1691,7 @@
|
||||
<span class="n">lora_runtime_params</span><span class="o">=</span><span class="n">dense_lora_params</span><span class="p">,</span>
|
||||
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">all_reduce_params</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">context</span> <span class="o">=</span> <span class="n">dense_conditional</span><span class="o">.</span><span class="n">add_output</span><span class="p">(</span><span class="n">skip_case</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
@ -2047,12 +2045,19 @@
|
||||
<span class="p">],</span> <span class="s1">'Plugin only support masked MHA.'</span>
|
||||
|
||||
<span class="c1"># KV cache scales.</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_rcp_scaling_factor</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">(</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||||
<span class="n">fp32_array</span><span class="p">([</span><span class="mf">1.0</span><span class="p">])</span>
|
||||
<span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">(</span>
|
||||
<span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="n">kv_quant_orig_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">(</span>
|
||||
<span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># Attention output scales</span>
|
||||
<span class="k">assert</span> <span class="p">(</span>
|
||||
<span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span>
|
||||
<span class="p">)</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">(</span>
|
||||
<span class="p">),</span> <span class="s2">"FP8 Context FMHA must be used together with the fp8 quantization workflow."</span>
|
||||
|
||||
<span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="n">context</span><span class="p">,</span> <span class="n">past_key_value</span> <span class="o">=</span> <span class="n">gpt_attention</span><span class="p">(</span>
|
||||
<span class="n">qkv</span><span class="o">=</span><span class="n">qkv</span><span class="p">,</span>
|
||||
@ -2100,7 +2105,7 @@
|
||||
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="p">,</span>
|
||||
@ -2332,13 +2337,19 @@
|
||||
|
||||
<span class="c1"># KV cache scales.</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_rcp_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="mf">1.0</span><span class="p">]))</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="n">kv_quant_orig_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_scaling_factor</span><span class="o">.</span><span class="n">value</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">kv_orig_quant_scale</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">kv_quant_orig_scale</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># Attention output scales</span>
|
||||
<span class="k">assert</span> <span class="p">(</span>
|
||||
<span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span>
|
||||
<span class="p">)</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">(</span>
|
||||
<span class="p">),</span> <span class="s2">"FP8 Context FMHA must be used together with the fp8 quantization workflow."</span>
|
||||
|
||||
<span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="n">rotary_cos_sin</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_positions_for_gpt_attention</span><span class="o">.</span><span class="n">value</span>
|
||||
@ -2411,7 +2422,6 @@
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span>
|
||||
<span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span><span class="n">spec_decoding_use</span><span class="p">,</span>
|
||||
<span class="n">attn_logit_softcapping_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attn_value</span><span class="p">,</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">attention_params</span><span class="o">.</span>
|
||||
<span class="n">host_runtime_perf_knobs</span><span class="p">,</span>
|
||||
@ -2443,7 +2453,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f59f708c0e0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6834e00>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -203,7 +204,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dc46a50>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac66b6450>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -400,7 +401,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dd4c8f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac67d5eb0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -369,7 +370,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dd4c800>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac69dd5b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -715,7 +716,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dd1b1d0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac68901a0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -571,7 +572,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dedda00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac67d72f0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -305,7 +306,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dd054c0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac66b7020>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -212,7 +213,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6de9b590>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6836de0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -496,7 +497,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6de7f9e0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6981010>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -166,20 +166,17 @@
|
||||
<span class="kn">import</span> <span class="nn">shutil</span>
|
||||
<span class="kn">import</span> <span class="nn">tempfile</span>
|
||||
<span class="kn">import</span> <span class="nn">weakref</span>
|
||||
<span class="kn">from</span> <span class="nn">contextlib</span> <span class="kn">import</span> <span class="n">contextmanager</span>
|
||||
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Literal</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Union</span>
|
||||
|
||||
<span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
|
||||
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">PreTrainedTokenizerBase</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">..</span> <span class="kn">import</span> <span class="n">bindings</span> <span class="k">as</span> <span class="n">tllm</span>
|
||||
<span class="kn">from</span> <span class="nn">..bindings</span> <span class="kn">import</span> <span class="n">executor</span> <span class="k">as</span> <span class="n">tllm</span>
|
||||
<span class="kn">from</span> <span class="nn">..builder</span> <span class="kn">import</span> <span class="n">EngineConfig</span>
|
||||
<span class="kn">from</span> <span class="nn">..executor</span> <span class="kn">import</span> <span class="p">(</span><span class="n">DetokenizedGenerationResultBase</span><span class="p">,</span> <span class="n">GenerationExecutor</span><span class="p">,</span>
|
||||
<span class="n">GenerationResult</span><span class="p">,</span> <span class="n">LoRARequest</span><span class="p">,</span> <span class="n">PromptAdapterRequest</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">..inputs</span> <span class="kn">import</span> <span class="n">PromptInputs</span><span class="p">,</span> <span class="n">create_input_processor</span><span class="p">,</span> <span class="n">prompt_inputs</span>
|
||||
<span class="kn">from</span> <span class="nn">..executor</span> <span class="kn">import</span> <span class="p">(</span><span class="n">GenerationExecutor</span><span class="p">,</span> <span class="n">GenerationResult</span><span class="p">,</span> <span class="n">LoRARequest</span><span class="p">,</span>
|
||||
<span class="n">PromptAdapterRequest</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
|
||||
<span class="kn">from</span> <span class="nn">..sampling_params</span> <span class="kn">import</span> <span class="n">SamplingParams</span>
|
||||
<span class="kn">from</span> <span class="nn">.llm_utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">LLMARGS_DOCSTRING</span><span class="p">,</span> <span class="n">CachedModelLoader</span><span class="p">,</span> <span class="n">LlmArgs</span><span class="p">,</span>
|
||||
@ -191,18 +188,9 @@
|
||||
<span class="kn">from</span> <span class="nn">.utils</span> <span class="kn">import</span> <span class="n">append_docstring</span><span class="p">,</span> <span class="n">exception_handler</span><span class="p">,</span> <span class="n">get_device_count</span>
|
||||
|
||||
|
||||
<span class="nd">@contextmanager</span>
|
||||
<span class="k">def</span> <span class="nf">nvtx_range</span><span class="p">(</span><span class="n">msg</span><span class="p">):</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
<span class="k">yield</span>
|
||||
<span class="k">finally</span><span class="p">:</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="RequestOutput">
|
||||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RequestOutput">[docs]</a>
|
||||
<span class="k">class</span> <span class="nc">RequestOutput</span><span class="p">(</span><span class="n">DetokenizedGenerationResultBase</span><span class="p">,</span> <span class="n">GenerationResult</span><span class="p">):</span>
|
||||
<span class="k">class</span> <span class="nc">RequestOutput</span><span class="p">(</span><span class="n">GenerationResult</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">"""The output data of a completion request to the LLM.</span>
|
||||
|
||||
<span class="sd"> Parameters:</span>
|
||||
@ -220,11 +208,42 @@
|
||||
<span class="n">generation_result</span><span class="p">:</span> <span class="n">GenerationResult</span><span class="p">,</span>
|
||||
<span class="n">prompt</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">tokenizer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">TokenizerBase</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">generation_result</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_streaming</span> <span class="o">=</span> <span class="n">generation_result</span><span class="o">.</span><span class="n">streaming</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prompt</span> <span class="o">=</span> <span class="n">prompt</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">prompt</span> <span class="o">=</span> <span class="n">prompt</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="RequestOutput.handle_response">
|
||||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RequestOutput.handle_response">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">handle_response</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">response</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">handle_response</span><span class="p">(</span><span class="n">response</span><span class="p">)</span>
|
||||
|
||||
<span class="n">sampling_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_generation_request</span><span class="o">.</span><span class="n">sampling_params</span>
|
||||
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s1">'skip_special_tokens'</span><span class="p">:</span>
|
||||
<span class="n">sampling_params</span><span class="o">.</span><span class="n">skip_special_tokens</span><span class="p">,</span>
|
||||
<span class="s1">'spaces_between_special_tokens'</span><span class="p">:</span>
|
||||
<span class="n">sampling_params</span><span class="o">.</span><span class="n">spaces_between_special_tokens</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">if</span> <span class="n">sampling_params</span><span class="o">.</span><span class="n">detokenize</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">for</span> <span class="n">beam_output</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">outputs</span><span class="p">:</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">_last_text_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">beam_output</span><span class="o">.</span><span class="n">text</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">,</span> <span class="s1">'decode_incrementally'</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">streaming</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">sampling_params</span><span class="o">.</span><span class="n">use_beam_search</span><span class="p">:</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">text</span><span class="p">,</span> <span class="n">beam_output</span><span class="o">.</span><span class="n">_incremental_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode_incrementally</span><span class="p">(</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">token_ids_diff</span><span class="p">,</span>
|
||||
<span class="n">prev_text</span><span class="o">=</span><span class="n">beam_output</span><span class="o">.</span><span class="n">text</span><span class="p">,</span>
|
||||
<span class="n">states</span><span class="o">=</span><span class="n">beam_output</span><span class="o">.</span><span class="n">_incremental_states</span><span class="p">,</span>
|
||||
<span class="n">flush</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">text</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode_incrementally</span><span class="p">(</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">token_ids</span><span class="p">,</span>
|
||||
<span class="n">flush</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span>
|
||||
<span class="n">beam_output</span><span class="o">.</span><span class="n">token_ids</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_repr_fields</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@ -234,6 +253,9 @@
|
||||
|
||||
|
||||
|
||||
<span class="n">PromptInputs</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="LLM">
|
||||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LLM">[docs]</a>
|
||||
<span class="nd">@append_docstring</span><span class="p">(</span><span class="n">LLMARGS_DOCSTRING</span><span class="p">)</span>
|
||||
@ -310,7 +332,7 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">llm_build_stats</span> <span class="o">=</span> <span class="n">LlmBuildStats</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_build_model</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_processor</span> <span class="o">=</span> <span class="n">create_input_processor</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
|
||||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mpi_session</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mpi_session</span><span class="o">.</span><span class="n">shutdown</span><span class="p">()</span>
|
||||
@ -336,7 +358,6 @@
|
||||
<span class="n">Sequence</span><span class="p">[</span><span class="n">LoRARequest</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">prompt_adapter_request</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span>
|
||||
<span class="n">PromptAdapterRequest</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">PromptAdapterRequest</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">queries</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">PromptInputs</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">PromptInputs</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="n">RequestOutput</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">RequestOutput</span><span class="p">]]:</span>
|
||||
<span class="w"> </span><span class="sd">"""Generate output for the given prompts in the synchronous mode.</span>
|
||||
<span class="sd"> Synchronous generation accepts either single prompt or batched prompts.</span>
|
||||
@ -351,24 +372,17 @@
|
||||
<span class="sd"> if any. Defaults to None.</span>
|
||||
<span class="sd"> prompt_adapter_request (PromptAdapterRequest, Sequence[PromptAdapterRequest], optional):</span>
|
||||
<span class="sd"> Prompt Adapter request to use for generation, if any. Defaults to None.</span>
|
||||
<span class="sd"> queries (PromptInputs or Sequence[PromptInputs]): The query text or token ids.</span>
|
||||
<span class="sd"> it can be single prompt or batched prompts. it is used for star attention to run long context tasks.</span>
|
||||
|
||||
<span class="sd"> Returns:</span>
|
||||
<span class="sd"> Union[RequestOutput, List[RequestOutput]]: The output data of the completion request to the LLM.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">unbatched</span> <span class="o">=</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">unbatched</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="n">unbatched</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="n">unbatched</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">unbatched</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">int</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">unbatched</span><span class="p">:</span>
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">inputs</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">queries</span><span class="p">:</span>
|
||||
<span class="n">queries</span> <span class="o">=</span> <span class="p">[</span><span class="n">queries</span><span class="p">]</span>
|
||||
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">prompt_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">queries</span><span class="p">:</span>
|
||||
<span class="n">queries</span> <span class="o">=</span> <span class="p">[</span><span class="n">prompt_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">queries</span><span class="p">]</span>
|
||||
|
||||
<span class="n">futures</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">request_inputs</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span>
|
||||
@ -384,9 +398,7 @@
|
||||
<span class="n">pa_req</span> <span class="o">=</span> <span class="n">prompt_adapter_request</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">pa_req</span> <span class="o">=</span> <span class="n">prompt_adapter_request</span>
|
||||
<span class="n">request_queries</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">queries</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">queries</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||||
<span class="n">future</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generate_async</span><span class="p">(</span><span class="n">request_inputs</span><span class="p">,</span>
|
||||
<span class="n">queries</span><span class="o">=</span><span class="n">request_queries</span><span class="p">,</span>
|
||||
<span class="n">sampling_params</span><span class="o">=</span><span class="n">sp</span><span class="p">,</span>
|
||||
<span class="n">lora_request</span><span class="o">=</span><span class="n">lora_req</span><span class="p">,</span>
|
||||
<span class="n">prompt_adapter_request</span><span class="o">=</span><span class="n">pa_req</span><span class="p">,</span>
|
||||
@ -407,7 +419,6 @@
|
||||
|
||||
<div class="viewcode-block" id="LLM.generate_async">
|
||||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LLM.generate_async">[docs]</a>
|
||||
<span class="nd">@nvtx_range</span><span class="p">(</span><span class="s2">"generate_async"</span><span class="p">)</span>
|
||||
<span class="k">def</span> <span class="nf">generate_async</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">inputs</span><span class="p">:</span> <span class="n">PromptInputs</span><span class="p">,</span>
|
||||
@ -415,7 +426,6 @@
|
||||
<span class="n">lora_request</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoRARequest</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">prompt_adapter_request</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">PromptAdapterRequest</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">queries</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">PromptInputs</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">)</span> <span class="o">-></span> <span class="n">RequestOutput</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""Generate output for the given prompt in the asynchronous mode.</span>
|
||||
<span class="sd"> Asynchronous generation accepts single prompt only.</span>
|
||||
@ -430,52 +440,31 @@
|
||||
<span class="sd"> use for generation, if any. Defaults to None.</span>
|
||||
<span class="sd"> streaming (bool): Whether to use the streaming mode for the generation. Defaults to</span>
|
||||
<span class="sd"> False.</span>
|
||||
<span class="sd"> queries (PromptInputs or Sequence[PromptInputs]): The query text or token ids.</span>
|
||||
<span class="sd"> it can be single prompt or batched prompts. it is used for star attention to run long context tasks.</span>
|
||||
|
||||
<span class="sd"> Returns:</span>
|
||||
<span class="sd"> RequestOutput: The output data of the completion request to the LLM.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">sampling_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_sampling_params</span><span class="p">(</span><span class="n">sampling_params</span><span class="p">)</span>
|
||||
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">prompt_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">queries</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">queries</span> <span class="o">=</span> <span class="n">prompt_inputs</span><span class="p">(</span><span class="n">queries</span><span class="p">)</span>
|
||||
|
||||
<span class="n">query_token_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">prompt_tuning_config</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="s2">"prompt_token_ids"</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||||
<span class="n">prompt_token_ids</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s1">'prompt_token_ids'</span><span class="p">]</span>
|
||||
<span class="n">prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">queries</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">query_token_ids</span> <span class="o">=</span> <span class="n">queries</span><span class="p">[</span><span class="s1">'prompt_token_ids'</span><span class="p">]</span>
|
||||
<span class="k">elif</span> <span class="s2">"prompt"</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||||
<span class="n">prompt_token_ids</span><span class="p">,</span> <span class="n">extra_processed_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_processor</span><span class="p">(</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="n">prompt_token_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_prompt_token_ids</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="p">,</span> <span class="n">sampling_params</span><span class="p">)</span>
|
||||
<span class="n">prompt</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s1">'prompt'</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">queries</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">query_token_ids</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_processor</span><span class="p">(</span>
|
||||
<span class="n">queries</span><span class="p">,</span> <span class="n">sampling_params</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">extra_processed_inputs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">prompt_tuning_config</span> <span class="o">=</span> <span class="n">extra_processed_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
|
||||
<span class="s1">'prompt_tuning_config'</span><span class="p">)</span>
|
||||
<span class="n">prompt</span> <span class="o">=</span> <span class="n">inputs</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="n">prompt_token_ids</span> <span class="o">=</span> <span class="n">inputs</span>
|
||||
<span class="n">prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"The inputs must be type str or list of int, but got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_check_arguments</span><span class="p">(</span>
|
||||
<span class="nb">len</span><span class="p">(</span><span class="n">prompt_token_ids</span><span class="p">),</span>
|
||||
<span class="nb">len</span><span class="p">(</span><span class="n">query_token_ids</span><span class="p">)</span> <span class="k">if</span> <span class="n">query_token_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">sampling_params</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_check_arguments</span><span class="p">(</span><span class="n">prompt_token_ids</span><span class="p">,</span> <span class="n">sampling_params</span><span class="p">)</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_executor</span><span class="o">.</span><span class="n">generate_async</span><span class="p">(</span>
|
||||
<span class="n">prompt_token_ids</span><span class="p">,</span>
|
||||
<span class="n">query_token_ids</span><span class="o">=</span><span class="n">query_token_ids</span><span class="p">,</span>
|
||||
<span class="n">sampling_params</span><span class="o">=</span><span class="n">sampling_params</span><span class="p">,</span>
|
||||
<span class="n">lora_request</span><span class="o">=</span><span class="n">lora_request</span><span class="p">,</span>
|
||||
<span class="n">prompt_adapter_request</span><span class="o">=</span><span class="n">prompt_adapter_request</span><span class="p">,</span>
|
||||
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
|
||||
<span class="n">prompt_tuning_config</span><span class="o">=</span><span class="n">prompt_tuning_config</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">RequestOutput</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">prompt</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span></div>
|
||||
|
||||
@ -508,6 +497,21 @@
|
||||
<span class="sd"> '''</span>
|
||||
<span class="k">return</span> <span class="k">await</span> <span class="bp">self</span><span class="o">.</span><span class="n">_executor</span><span class="o">.</span><span class="n">aget_stats</span><span class="p">(</span><span class="n">timeout</span><span class="o">=</span><span class="n">timeout</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_prepare_prompt_token_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||||
<span class="n">sampling_params</span><span class="p">:</span> <span class="n">SamplingParams</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"tokenizer is required to tokenize string prompt"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">sampling_params</span><span class="o">.</span><span class="n">truncate_prompt_tokens</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span>
|
||||
<span class="n">prompt</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">sampling_params</span><span class="o">.</span><span class="n">add_special_tokens</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span>
|
||||
<span class="n">prompt</span><span class="p">,</span>
|
||||
<span class="n">add_special_tokens</span><span class="o">=</span><span class="n">sampling_params</span><span class="o">.</span><span class="n">add_special_tokens</span><span class="p">,</span>
|
||||
<span class="n">truncation</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">max_length</span><span class="o">=</span><span class="n">sampling_params</span><span class="o">.</span><span class="n">truncate_prompt_tokens</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_prepare_sampling_params</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">sampling_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">SamplingParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">SamplingParams</span><span class="p">:</span>
|
||||
@ -531,12 +535,9 @@
|
||||
<span class="sa">f</span><span class="s2">"The sampling_params must be type SamplingParams or None, but got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">sampling_params</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_check_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">query_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||||
<span class="k">def</span> <span class="nf">_check_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prompt_token_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||||
<span class="n">sampling_params</span><span class="p">:</span> <span class="n">SamplingParams</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
|
||||
<span class="k">if</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="p">,</span> <span class="s1">'backend'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="s1">'pytorch'</span><span class="p">:</span>
|
||||
<span class="k">return</span>
|
||||
|
||||
<span class="n">build_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">build_config</span>
|
||||
|
||||
<span class="n">built_enging_cfg_file</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">model</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'config.json'</span>
|
||||
@ -544,13 +545,14 @@
|
||||
<span class="n">built_enging_cfg</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||||
<span class="n">max_seq_len</span> <span class="o">=</span> <span class="n">built_enging_cfg</span><span class="p">[</span><span class="s1">'build_config'</span><span class="p">][</span>
|
||||
<span class="s1">'max_seq_len'</span><span class="p">]</span> <span class="k">if</span> <span class="s1">'build_config'</span> <span class="ow">in</span> <span class="n">built_enging_cfg</span> <span class="k">else</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
|
||||
<span class="c1"># TODO: Remove this check and left the request verification to cpp runtime</span>
|
||||
|
||||
<span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">enable_chunked_prefill</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span>
|
||||
<span class="n">prompt_len</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">+</span> <span class="n">query_len</span> <span class="o">+</span>
|
||||
<span class="n">sampling_params</span><span class="o">.</span><span class="n">max_tokens</span> <span class="o">></span> <span class="n">max_seq_len</span><span class="p">):</span>
|
||||
<span class="n">prompt_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt_token_ids</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># TODO: Remove this check and left the request verification to cpp runtime</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">enable_chunked_prefill</span>
|
||||
<span class="p">)</span> <span class="ow">and</span> <span class="n">prompt_len</span> <span class="o">+</span> <span class="n">sampling_params</span><span class="o">.</span><span class="n">max_tokens</span> <span class="o">></span> <span class="n">max_seq_len</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"The sum of prompt length (</span><span class="si">{</span><span class="n">prompt_len</span><span class="o">/</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2">) and query length (</span><span class="si">{</span><span class="n">query_len</span><span class="si">}</span><span class="s2">) max_tokens (</span><span class="si">{</span><span class="n">sampling_params</span><span class="o">.</span><span class="n">max_tokens</span><span class="si">}</span><span class="s2">) should not exceed "</span>
|
||||
<span class="sa">f</span><span class="s2">"The sum of prompt length (</span><span class="si">{</span><span class="n">prompt_len</span><span class="si">}</span><span class="s2">) and max_tokens (</span><span class="si">{</span><span class="n">sampling_params</span><span class="o">.</span><span class="n">max_tokens</span><span class="si">}</span><span class="s2">) should not exceed "</span>
|
||||
<span class="sa">f</span><span class="s2">"max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2">)"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">sampling_params</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">></span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">:</span>
|
||||
@ -616,17 +618,19 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">extended_runtime_perf_knob_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">executor_config</span><span class="o">.</span><span class="n">extended_runtime_perf_knob_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">extended_runtime_perf_knob_config</span>
|
||||
|
||||
<span class="n">trt_engine_dir</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_engine_dir</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_engine_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="c1"># PIVOT_TO_PYTHON_START</span>
|
||||
<span class="n">hf_model_dir</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.pyexecutor.config</span> <span class="kn">import</span> <span class="n">update_executor_config</span>
|
||||
<span class="n">update_executor_config</span><span class="p">(</span>
|
||||
<span class="n">executor_config</span><span class="p">,</span>
|
||||
<span class="n">backend</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">backend</span><span class="p">,</span>
|
||||
<span class="n">pytorch_backend_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">pytorch_backend_config</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">to_mapping</span><span class="p">(),</span>
|
||||
<span class="n">build_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span>
|
||||
<span class="n">hf_model_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span><span class="p">,</span>
|
||||
<span class="n">trt_engine_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_engine_dir</span><span class="p">)</span>
|
||||
<span class="n">executor_config</span><span class="o">.</span><span class="n">llm_parallel_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">parallel_config</span>
|
||||
<span class="n">max_seq_len</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
|
||||
<span class="n">hf_model_dir</span><span class="o">=</span><span class="n">hf_model_dir</span><span class="p">,</span>
|
||||
<span class="n">trt_engine_dir</span><span class="o">=</span><span class="n">trt_engine_dir</span><span class="p">)</span>
|
||||
<span class="c1"># PIVOT_TO_PYTHON_END</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_executor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_executor_cls</span><span class="o">.</span><span class="n">create</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_engine_dir</span><span class="p">,</span>
|
||||
@ -636,9 +640,6 @@
|
||||
<span class="n">mpi_session</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mpi_session</span><span class="p">,</span>
|
||||
<span class="n">reuse_mpi_comm</span><span class="o">=</span><span class="n">external_mpi_comm_available</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">world_size</span><span class="p">),</span>
|
||||
<span class="n">num_postprocess_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">_num_postprocess_workers</span><span class="p">,</span>
|
||||
<span class="n">postprocess_tokenizer_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">_postprocess_tokenizer_dir</span><span class="p">,</span>
|
||||
<span class="n">postprocess_result_handler</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">_postprocess_result_handler</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_try_load_tokenizer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="n">TokenizerBase</span><span class="p">]:</span>
|
||||
@ -717,7 +718,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7da791842210>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ac7cb0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -211,7 +211,8 @@
|
||||
<span class="n">ExtendedRuntimePerfKnobConfig</span><span class="p">,</span> <span class="n">KvCacheConfig</span><span class="p">,</span>
|
||||
<span class="n">LookaheadDecodingConfig</span><span class="p">,</span> <span class="n">PeftCacheConfig</span><span class="p">,</span>
|
||||
<span class="n">SchedulerConfig</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">..builder</span> <span class="kn">import</span> <span class="n">BuildConfig</span><span class="p">,</span> <span class="n">Engine</span><span class="p">,</span> <span class="n">EngineConfig</span><span class="p">,</span> <span class="n">build</span>
|
||||
<span class="kn">from</span> <span class="nn">..builder</span> <span class="kn">import</span> <span class="p">(</span><span class="n">BuildConfig</span><span class="p">,</span> <span class="n">Engine</span><span class="p">,</span> <span class="n">EngineConfig</span><span class="p">,</span> <span class="n">_init_max_seq_len</span><span class="p">,</span>
|
||||
<span class="n">build</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
|
||||
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
|
||||
<span class="kn">from</span> <span class="nn">..models.automodel</span> <span class="kn">import</span> <span class="n">MODEL_MAP</span><span class="p">,</span> <span class="n">AutoConfig</span><span class="p">,</span> <span class="n">AutoModelForCausalLM</span>
|
||||
@ -221,8 +222,7 @@
|
||||
<span class="kn">from</span> <span class="nn">.build_cache</span> <span class="kn">import</span> <span class="p">(</span><span class="n">BuildCache</span><span class="p">,</span> <span class="n">BuildCacheConfig</span><span class="p">,</span> <span class="n">CachedStage</span><span class="p">,</span>
|
||||
<span class="n">get_build_cache_config_from_env</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">.mpi_session</span> <span class="kn">import</span> <span class="n">MPINodeState</span><span class="p">,</span> <span class="n">MpiSession</span>
|
||||
<span class="kn">from</span> <span class="nn">.tokenizer</span> <span class="kn">import</span> <span class="p">(</span><span class="n">TokenizerBase</span><span class="p">,</span> <span class="n">TransformersTokenizer</span><span class="p">,</span> <span class="n">load_hf_tokenizer</span><span class="p">,</span>
|
||||
<span class="n">tokenizer_factory</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">.tokenizer</span> <span class="kn">import</span> <span class="n">TokenizerBase</span><span class="p">,</span> <span class="n">TransformersTokenizer</span><span class="p">,</span> <span class="n">tokenizer_factory</span>
|
||||
<span class="c1"># TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import</span>
|
||||
<span class="kn">from</span> <span class="nn">.utils</span> <span class="kn">import</span> <span class="p">(</span><span class="n">GpuArch</span><span class="p">,</span> <span class="n">append_docstring</span><span class="p">,</span> <span class="n">download_hf_model</span><span class="p">,</span>
|
||||
<span class="n">download_hf_pretrained_config</span><span class="p">,</span> <span class="n">enable_llm_debug</span><span class="p">,</span>
|
||||
@ -238,7 +238,6 @@
|
||||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">moe_tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">moe_ep_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">cp_config</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">)</span>
|
||||
<span class="n">auto_parallel</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="n">_world_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
@ -260,7 +259,6 @@
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">world_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||||
@ -280,24 +278,15 @@
|
||||
<span class="k">elif</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">auto_parallel</span>
|
||||
<span class="p">)</span> <span class="ow">and</span> <span class="n">world_size</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="p">:</span>
|
||||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"world_size </span><span class="si">{</span><span class="n">world_size</span><span class="si">}</span><span class="s2"> should be equal to tp_size * pp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2"> "</span>
|
||||
<span class="sa">f</span><span class="s2">"world_size </span><span class="si">{</span><span class="n">world_size</span><span class="si">}</span><span class="s2"> should be equal to tp_size * pp_size * cp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2"> "</span>
|
||||
<span class="s2">"in non-auto_parallel mode.</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="s2">"For non-auto-parallel mode, the world_size is not needed to set"</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">is_multi_gpu</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">></span> <span class="mi">1</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">to_mapping</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Mapping</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span>
|
||||
<span class="n">rank</span><span class="o">=</span><span class="n">mpi_rank</span><span class="p">(),</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||||
<span class="n">cp_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_config</span><span class="p">,</span>
|
||||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_ep_size</span><span class="p">,</span>
|
||||
<span class="n">auto_parallel</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">)</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="CalibConfig">
|
||||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CalibConfig">[docs]</a>
|
||||
@ -576,8 +565,6 @@
|
||||
|
||||
<span class="n">moe_expert_parallel_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="n">cp_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">)</span>
|
||||
|
||||
<span class="n">auto_parallel</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="n">auto_parallel_world_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
@ -661,12 +648,8 @@
|
||||
<span class="c1"># Extra PyTorch backend options, ignored if backend != "pytorch".</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.pyexecutor.config</span> <span class="kn">import</span> <span class="n">PyTorchConfig</span>
|
||||
<span class="n">pytorch_backend_config</span><span class="p">:</span> <span class="n">PyTorchConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">PyTorchConfig</span><span class="p">)</span>
|
||||
<span class="c1"># PIVOT_TO_PYTHON_END</span>
|
||||
|
||||
<span class="c1"># private options</span>
|
||||
<span class="n">_num_postprocess_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Number of postprocess worker processes</span>
|
||||
<span class="n">_postprocess_tokenizer_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">_postprocess_result_handler</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="c1"># PIVOT_TO_PYTHON_END</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">__post_init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="c1"># TODO[chunweiy]: Enable this option in the future</span>
|
||||
@ -700,7 +683,6 @@
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">context_parallel_size</span><span class="p">,</span>
|
||||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_tensor_parallel_size</span><span class="p">,</span>
|
||||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_expert_parallel_size</span><span class="p">,</span>
|
||||
<span class="n">cp_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_config</span><span class="p">,</span>
|
||||
<span class="n">auto_parallel</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">auto_parallel_world_size</span>
|
||||
@ -770,8 +752,7 @@
|
||||
<span class="n">speculative_model_obj</span> <span class="o">=</span> <span class="n">_ModelWrapper</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_model</span>
|
||||
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span> <span class="ow">and</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'backend'</span><span class="p">,</span>
|
||||
<span class="kc">None</span><span class="p">)</span> <span class="o">!=</span> <span class="s1">'pytorch'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span><span class="p">:</span>
|
||||
<span class="c1"># Load parallel_config from the engine.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">model_format</span> <span class="o">=</span> <span class="n">ModelLoader</span><span class="o">.</span><span class="n">get_model_format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||||
|
||||
@ -1205,8 +1186,19 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">speculative_model</span>
|
||||
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">speculative_model</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">convert_checkpoint_options</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">mpi_rank</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">to_mapping</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">mpi_rank</span><span class="p">()</span> <span class="k">if</span> <span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">is_multi_gpu</span> <span class="k">else</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">is_multi_gpu</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">Mapping</span><span class="p">(</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">moe_ep_size</span><span class="p">,</span>
|
||||
<span class="n">rank</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span>
|
||||
<span class="n">world_size</span><span class="o">=</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">Mapping</span><span class="p">()</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_build_pipeline</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
|
||||
@ -1215,7 +1207,7 @@
|
||||
<span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_obj</span><span class="o">.</span><span class="n">model_dir</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span>
|
||||
<span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span><span class="o">.</span><span class="n">model_dir</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="n">PATH</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span><span class="o">.</span><span class="n">model_dir</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_info</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">_ModelInfo</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_format</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">model_format</span>
|
||||
|
||||
@ -1465,8 +1457,7 @@
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_dir</span> <span class="o">=</span> <span class="n">mpi_broadcast</span><span class="p">(</span><span class="n">speculative_model_dir</span><span class="p">,</span>
|
||||
<span class="n">root</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span><span class="o">.</span><span class="n">model_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_dir</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_dir</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_obj</span><span class="o">.</span><span class="n">is_local_model</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">_load_model_from_hf</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
@ -1476,20 +1467,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">trust_remote_code</span><span class="p">,</span>
|
||||
<span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="p">,</span> <span class="s2">"speculative_model"</span><span class="p">)</span>
|
||||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">speculative_model</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Update quant_config if it's ModelOpt quantized ckpt</span>
|
||||
<span class="n">hf_quant_config_path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"hf_quant_config.json"</span>
|
||||
<span class="k">if</span> <span class="n">hf_quant_config_path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
||||
<span class="n">already_quantized</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">hf_quant_config_path</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||||
<span class="n">hf_quant_config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">hf_quant_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"quantization"</span><span class="p">,</span>
|
||||
<span class="p">{})</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">quant_config</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
|
||||
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">quant_config</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">already_quantized</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">load_format</span> <span class="o">==</span> <span class="s1">'dummy'</span><span class="p">:</span>
|
||||
<span class="n">config</span> <span class="o">=</span> <span class="n">model_cls</span><span class="o">.</span><span class="n">config_class</span><span class="o">.</span><span class="n">from_hugging_face</span><span class="p">(</span>
|
||||
<span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_dir</span><span class="p">),</span>
|
||||
@ -1499,7 +1476,7 @@
|
||||
<span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">convert_checkpoint_options</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model_cls</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">quant_config</span><span class="o">.</span><span class="n">requires_calibration</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">already_quantized</span><span class="p">:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">quant_config</span><span class="o">.</span><span class="n">requires_calibration</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">workspace</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">workspace</span><span class="si">}</span><span class="s2">/quantized-checkpoint"</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
@ -1621,10 +1598,16 @@
|
||||
<span class="n">model_dir</span><span class="p">,</span>
|
||||
<span class="n">trust_remote_code</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">use_fast</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="n">TransformersTokenizer</span><span class="p">]:</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">tokenizer</span> <span class="o">:=</span> <span class="n">load_hf_tokenizer</span><span class="p">(</span><span class="n">model_dir</span><span class="p">,</span> <span class="n">trust_remote_code</span><span class="p">,</span>
|
||||
<span class="n">use_fast</span><span class="p">))</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">tokenizer</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">try</span><span class="p">:</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">TransformersTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
||||
<span class="n">model_dir</span><span class="p">,</span>
|
||||
<span class="n">legacy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">padding_side</span><span class="o">=</span><span class="s1">'left'</span><span class="p">,</span>
|
||||
<span class="n">truncation_side</span><span class="o">=</span><span class="s1">'left'</span><span class="p">,</span>
|
||||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="n">trust_remote_code</span><span class="p">,</span>
|
||||
<span class="n">use_fast</span><span class="o">=</span><span class="n">use_fast</span><span class="p">)</span>
|
||||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Failed to load tokenizer from </span><span class="si">{</span><span class="n">model_dir</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="kc">None</span>
|
||||
|
||||
@ -1707,6 +1690,8 @@
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||||
<span class="s2">"QuantConfig for pytorch backend is ignored. You can load"</span>
|
||||
<span class="s2">"quantized model with hf_quant_config.json directly."</span><span class="p">)</span>
|
||||
<span class="n">_init_max_seq_len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_pretrained_config</span><span class="p">(),</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">build_config</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span>
|
||||
<span class="c1"># PIVOT_TO_PYTHON_END</span>
|
||||
|
||||
@ -1767,7 +1752,10 @@
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">AutoConfig</span><span class="o">.</span><span class="n">from_hugging_face</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_hf_model_dir</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">to_mapping</span><span class="p">(),</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span><span class="p">),</span>
|
||||
<span class="n">quant_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">quant_config</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">llm_args</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
|
||||
@ -1890,7 +1878,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7da791af4950>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac67ee5d0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -430,7 +431,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dc09880>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6b74470>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -734,7 +735,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f59fcc79bb0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac642e0c0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -342,7 +343,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6deb1f10>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6224260>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -359,7 +360,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dd4c440>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6b0b7a0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -558,7 +559,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6dff2750>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac69833b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -218,7 +219,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6deb35c0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6a1acc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -471,7 +472,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6d9f5c70>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ac5dc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -369,7 +370,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e004a10>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6b53410>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -233,7 +234,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6deb3500>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac62276b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -359,7 +360,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e1b6330>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ac4590>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -412,7 +413,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6de7cc80>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6fe8ef0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -459,7 +460,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e168620>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6fea2d0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -569,7 +570,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6de9a450>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac69825d0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -186,7 +187,7 @@
|
||||
<span class="kn">from</span> <span class="nn">..._common</span> <span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span>
|
||||
<span class="kn">from</span> <span class="nn">..._utils</span> <span class="kn">import</span> <span class="n">pad_vocab_size</span>
|
||||
<span class="kn">from</span> <span class="nn">...bindings</span> <span class="kn">import</span> <span class="n">KVCacheType</span>
|
||||
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">,</span> <span class="n">cast</span><span class="p">,</span> <span class="n">concat</span><span class="p">,</span>
|
||||
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">,</span>
|
||||
<span class="n">gather_last_token_logits</span><span class="p">,</span> <span class="n">index_select</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">...layers</span> <span class="kn">import</span> <span class="n">AttentionParams</span><span class="p">,</span> <span class="n">ColumnLinear</span><span class="p">,</span> <span class="n">SpecDecodingParams</span>
|
||||
<span class="kn">from</span> <span class="nn">...module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">ModuleList</span>
|
||||
@ -685,11 +686,13 @@
|
||||
|
||||
<span class="k">class</span> <span class="nc">EagleNet</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">logits_dtype</span><span class="p">):</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">config</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">drafter</span> <span class="o">=</span> <span class="n">LLaMAModel</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">logits_dtype</span> <span class="o">=</span> <span class="n">logits_dtype</span>
|
||||
|
||||
<span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||||
<span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
|
||||
@ -725,8 +728,7 @@
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather_last_token_logits</span><span class="p">(</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">last_token_indices</span><span class="p">,</span>
|
||||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">logits_dtype</span><span class="p">),</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">cache</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">),</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">cache</span>
|
||||
|
||||
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">cache</span>
|
||||
|
||||
@ -758,8 +760,7 @@
|
||||
<span class="n">eagle_net_config</span><span class="o">.</span><span class="n">layer_idx_offset</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">num_hidden_layers</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">eagle_nets</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">([</span>
|
||||
<span class="n">EagleNet</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">eagle_net_config</span><span class="p">,</span>
|
||||
<span class="n">logits_dtype</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">logits_dtype</span><span class="p">)</span>
|
||||
<span class="n">EagleNet</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">eagle_net_config</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span><span class="p">)</span>
|
||||
<span class="p">])</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||||
@ -777,7 +778,7 @@
|
||||
<span class="n">host_gen_eagle_net_context_lengths</span><span class="p">,</span>
|
||||
<span class="n">host_gen_eagle_net_past_key_value_lengths</span><span class="p">,</span>
|
||||
<span class="n">hidden_size_batch_level_starts</span><span class="p">,</span> <span class="n">input_gen_tokens</span><span class="p">,</span>
|
||||
<span class="n">input_spec_decoding_generation_lengths</span><span class="p">,</span> <span class="n">spec_decoding_use</span><span class="p">):</span>
|
||||
<span class="n">input_spec_decoding_generation_lengths</span><span class="p">):</span>
|
||||
|
||||
<span class="n">drafter_inputs</span> <span class="o">=</span> <span class="n">eagle_prepare_drafter_inputs_plugin</span><span class="p">(</span>
|
||||
<span class="n">layer_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_non_leaves_per_layer</span><span class="p">,</span>
|
||||
@ -811,8 +812,7 @@
|
||||
<span class="k">if</span> <span class="n">layer_idx</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">spec_decoding_params</span> <span class="o">=</span> <span class="n">SpecDecodingParams</span><span class="p">(</span>
|
||||
<span class="kc">True</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span><span class="p">,</span> <span class="n">spec_decoding_generation_lengths</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="p">)</span>
|
||||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Get hidden states for accepted ids</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_slice_hidden_states</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
@ -984,8 +984,7 @@
|
||||
<span class="n">hidden_size_batch_level_starts</span><span class="o">=</span><span class="n">hidden_size_batch_level_starts</span><span class="p">,</span>
|
||||
<span class="n">input_gen_tokens</span><span class="o">=</span><span class="n">input_gen_tokens</span><span class="p">,</span>
|
||||
<span class="n">input_spec_decoding_generation_lengths</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span>
|
||||
<span class="n">spec_decoding_generation_lengths</span><span class="p">,</span>
|
||||
<span class="n">spec_decoding_use</span><span class="o">=</span><span class="n">spec_decoding_params</span><span class="o">.</span><span class="n">spec_decoding_use</span><span class="p">)</span>
|
||||
<span class="n">spec_decoding_generation_lengths</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">single_eagle_net_iter</span><span class="p">(</span><span class="n">next_draft_tokens</span><span class="p">,</span> <span class="n">next_draft_lens</span><span class="p">):</span>
|
||||
<span class="c1"># Run EAGLE Net</span>
|
||||
@ -1073,7 +1072,6 @@
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="n">lm_logits</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">all_hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="n">lm_logits</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">lm_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">logits_dtype</span><span class="p">)</span>
|
||||
<span class="c1"># Call eagle logic to accept prev draft tokens and predict next draft tokens</span>
|
||||
<span class="n">next_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_eagle_fwd_helper</span><span class="p">(</span><span class="n">lm_logits</span><span class="p">,</span>
|
||||
<span class="n">all_hidden_states</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span>
|
||||
@ -1320,7 +1318,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e14c0b0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6226c00>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -2282,7 +2283,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e0fd940>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6a1afc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -294,7 +295,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6df889e0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6feb3e0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -456,7 +457,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e16af00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ea5730>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -279,14 +280,10 @@
|
||||
|
||||
<span class="k">return</span> <span class="p">{</span>
|
||||
<span class="o">**</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">to_dict</span><span class="p">(),</span>
|
||||
<span class="o">**</span><span class="p">{</span>
|
||||
<span class="n">f</span><span class="p">:</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">GEMMA_ADDED_FIELDS</span>
|
||||
<span class="p">},</span>
|
||||
<span class="o">**</span><span class="p">({</span>
|
||||
<span class="n">f</span><span class="p">:</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">GEMMA2_ADDED_FIELDS</span>
|
||||
<span class="p">}</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gemma_2</span> <span class="k">else</span> <span class="p">{})</span>
|
||||
<span class="o">**</span><span class="p">{</span><span class="n">f</span><span class="p">:</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">GEMMA_ADDED_FIELDS</span><span class="p">},</span>
|
||||
<span class="o">**</span><span class="p">({</span><span class="n">f</span><span class="p">:</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">GEMMA2_ADDED_FIELDS</span><span class="p">}</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_gemma_2</span> <span class="k">else</span> <span class="p">{})</span>
|
||||
<span class="p">}</span></div>
|
||||
|
||||
|
||||
@ -339,7 +336,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6ddbfb30>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6f33ef0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -182,10 +183,8 @@
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.quantization.mode</span> <span class="kn">import</span> <span class="p">(</span><span class="n">MODELOPT_FLOW_QUANTIZATIONS</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="p">)</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">..._common</span> <span class="kn">import</span> <span class="n">default_net</span>
|
||||
<span class="kn">from</span> <span class="nn">..._utils</span> <span class="kn">import</span> <span class="n">pad_vocab_size</span>
|
||||
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="p">(</span><span class="n">AllReduceFusionOp</span><span class="p">,</span> <span class="n">AllReduceParams</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">cast</span><span class="p">,</span>
|
||||
<span class="n">recv</span><span class="p">,</span> <span class="n">send</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">...functional</span> <span class="kn">import</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">cast</span><span class="p">,</span> <span class="n">recv</span><span class="p">,</span> <span class="n">send</span>
|
||||
<span class="kn">from</span> <span class="nn">...layers</span> <span class="kn">import</span> <span class="p">(</span><span class="n">Attention</span><span class="p">,</span> <span class="n">AttentionMaskType</span><span class="p">,</span> <span class="n">AttentionParams</span><span class="p">,</span>
|
||||
<span class="n">ColumnLinear</span><span class="p">,</span> <span class="n">Embedding</span><span class="p">,</span> <span class="n">GatedMLP</span><span class="p">,</span> <span class="n">KeyValueCacheParams</span><span class="p">,</span>
|
||||
<span class="n">LoraParams</span><span class="p">,</span> <span class="n">PositionEmbeddingType</span><span class="p">,</span> <span class="n">RmsNorm</span><span class="p">)</span>
|
||||
@ -212,7 +211,7 @@
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
|
||||
<span class="n">layers_range</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_layers</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">num_hidden_layers</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">local_layer_idx</span> <span class="o">=</span> <span class="n">layer_idx</span> <span class="o">-</span> <span class="n">layers_range</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">local_layer_idx</span> <span class="o">=</span> <span class="n">layer_idx</span> <span class="o">-</span> <span class="n">layers_range</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="mf">1.0</span>
|
||||
<span class="n">max_attn_value</span> <span class="o">=</span> <span class="mf">0.0</span>
|
||||
@ -225,7 +224,7 @@
|
||||
<span class="n">max_attn_value</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">attn_logit_softcapping</span> <span class="ow">or</span> <span class="mf">0.0</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">Attention</span><span class="p">(</span>
|
||||
<span class="n">local_layer_idx</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">local_layer_idx</span><span class="o">=</span><span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">hidden_size</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">num_attention_heads</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span><span class="p">,</span>
|
||||
<span class="n">num_kv_heads</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">num_key_value_heads</span><span class="p">,</span>
|
||||
@ -241,8 +240,7 @@
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="n">q_scaling</span><span class="p">,</span>
|
||||
<span class="n">max_attn_value</span><span class="o">=</span><span class="n">max_attn_value</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">max_attn_value</span><span class="o">=</span><span class="n">max_attn_value</span><span class="p">)</span>
|
||||
|
||||
<span class="n">mlp_hidden_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="mi">4</span> <span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">intermediate_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">config</span><span class="o">.</span><span class="n">intermediate_size</span>
|
||||
|
||||
@ -275,71 +273,36 @@
|
||||
<span class="n">use_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">kv_cache_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">KeyValueCacheParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">attention_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AttentionParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">next_layer_input_layernorm_args</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="c1"># assert not (</span>
|
||||
<span class="c1"># default_net().plugin_config.reduce_fusion and self.has_residual_mlp</span>
|
||||
<span class="c1"># ), "Custom all reduce and residual mlp can't be enabled at the same time."</span>
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">(</span>
|
||||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">local_layer_idx</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span> <span class="c1">#FIXME:AN need to check if appropriate residual value is hidden state is pulled out.</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="n">lora_layer_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraParams</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
|
||||
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
||||
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
||||
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="p">,</span>
|
||||
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">,</span>
|
||||
<span class="n">norm_before_bmm1</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">,</span>
|
||||
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">AllReduceParams</span><span class="p">(</span>
|
||||
<span class="n">fusion_op</span><span class="o">=</span><span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span>
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="k">else</span>
|
||||
<span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
|
||||
<span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
|
||||
<span class="n">norm_weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
|
||||
<span class="n">norm_pre_residual_weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pre_feedforward_layernorm</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span>
|
||||
<span class="n">value</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">eps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="o">.</span><span class="n">eps</span><span class="p">))</span>
|
||||
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
||||
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
||||
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="p">,</span>
|
||||
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">,</span>
|
||||
<span class="n">norm_before_bmm1</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="n">attention_output</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">attention_output</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">attention_output</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">attention_output</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span>
|
||||
|
||||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_feedforward_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">attention_output</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span>
|
||||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_feedforward_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">next_layer_input_layernorm_args</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">,</span>
|
||||
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">AllReduceParams</span><span class="p">(</span>
|
||||
<span class="n">fusion_op</span><span class="o">=</span><span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span>
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="k">else</span>
|
||||
<span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
|
||||
<span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
|
||||
<span class="n">norm_weight</span><span class="o">=</span><span class="n">next_layer_input_layernorm_args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">norm_pre_residual_weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">post_feedforward_layernorm</span><span class="o">.</span>
|
||||
<span class="n">weight</span><span class="o">.</span><span class="n">value</span><span class="p">,</span>
|
||||
<span class="n">eps</span><span class="o">=</span><span class="n">next_layer_input_layernorm_args</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_feedforward_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">inter_layernorms</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_feedforward_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span>
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||||
@ -551,7 +514,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e1ea360>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ea60f0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -503,7 +504,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e2ae060>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6eabe00>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -588,7 +589,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e199340>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6ff5490>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -232,7 +233,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e13a270>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac71bd400>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -384,7 +385,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e25bb00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74eca10>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -324,7 +325,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e0fe870>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6fead80>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -174,7 +175,6 @@
|
||||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||||
<span class="c1"># limitations under the License.</span>
|
||||
<span class="kn">import</span> <span class="nn">json</span>
|
||||
<span class="kn">import</span> <span class="nn">math</span>
|
||||
<span class="kn">import</span> <span class="nn">sys</span>
|
||||
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
|
||||
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
|
||||
@ -199,10 +199,6 @@
|
||||
<span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">moe</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">MoeConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">remove_duplicated_kv_heads</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">embedding_multiplier</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">attention_multiplier</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">residual_multiplier</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">output_multiplier_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mlp_bias</span> <span class="o">=</span> <span class="n">mlp_bias</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">attn_bias</span> <span class="o">=</span> <span class="n">attn_bias</span>
|
||||
@ -227,10 +223,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_input_layernorm_in_first_layer</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_last_layernorm</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layer_idx_offset</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_multiplier</span> <span class="o">=</span> <span class="n">embedding_multiplier</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">attention_multiplier</span> <span class="o">=</span> <span class="n">attention_multiplier</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">residual_multiplier</span> <span class="o">=</span> <span class="n">residual_multiplier</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">output_multiplier_scale</span> <span class="o">=</span> <span class="n">output_multiplier_scale</span>
|
||||
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||||
|
||||
@ -273,7 +265,7 @@
|
||||
<span class="n">hf_config</span> <span class="o">=</span> <span class="n">hf_config_or_dir</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hf_config_dir</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">hf_config_or_dir</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="s2">"vila"</span> <span class="ow">in</span> <span class="n">hf_config_dir</span><span class="o">.</span><span class="n">lower</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="s2">"vila"</span> <span class="ow">in</span> <span class="n">hf_config_dir</span><span class="p">:</span>
|
||||
<span class="n">sys</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">hf_config_dir</span> <span class="o">+</span> <span class="s2">"/../VILA"</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">llava.model</span> <span class="kn">import</span> <span class="n">LlavaLlamaConfig</span> <span class="c1"># noqa</span>
|
||||
<span class="kn">from</span> <span class="nn">llava.model</span> <span class="kn">import</span> <span class="n">LlavaLlamaModel</span>
|
||||
@ -323,14 +315,8 @@
|
||||
<span class="s1">'disable_weight_only_quant_plugin'</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">remove_duplicated_kv_heads</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'remove_duplicated_kv_heads'</span><span class="p">,</span>
|
||||
<span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">embedding_multiplier</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">hf_config</span><span class="p">,</span> <span class="s2">"embedding_multiplier"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
||||
<span class="n">attention_multiplier</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">hf_config</span><span class="p">,</span> <span class="s2">"attention_multiplier"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">attention_multiplier</span> <span class="o">!=</span> <span class="mf">1.0</span><span class="p">:</span>
|
||||
<span class="n">attention_multiplier</span> <span class="o">*=</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">head_size</span><span class="p">)</span>
|
||||
<span class="n">residual_multiplier</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">hf_config</span><span class="p">,</span> <span class="s2">"residual_multiplier"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
||||
<span class="n">output_multiplier_scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">hf_config</span><span class="p">,</span> <span class="s2">"logits_scaling"</span><span class="p">,</span>
|
||||
<span class="mf">1.0</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"mixtral"</span><span class="p">,</span> <span class="s2">"arctic"</span><span class="p">,</span> <span class="s2">"granitemoe"</span><span class="p">]:</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"mixtral"</span> <span class="ow">or</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"arctic"</span><span class="p">:</span>
|
||||
<span class="c1"># HF LLaMA-type models are implicitly using gated activation.</span>
|
||||
<span class="c1"># With our MoE implementation, we must make it explicit</span>
|
||||
<span class="n">hidden_act</span> <span class="o">=</span> <span class="s2">"swiglu"</span>
|
||||
@ -371,10 +357,6 @@
|
||||
<span class="n">quantization</span><span class="o">=</span><span class="n">quant_config</span><span class="p">,</span>
|
||||
<span class="n">remove_duplicated_kv_heads</span><span class="o">=</span><span class="n">remove_duplicated_kv_heads</span><span class="p">,</span>
|
||||
<span class="n">tie_word_embeddings</span><span class="o">=</span><span class="n">tie_word_embeddings</span><span class="p">,</span>
|
||||
<span class="n">embedding_multiplier</span><span class="o">=</span><span class="n">embedding_multiplier</span><span class="p">,</span>
|
||||
<span class="n">attention_multiplier</span><span class="o">=</span><span class="n">attention_multiplier</span><span class="p">,</span>
|
||||
<span class="n">residual_multiplier</span><span class="o">=</span><span class="n">residual_multiplier</span><span class="p">,</span>
|
||||
<span class="n">output_multiplier_scale</span><span class="o">=</span><span class="n">output_multiplier_scale</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
@ -446,7 +428,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e229490>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac6f30290>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -234,7 +235,6 @@
|
||||
<span class="n">tp_group</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">tp_rank</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">config</span><span class="o">.</span><span class="n">attention_multiplier</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
|
||||
<span class="n">cp_group</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_group</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||||
@ -362,7 +362,7 @@
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">residual</span> <span class="o">=</span> <span class="n">attention_output</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">residual_multiplier</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">attention_output</span>
|
||||
<span class="n">residual</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_layernorm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">next_layer_input_layernorm_args</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@ -397,7 +397,7 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">lora_layer_params</span><span class="o">=</span><span class="n">lora_layer_params</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">residual_multiplier</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">residual</span> <span class="o">+</span> <span class="n">hidden_states</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span><span class="p">)</span>
|
||||
@ -417,7 +417,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_embedding</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||||
<span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_multiplier</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">embedding_multiplier</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">DecoderLayerList</span><span class="p">(</span><span class="n">LLaMADecoderLayer</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
|
||||
|
||||
@ -460,7 +459,6 @@
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="o">*</span><span class="n">ptuning_args</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">*=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_multiplier</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">recv</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">prev_pp_rank</span><span class="p">())</span>
|
||||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">pp_reduce_scatter</span><span class="p">:</span>
|
||||
@ -738,7 +736,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e23cd40>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7724b00>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -652,7 +653,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e0fd460>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74d34a0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -286,7 +287,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e3563f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7727fb0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -415,7 +416,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e3a2c00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7817aa0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -201,7 +202,6 @@
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.models.modeling_utils</span> <span class="kn">import</span> <span class="n">PretrainedModel</span><span class="p">,</span> <span class="n">QuantConfig</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.module</span> <span class="kn">import</span> <span class="n">Module</span><span class="p">,</span> <span class="n">ModuleList</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.parameter</span> <span class="kn">import</span> <span class="n">Parameter</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.quantization</span> <span class="kn">import</span> <span class="n">QuantMode</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">.config</span> <span class="kn">import</span> <span class="n">MLLaMAConfig</span>
|
||||
|
||||
@ -223,36 +223,35 @@
|
||||
<span class="k">class</span> <span class="nc">CrossAttentionTransformerBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="o">*</span><span class="p">,</span>
|
||||
<span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||||
<span class="n">num_kv_heads</span><span class="p">,</span>
|
||||
<span class="n">head_size</span><span class="p">,</span>
|
||||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">has_attention_qkvo_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">has_mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">layernorm_position</span><span class="o">=</span><span class="n">LayerNormPositionType</span><span class="o">.</span><span class="n">pre_layernorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_type</span><span class="o">=</span><span class="n">LayerNormType</span><span class="o">.</span><span class="n">RmsNorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span>
|
||||
<span class="n">hidden_act</span><span class="o">=</span><span class="s2">"gated-silu"</span><span class="p">,</span>
|
||||
<span class="n">mlp_type</span><span class="o">=</span><span class="n">MLPType</span><span class="o">.</span><span class="n">GatedMLP</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">residual_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">max_distance</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">num_buckets</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">fp16_clamping</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_kv</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_implicit_relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="o">*</span><span class="p">,</span>
|
||||
<span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||||
<span class="n">num_kv_heads</span><span class="p">,</span>
|
||||
<span class="n">head_size</span><span class="p">,</span>
|
||||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">has_attention_qkvo_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">has_mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">layernorm_position</span><span class="o">=</span><span class="n">LayerNormPositionType</span><span class="o">.</span><span class="n">pre_layernorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_type</span><span class="o">=</span><span class="n">LayerNormType</span><span class="o">.</span><span class="n">RmsNorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span>
|
||||
<span class="n">hidden_act</span><span class="o">=</span><span class="s2">"gated-silu"</span><span class="p">,</span>
|
||||
<span class="n">mlp_type</span><span class="o">=</span><span class="n">MLPType</span><span class="o">.</span><span class="n">GatedMLP</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">residual_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">max_distance</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">num_buckets</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">fp16_clamping</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_kv</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_implicit_relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">local_layer_idx</span> <span class="o">=</span> <span class="n">local_layer_idx</span>
|
||||
@ -287,7 +286,6 @@
|
||||
<span class="n">qk_layernorm</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">layernorm_type</span><span class="o">=</span><span class="n">layernorm_type</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="n">layer_idx_in_cache_pool</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span> <span class="o">=</span> <span class="n">ln_type</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
@ -455,36 +453,35 @@
|
||||
<span class="k">class</span> <span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="o">*</span><span class="p">,</span>
|
||||
<span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||||
<span class="n">num_kv_heads</span><span class="p">,</span>
|
||||
<span class="n">head_size</span><span class="p">,</span>
|
||||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">has_attention_qkvo_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">has_mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">layernorm_position</span><span class="o">=</span><span class="n">LayerNormPositionType</span><span class="o">.</span><span class="n">pre_layernorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_type</span><span class="o">=</span><span class="n">LayerNormType</span><span class="o">.</span><span class="n">RmsNorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span>
|
||||
<span class="n">hidden_act</span><span class="o">=</span><span class="s2">"gated-silu"</span><span class="p">,</span>
|
||||
<span class="n">mlp_type</span><span class="o">=</span><span class="n">MLPType</span><span class="o">.</span><span class="n">GatedMLP</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">residual_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">max_distance</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">num_buckets</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">fp16_clamping</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_kv</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_implicit_relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="o">*</span><span class="p">,</span>
|
||||
<span class="n">local_layer_idx</span><span class="p">,</span>
|
||||
<span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">ffn_hidden_size</span><span class="p">,</span>
|
||||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||||
<span class="n">num_kv_heads</span><span class="p">,</span>
|
||||
<span class="n">head_size</span><span class="p">,</span>
|
||||
<span class="n">max_position_embeddings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">q_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">has_attention_qkvo_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">has_mlp_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">layernorm_position</span><span class="o">=</span><span class="n">LayerNormPositionType</span><span class="o">.</span><span class="n">pre_layernorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_type</span><span class="o">=</span><span class="n">LayerNormType</span><span class="o">.</span><span class="n">RmsNorm</span><span class="p">,</span>
|
||||
<span class="n">layernorm_eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span>
|
||||
<span class="n">hidden_act</span><span class="o">=</span><span class="s2">"gated-silu"</span><span class="p">,</span>
|
||||
<span class="n">mlp_type</span><span class="o">=</span><span class="n">MLPType</span><span class="o">.</span><span class="n">GatedMLP</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">Mapping</span><span class="p">(),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">residual_scaling</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
|
||||
<span class="n">relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">max_distance</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">num_buckets</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">fp16_clamping</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_kv</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_implicit_relative_attention</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">local_layer_idx</span> <span class="o">=</span> <span class="n">local_layer_idx</span>
|
||||
@ -518,7 +515,6 @@
|
||||
<span class="n">rotary_embedding_base</span><span class="o">=</span><span class="n">rotary_embedding_base</span><span class="p">,</span>
|
||||
<span class="n">rotary_embedding_scaling</span><span class="o">=</span><span class="n">rotary_embedding_scaling</span><span class="p">,</span>
|
||||
<span class="n">layer_idx_in_cache_pool</span><span class="o">=</span><span class="n">layer_idx_in_cache_pool</span><span class="p">,</span>
|
||||
<span class="n">quant_mode</span><span class="o">=</span><span class="n">quant_mode</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">input_layernorm</span> <span class="o">=</span> <span class="n">ln_type</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
@ -632,15 +628,22 @@
|
||||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">MLLaMAModel</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="MLLaMAModel">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel">[docs]</a>
|
||||
<span class="k">class</span> <span class="nc">MLLaMAModel</span><span class="p">(</span><span class="n">PretrainedModel</span><span class="p">):</span>
|
||||
<span class="n">config_class</span> <span class="o">=</span> <span class="n">MLLaMAConfig</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">MLLaMAConfig</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">MLLaMAConfig</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
|
||||
<span class="n">Attention</span><span class="o">.</span><span class="n">create_attention_const_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">position_embedding_type</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span>
|
||||
|
||||
<span class="n">type_vocab_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">type_vocab_size</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="o">=</span> <span class="p">(</span><span class="n">type_vocab_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rescale_before_lm_head</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rescale_before_lm_head</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layernorm_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">layernorm_type</span>
|
||||
<span class="n">ln_type</span> <span class="o">=</span> <span class="n">layernorm_map</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">layernorm_type</span><span class="p">]</span>
|
||||
|
||||
@ -666,6 +669,8 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">head_size</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="o">=</span> <span class="n">type_vocab_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">fp16_clamping</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">skip_cross_kv</span>
|
||||
@ -714,7 +719,6 @@
|
||||
<span class="s2">"skip_cross_kv"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">,</span>
|
||||
<span class="s2">"rotary_embedding_base"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_base</span><span class="p">,</span>
|
||||
<span class="s2">"rotary_embedding_scaling"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_scaling</span><span class="p">,</span>
|
||||
<span class="s2">"quant_mode"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">if</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention_layers</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="n">layers_range</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"not support PP now"</span>
|
||||
@ -731,21 +735,44 @@
|
||||
<span class="n">num_kv_heads_per_layer</span><span class="p">[:</span><span class="n">local_layer_idx</span><span class="p">]</span><span class="o">.</span>
|
||||
<span class="n">count</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">)))</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">(</span><span class="n">_layers</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_layers</span> <span class="o">=</span> <span class="n">ModuleList</span><span class="p">(</span><span class="n">_layers</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_model_final_layernorm</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span> <span class="o">=</span> <span class="n">ln_type</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">eps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">norm_epsilon</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">ColumnLinear</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||||
<span class="n">bias</span><span class="o">=</span><span class="kc">False</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s2">"has_lm_head_bias"</span><span class="p">)</span> <span class="k">else</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">has_lm_head_bias</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||||
<span class="n">tp_group</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">trtllm_modules_to_hf_modules</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="o">**</span><span class="n">get_default_trtllm_modules_to_hf_modules</span><span class="p">(),</span>
|
||||
<span class="s2">"attn_q"</span><span class="p">:</span> <span class="s2">"self_attn.q_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_k"</span><span class="p">:</span> <span class="s2">"self_attn.k_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_v"</span><span class="p">:</span> <span class="s2">"self_attn.v_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_dense"</span><span class="p">:</span> <span class="s2">"self_attn.o_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_q"</span><span class="p">:</span> <span class="s2">"encoder_attn.q_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_k"</span><span class="p">:</span> <span class="s2">"encoder_attn.k_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_v"</span><span class="p">:</span> <span class="s2">"encoder_attn.v_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_dense"</span><span class="p">:</span> <span class="s2">"encoder_attn.o_proj"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">relative_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_implicit_relative_attention</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rel_attn_table</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span>
|
||||
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_buckets</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span><span class="p">)</span>
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAModel.forward">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel.forward">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">decoder_input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||||
@ -769,15 +796,19 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||||
|
||||
<span class="n">attention_params</span> <span class="o">=</span> <span class="n">Attention</span><span class="o">.</span><span class="n">fill_attention_params</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">attention_params</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_embedding</span><span class="p">(</span><span class="n">decoder_input_ids</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="s1">'embedding_layer_output'</span><span class="p">,</span>
|
||||
<span class="n">hidden_states</span><span class="p">)</span>
|
||||
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">recv</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">prev_pp_rank</span><span class="p">())</span>
|
||||
|
||||
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">fill_none_tensor_list</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">))</span>
|
||||
<span class="n">kv_cache_params</span><span class="o">.</span><span class="n">fill_none_tensor_list</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_layers</span><span class="p">))</span>
|
||||
|
||||
<span class="n">full_text_row_masked_out_mask</span> <span class="o">=</span> <span class="n">reduce</span><span class="p">(</span>
|
||||
<span class="p">(</span><span class="n">attention_mask_params</span><span class="o">.</span><span class="n">cross_attention_mask</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span>
|
||||
@ -794,7 +825,7 @@
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="n">presents</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">decoder_layer</span><span class="p">,</span> <span class="n">past</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
|
||||
<span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">,</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span><span class="p">)):</span>
|
||||
<span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_layers</span><span class="p">,</span> <span class="n">kv_cache_params</span><span class="o">.</span><span class="n">past_key_value</span><span class="p">)):</span>
|
||||
|
||||
<span class="n">lora_layer_params</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">lora_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lora_params</span><span class="o">.</span><span class="n">lora_ranks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@ -845,135 +876,14 @@
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_model_final_layernorm</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ln_f</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">send</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">next_pp_rank</span><span class="p">())</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">presents</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">precompute_relative_attention_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">build_config</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">relative_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_implicit_relative_attention</span><span class="p">:</span>
|
||||
<span class="n">relative_attention_bias_builder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">relative_attention_bias</span>
|
||||
<span class="n">rel_attn_precomputed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">rel_attn_table</span> <span class="o">=</span> <span class="n">numpy_to_torch</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rel_attn_table</span><span class="o">.</span><span class="n">raw_value</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">relative_attention_bias_builder</span><span class="p">(</span>
|
||||
<span class="n">rel_attn_precomputed</span><span class="p">,</span>
|
||||
<span class="n">rel_attn_table</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_buckets</span><span class="p">,</span>
|
||||
<span class="kc">False</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_distance</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">]</span><span class="o">.</span><span class="n">self_attention</span><span class="o">.</span><span class="n">set_rel_attn_table</span><span class="p">(</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="n">rel_attn_precomputed</span><span class="p">)</span>
|
||||
|
||||
|
||||
<span class="c1"># TODO try to inherit the DecoderModelForCausalLM</span>
|
||||
<div class="viewcode-block" id="MLLaMAForCausalLM">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAForCausalLM">[docs]</a>
|
||||
<span class="k">class</span> <span class="nc">MLLaMAForCausalLM</span><span class="p">(</span><span class="n">PretrainedModel</span><span class="p">):</span>
|
||||
<span class="n">config_class</span> <span class="o">=</span> <span class="n">MLLaMAConfig</span>
|
||||
|
||||
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">MLLaMAConfig</span><span class="p">):</span>
|
||||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
|
||||
<span class="n">Attention</span><span class="o">.</span><span class="n">create_attention_const_params</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">position_embedding_type</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">transformer</span> <span class="o">=</span> <span class="n">MLLaMAModel</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_model_final_layernorm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">has_model_final_layernorm</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_kv_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">logits_dtype</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">ColumnLinear</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
||||
<span class="n">bias</span><span class="o">=</span><span class="kc">False</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s2">"has_lm_head_bias"</span><span class="p">)</span> <span class="k">else</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">has_lm_head_bias</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||||
<span class="n">tp_group</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">gather_output</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">trtllm_modules_to_hf_modules</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="o">**</span><span class="n">get_default_trtllm_modules_to_hf_modules</span><span class="p">(),</span>
|
||||
<span class="s2">"attn_q"</span><span class="p">:</span> <span class="s2">"self_attn.q_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_k"</span><span class="p">:</span> <span class="s2">"self_attn.k_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_v"</span><span class="p">:</span> <span class="s2">"self_attn.v_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"attn_dense"</span><span class="p">:</span> <span class="s2">"self_attn.o_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_q"</span><span class="p">:</span> <span class="s2">"encoder_attn.q_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_k"</span><span class="p">:</span> <span class="s2">"encoder_attn.k_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_v"</span><span class="p">:</span> <span class="s2">"encoder_attn.v_proj"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attn_dense"</span><span class="p">:</span> <span class="s2">"encoder_attn.o_proj"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAForCausalLM.forward">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAForCausalLM.forward">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">decoder_input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||||
<span class="n">use_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">attention_mask_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">last_token_ids</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">kv_cache_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">attention_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">hidden_states</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">lora_params</span><span class="p">:</span> <span class="n">LoraParams</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">cross_kv_cache_gen</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">cross_kv_reuse</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">prompt_tasks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">decoder_input_ids</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||||
<span class="n">attention_params</span> <span class="o">=</span> <span class="n">Attention</span><span class="o">.</span><span class="n">fill_attention_params</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="p">,</span> <span class="n">attention_params</span><span class="p">)</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="p">(</span>
|
||||
<span class="n">decoder_input_ids</span><span class="o">=</span><span class="n">decoder_input_ids</span><span class="p">,</span>
|
||||
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
|
||||
<span class="n">use_cache</span><span class="o">=</span><span class="n">use_cache</span><span class="p">,</span>
|
||||
<span class="n">attention_mask_params</span><span class="o">=</span><span class="n">attention_mask_params</span><span class="p">,</span>
|
||||
<span class="n">last_token_ids</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">,</span>
|
||||
<span class="n">kv_cache_params</span><span class="o">=</span><span class="n">kv_cache_params</span><span class="p">,</span>
|
||||
<span class="n">attention_params</span><span class="o">=</span><span class="n">attention_params</span><span class="p">,</span>
|
||||
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
|
||||
<span class="n">lora_params</span><span class="o">=</span><span class="n">lora_params</span><span class="p">,</span>
|
||||
<span class="n">cross_kv_cache_gen</span><span class="o">=</span><span class="n">cross_kv_cache_gen</span><span class="p">,</span>
|
||||
<span class="n">cross_kv_reuse</span><span class="o">=</span><span class="n">cross_kv_reuse</span><span class="p">,</span>
|
||||
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
|
||||
<span class="n">prompt_tasks</span><span class="o">=</span><span class="n">prompt_tasks</span><span class="p">,</span>
|
||||
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
|
||||
<span class="n">skip_cross_attn_blocks</span><span class="o">=</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">presents</span> <span class="o">=</span> <span class="n">hidden_states</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
||||
<span class="k">pass</span>
|
||||
<span class="c1"># [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size]</span>
|
||||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather_last_token_logits</span><span class="p">(</span>
|
||||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">,</span>
|
||||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">register_network_output</span><span class="p">(</span><span class="s1">'logits_before_lmhead'</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># [bs, hidden_size] -> [bs, vocab_size]</span>
|
||||
<span class="n">lm_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
|
||||
@ -998,8 +908,8 @@
|
||||
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAForCausalLM.prepare_inputs">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAForCausalLM.prepare_inputs">[docs]</a>
|
||||
<div class="viewcode-block" id="MLLaMAModel.prepare_inputs">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel.prepare_inputs">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">prepare_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||||
<span class="n">max_batch_size</span><span class="p">,</span>
|
||||
<span class="n">max_beam_width</span><span class="p">,</span>
|
||||
@ -1022,11 +932,15 @@
|
||||
<span class="c1"># Prepare inputs</span>
|
||||
<span class="n">max_output_len</span> <span class="o">=</span> <span class="n">max_decoder_input_len</span> <span class="o">+</span> <span class="n">max_seq_len</span>
|
||||
|
||||
<span class="n">head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">head_size</span>
|
||||
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">num_kv_heads</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">-</span>
|
||||
<span class="n">head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span>
|
||||
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_kv_heads</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">-</span>
|
||||
<span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||||
|
||||
<span class="n">encoder_head_size</span> <span class="o">=</span> <span class="n">head_size</span>
|
||||
<span class="c1"># TODO check</span>
|
||||
<span class="c1"># encoder_head_size = self.encoder_head_size</span>
|
||||
<span class="c1"># encoder_num_kv_heads = (self.encoder_num_kv_heads + self.mapping.tp_size</span>
|
||||
<span class="c1"># - 1) // self.mapping.tp_size</span>
|
||||
<span class="n">encoder_head_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span>
|
||||
<span class="n">encoder_num_kv_heads</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
|
||||
|
||||
<span class="n">bb_range</span> <span class="o">=</span> <span class="p">[</span>
|
||||
@ -1288,7 +1202,7 @@
|
||||
<span class="p">]),</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">layers_range</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_layers</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">total_num_layers</span><span class="p">)</span>
|
||||
<span class="n">layers_range</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_layers</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_num_layers</span><span class="p">)</span>
|
||||
<span class="n">num_pp_layers</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">layers_range</span><span class="p">)</span>
|
||||
|
||||
<span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
@ -1573,10 +1487,10 @@
|
||||
<span class="p">(</span><span class="s1">'boolean'</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">]),</span>
|
||||
<span class="p">]))</span>
|
||||
<span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">num_heads</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">-</span>
|
||||
<span class="n">num_heads</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">-</span>
|
||||
<span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||||
<span class="n">cross_kv_out_dim</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_kv_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">head_size</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
|
||||
<span class="n">cross_kv_out_dim</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_kv_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||||
<span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"cross_kv_reuse"</span><span class="p">,</span>
|
||||
@ -1617,14 +1531,15 @@
|
||||
<span class="mi">1</span><span class="p">,</span> <span class="n">prompt_embedding_table_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">prompt_embedding_table_size</span>
|
||||
<span class="p">]]</span>
|
||||
|
||||
<span class="n">prompt_embedding_table</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s1">'prompt_embedding_table'</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span><span class="p">,</span>
|
||||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">],</span>
|
||||
<span class="n">dim_range</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="s1">'prompt_embedding_table_size'</span><span class="p">,</span> <span class="n">p_embedding_range</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="s1">'hidden_size'</span><span class="p">,</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">transformer</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">]),</span>
|
||||
<span class="p">]))</span>
|
||||
<span class="n">prompt_embedding_table</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'prompt_embedding_table'</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span><span class="p">,</span>
|
||||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">],</span>
|
||||
<span class="n">dim_range</span><span class="o">=</span><span class="n">OrderedDict</span><span class="p">([</span>
|
||||
<span class="p">(</span><span class="s1">'prompt_embedding_table_size'</span><span class="p">,</span>
|
||||
<span class="n">p_embedding_range</span><span class="p">),</span>
|
||||
<span class="p">(</span><span class="s1">'hidden_size'</span><span class="p">,</span>
|
||||
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">]),</span>
|
||||
<span class="p">]))</span>
|
||||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||||
<span class="n">num_tokens_range</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="mi">1</span><span class="p">,</span>
|
||||
@ -1672,14 +1587,41 @@
|
||||
<span class="k">return</span> <span class="n">result</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAForCausalLM.use_lora">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAForCausalLM.use_lora">[docs]</a>
|
||||
<div class="viewcode-block" id="MLLaMAModel.use_lora">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel.use_lora">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">use_lora</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lora_config</span><span class="p">:</span> <span class="n">LoraConfig</span><span class="p">):</span>
|
||||
<span class="n">use_lora</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lora_config</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">trtllm_modules_to_hf_modules</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAForCausalLM.from_hugging_face">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAForCausalLM.from_hugging_face">[docs]</a>
|
||||
<div class="viewcode-block" id="MLLaMAModel.precompute_relative_attention_bias">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel.precompute_relative_attention_bias">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">precompute_relative_attention_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">build_config</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">relative_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_implicit_relative_attention</span><span class="p">:</span>
|
||||
<span class="n">relative_attention_bias_builder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">relative_attention_bias</span>
|
||||
<span class="n">rel_attn_precomputed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
||||
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">rel_attn_table</span> <span class="o">=</span> <span class="n">numpy_to_torch</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">rel_attn_table</span><span class="o">.</span><span class="n">raw_value</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">relative_attention_bias_builder</span><span class="p">(</span>
|
||||
<span class="n">rel_attn_precomputed</span><span class="p">,</span>
|
||||
<span class="n">rel_attn_table</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_attention_heads</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_buckets</span><span class="p">,</span>
|
||||
<span class="kc">False</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_distance</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_layers</span><span class="p">[</span>
|
||||
<span class="n">layer_idx</span><span class="p">]</span><span class="o">.</span><span class="n">self_attention</span><span class="o">.</span><span class="n">set_rel_attn_table</span><span class="p">(</span>
|
||||
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="n">rel_attn_precomputed</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MLLaMAModel.from_hugging_face">
|
||||
<a class="viewcode-back" href="../../../../python-api/tensorrt_llm.models.html#tensorrt_llm.models.MLLaMAModel.from_hugging_face">[docs]</a>
|
||||
<span class="nd">@classmethod</span>
|
||||
<span class="k">def</span> <span class="nf">from_hugging_face</span><span class="p">(</span>
|
||||
<span class="bp">cls</span><span class="p">,</span>
|
||||
@ -1688,7 +1630,7 @@
|
||||
<span class="n">mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Mapping</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">quant_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">QuantConfig</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">''' Create a MLLaMAForCausalLM object from give parameters</span>
|
||||
<span class="w"> </span><span class="sd">''' Create a MLLaMAModel object from give parameters</span>
|
||||
<span class="sd"> '''</span>
|
||||
<span class="kn">import</span> <span class="nn">transformers</span>
|
||||
|
||||
@ -1714,11 +1656,11 @@
|
||||
|
||||
<span class="n">custom_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s2">"lm_head"</span><span class="p">:</span> <span class="s2">"language_model.lm_head"</span><span class="p">,</span>
|
||||
<span class="s2">"transformer.ln_f"</span><span class="p">:</span> <span class="s2">"language_model.model.norm"</span><span class="p">,</span>
|
||||
<span class="s2">"transformer"</span><span class="p">:</span> <span class="s2">"language_model.model"</span><span class="p">,</span>
|
||||
<span class="s2">"ln_f"</span><span class="p">:</span> <span class="s2">"language_model.model.norm"</span><span class="p">,</span>
|
||||
<span class="s2">"decoder_layers"</span><span class="p">:</span> <span class="s2">"language_model.model.layers"</span><span class="p">,</span>
|
||||
<span class="s2">"self_attention"</span><span class="p">:</span> <span class="s2">"self_attn"</span><span class="p">,</span>
|
||||
<span class="s2">"cross_attention"</span><span class="p">:</span> <span class="s2">"cross_attn"</span><span class="p">,</span>
|
||||
<span class="s2">"vocab_embedding"</span><span class="p">:</span> <span class="s2">"embed_tokens"</span><span class="p">,</span>
|
||||
<span class="s2">"vocab_embedding"</span><span class="p">:</span> <span class="s2">"language_model.model.embed_tokens"</span><span class="p">,</span>
|
||||
<span class="s2">"gate_attn"</span><span class="p">:</span> <span class="s2">"cross_attn_attn_gate"</span><span class="p">,</span>
|
||||
<span class="s2">"gate_ffwd"</span><span class="p">:</span> <span class="s2">"cross_attn_mlp_gate"</span><span class="p">,</span>
|
||||
<span class="s2">"q_layernorm"</span><span class="p">:</span> <span class="s2">"q_norm"</span><span class="p">,</span>
|
||||
@ -1744,7 +1686,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e01af00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac71bddc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -314,8 +315,8 @@
|
||||
<span class="nd">@property</span>
|
||||
<span class="k">def</span> <span class="nf">requires_modelopt_quantization</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16_AWQ</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">MIXED_PRECISION</span>
|
||||
<span class="p">]:</span>
|
||||
<span class="k">return</span> <span class="kc">True</span>
|
||||
@ -336,11 +337,10 @@
|
||||
<span class="n">algo_to_modelopt_map</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A16</span><span class="p">:</span> <span class="s2">"int8_wo"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16</span><span class="p">:</span> <span class="s2">"int4_wo"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">:</span> <span class="s2">"nvfp4"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span> <span class="s2">"fp8"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16_AWQ</span><span class="p">:</span> <span class="s2">"int4_awq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">:</span> <span class="s2">"w4a8_awq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">:</span> <span class="s2">"int8_sq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">:</span> <span class="s1">'w4a8_awq'</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span> <span class="s1">'fp8'</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">:</span> <span class="s1">'int8_sq'</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">!=</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">MIXED_PRECISION</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"We don't support mixed precision in QuantConfig"</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@ -448,11 +448,10 @@
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_modelopt_qformat</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="n">algo_to_modelopt_map</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">:</span> <span class="s2">"nvfp4"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span> <span class="s2">"fp8"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16_AWQ</span><span class="p">:</span> <span class="s2">"int4_awq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">:</span> <span class="s2">"w4a8_awq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">:</span> <span class="s2">"int8_sq"</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">:</span> <span class="s1">'w4a8_awq'</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span> <span class="s1">'fp8'</span><span class="p">,</span>
|
||||
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A8_SQ_PER_CHANNEL</span><span class="p">:</span> <span class="s1">'int8_sq'</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">MIXED_PRECISION</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"We only support mixed precision quantization in LayerQuantConfig"</span>
|
||||
<span class="n">autoq_format</span> <span class="o">=</span> <span class="s1">','</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
|
||||
@ -675,8 +674,15 @@
|
||||
|
||||
<div class="viewcode-block" id="PretrainedConfig.set_rank">
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.models.html#tensorrt_llm.llmapi.PretrainedConfig.set_rank">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">set_rank</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">rank</span></div>
|
||||
<span class="k">def</span> <span class="nf">set_rank</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rank</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">Mapping</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span>
|
||||
<span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_ep_size</span><span class="p">,</span>
|
||||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">)</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="PretrainedConfig.get_config_group">
|
||||
@ -867,10 +873,8 @@
|
||||
<span class="n">config</span><span class="o">.</span><span class="n">set_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span>
|
||||
|
||||
<span class="n">rank</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
|
||||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">auto_parallel</span><span class="p">:</span>
|
||||
<span class="n">rank</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">elif</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="c1"># tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt</span>
|
||||
<span class="c1"># tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt</span>
|
||||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span>
|
||||
<span class="n">rank</span> <span class="o">=</span> <span class="n">rank</span> <span class="o">%</span> <span class="n">tp_size</span> <span class="o">+</span> <span class="n">rank</span> <span class="o">//</span> <span class="p">(</span><span class="n">tp_size</span> <span class="o">*</span> <span class="n">cp_size</span><span class="p">)</span> <span class="o">*</span> <span class="n">tp_size</span>
|
||||
@ -975,7 +979,7 @@
|
||||
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">opt_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||||
<span class="n">num_hidden_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_sin_cos_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||||
<span class="p">):</span>
|
||||
<span class="w"> </span><span class="sd">'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the</span>
|
||||
<span class="sd"> ranges of the dimensions of when using TRT dynamic shapes.</span>
|
||||
@ -1040,7 +1044,7 @@
|
||||
<span class="n">streamingllm</span><span class="o">=</span><span class="n">streamingllm</span><span class="p">,</span>
|
||||
<span class="n">opt_batch_size</span><span class="o">=</span><span class="n">opt_batch_size</span><span class="p">,</span>
|
||||
<span class="n">pp_reduce_scatter</span><span class="o">=</span><span class="n">pp_reduce_scatter</span><span class="p">,</span>
|
||||
<span class="n">mrope_rotary_cos_sin_size</span><span class="o">=</span><span class="n">mrope_rotary_cos_sin_size</span><span class="p">)</span>
|
||||
<span class="n">mrope_rotary_sin_cos_size</span><span class="o">=</span><span class="n">mrope_rotary_sin_cos_size</span><span class="p">)</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s1">'input_ids'</span><span class="p">:</span>
|
||||
@ -1672,12 +1676,7 @@
|
||||
<span class="n">layer</span><span class="o">.</span><span class="n">dense</span><span class="p">,</span> <span class="s1">'activation_scaling_factor'</span><span class="p">):</span>
|
||||
<span class="n">scale</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">]</span> <span class="o">/</span> <span class="n">layer</span><span class="o">.</span><span class="n">dense</span><span class="o">.</span><span class="n">activation_scaling_factor</span><span class="o">.</span><span class="n">raw_value</span>
|
||||
<span class="n">layer</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span>
|
||||
<span class="n">value</span><span class="o">=</span><span class="n">scale</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">'float32'</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">Attention</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span>
|
||||
<span class="n">layer</span><span class="o">.</span><span class="n">dense</span><span class="p">,</span> <span class="s1">'activation_scaling_factor_2'</span><span class="p">):</span>
|
||||
<span class="n">scale</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">dense</span><span class="o">.</span><span class="n">activation_scaling_factor_2</span><span class="o">.</span><span class="n">raw_value</span>
|
||||
<span class="n">layer</span><span class="o">.</span><span class="n">attention_output_orig_quant_scale</span> <span class="o">=</span> <span class="n">Parameter</span><span class="p">(</span>
|
||||
<span class="n">value</span><span class="o">=</span><span class="n">scale</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">'float32'</span><span class="p">)</span>
|
||||
<span class="n">value</span><span class="o">=</span><span class="n">scale</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span>
|
||||
<span class="k">return</span> <span class="n">model</span>
|
||||
|
||||
|
||||
@ -1749,18 +1748,16 @@
|
||||
<span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">qkv</span><span class="p">)</span> <span class="o">==</span> <span class="n">ColumnLinear</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="n">attn</span><span class="o">.</span><span class="n">qkv</span><span class="p">)</span> <span class="o">==</span> <span class="n">FP8Linear</span><span class="p">):</span>
|
||||
<span class="n">old_qkv</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">qkv</span>
|
||||
<span class="n">linear_class</span> <span class="o">=</span> <span class="nb">type</span><span class="p">(</span><span class="n">old_qkv</span><span class="p">)</span>
|
||||
<span class="n">new_kv</span> <span class="o">=</span> <span class="n">linear_class</span><span class="p">(</span>
|
||||
<span class="n">in_features</span><span class="o">=</span><span class="n">attn</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">out_features</span><span class="o">=</span><span class="mi">2</span> <span class="o">*</span> <span class="n">attn</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span> <span class="n">attn</span><span class="o">.</span><span class="n">num_attention_kv_heads</span> <span class="o">*</span>
|
||||
<span class="n">attn</span><span class="o">.</span><span class="n">attention_head_size</span><span class="p">,</span>
|
||||
<span class="n">bias</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||||
<span class="n">tp_group</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">gather_output</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">gather_output</span><span class="p">,</span>
|
||||
<span class="n">prefer_managed_weight</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">prefer_managed_weight</span><span class="p">,</span>
|
||||
<span class="n">is_qkv</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">is_qkv</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">new_kv</span> <span class="o">=</span> <span class="n">linear_class</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">attn</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
|
||||
<span class="n">out_features</span><span class="o">=</span><span class="mi">2</span> <span class="o">*</span> <span class="n">attn</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span>
|
||||
<span class="n">attn</span><span class="o">.</span><span class="n">num_attention_kv_heads</span> <span class="o">*</span>
|
||||
<span class="n">attn</span><span class="o">.</span><span class="n">attention_head_size</span><span class="p">,</span>
|
||||
<span class="n">bias</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||||
<span class="n">tp_group</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">tp_group</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||||
<span class="n">gather_output</span><span class="o">=</span><span class="n">old_qkv</span><span class="o">.</span><span class="n">gather_output</span><span class="p">,</span>
|
||||
<span class="n">is_qkv</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="n">old_qkv_weight_value</span> <span class="o">=</span> <span class="n">old_qkv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">raw_value</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">old_qkv_weight_value</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">([</span>
|
||||
@ -1866,9 +1863,6 @@
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span>
|
||||
<span class="s1">'weights_scaling_factor'</span><span class="p">,</span> <span class="s1">'alpha'</span>
|
||||
<span class="p">)]</span> <span class="o">=</span> <span class="n">activation_scaling_factor</span> <span class="o">*</span> <span class="n">weights_scaling_factor_2</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">'weights_scaling_factor'</span><span class="p">,</span>
|
||||
<span class="s1">'activation_scaling_factor'</span>
|
||||
<span class="p">)]</span> <span class="o">=</span> <span class="n">activation_scaling_factor</span>
|
||||
|
||||
<span class="c1"># FP8</span>
|
||||
<span class="k">elif</span> <span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span>
|
||||
@ -1891,33 +1885,7 @@
|
||||
<span class="n">model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="n">weights</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'lm_head.weights_scaling_factor'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="n">weights</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'lm_head.activation_scaling_factor'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="c1"># FP4</span>
|
||||
<span class="k">elif</span> <span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">:</span>
|
||||
<span class="c1"># Interleave block scale for NVFP4 plugin.</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">weights</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'weights_scaling_factor'</span><span class="p">):</span>
|
||||
<span class="n">ori_shape</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
|
||||
<span class="n">weights</span><span class="p">[</span>
|
||||
<span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">nvfp4_block_scale_interleave</span><span class="p">(</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">int8</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">contiguous</span><span class="p">())</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
||||
<span class="n">ori_shape</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'activation_scaling_factor'</span><span class="p">):</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span> <span class="o">+</span> <span class="s1">'_2'</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
<span class="n">weights</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'weights_scaling_factor_2'</span><span class="p">):</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'weight'</span><span class="p">)</span> <span class="ow">and</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">uint8</span><span class="p">:</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">weights</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'weights_scaling_factor_2'</span><span class="p">):</span>
|
||||
<span class="n">weight_global_sf</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
||||
<span class="n">act_global_sf</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span>
|
||||
<span class="s1">'weights_scaling_factor_2'</span><span class="p">,</span> <span class="s1">'activation_scaling_factor_2'</span><span class="p">)]</span>
|
||||
<span class="n">weights</span><span class="p">[</span><span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span>
|
||||
<span class="s1">'weights_scaling_factor_2'</span><span class="p">,</span>
|
||||
<span class="s1">'alpha'</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">act_global_sf</span> <span class="o">*</span> <span class="n">weight_global_sf</span><span class="p">)</span>
|
||||
<span class="n">weights</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||||
|
||||
<span class="k">elif</span> <span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span><span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A16</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W8A16</span><span class="p">]:</span>
|
||||
<span class="n">weights</span> <span class="o">=</span> <span class="n">weight_only_quantize_dict</span><span class="p">(</span><span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">,</span>
|
||||
<span class="n">quant_algo</span><span class="o">=</span><span class="n">quant_algo</span><span class="p">,</span>
|
||||
@ -1940,20 +1908,6 @@
|
||||
|
||||
<span class="n">pattern_info</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'fc'</span><span class="p">,</span> <span class="s1">'gate'</span><span class="p">,</span> <span class="s1">'proj'</span><span class="p">,</span> <span class="s1">'qkv'</span><span class="p">,</span> <span class="s1">'dense'</span><span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">add_kv_cache_rcp_scaling_factor</span><span class="p">(</span><span class="n">weights</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
||||
<span class="n">new_entries</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="c1"># The unified converter generate_tllm_weights() already generates these rcp weights, but legacy</span>
|
||||
<span class="c1"># converters do not. Handle it here.</span>
|
||||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">endswith</span><span class="p">(</span><span class="s1">'.kv_cache_scaling_factor'</span><span class="p">):</span>
|
||||
<span class="n">rcp_name</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">'kv_cache_scaling_factor'</span><span class="p">,</span>
|
||||
<span class="s1">'kv_cache_rcp_scaling_factor'</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">rcp_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">weights</span><span class="p">:</span>
|
||||
<span class="n">new_entries</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">rcp_name</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">reciprocal</span><span class="p">(</span><span class="n">param</span><span class="p">)))</span>
|
||||
<span class="n">weights</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">new_entries</span><span class="p">)</span>
|
||||
|
||||
<span class="n">add_kv_cache_rcp_scaling_factor</span><span class="p">(</span><span class="n">weights</span><span class="p">)</span>
|
||||
|
||||
<span class="n">per_layer_weights</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
@ -2033,7 +1987,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e4a5700>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7958170>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -356,7 +357,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e23cb00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74d3800>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -360,7 +361,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e138c50>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7959af0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -390,7 +391,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e205400>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac940dcd0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -204,8 +205,8 @@
|
||||
<span class="n">attention_head_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="o">/</span> <span class="n">num_attention_heads</span>
|
||||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">attention_head_size</span><span class="o">**</span><span class="mf">.5</span>
|
||||
|
||||
<span class="n">block_sparse</span> <span class="o">=</span> <span class="p">((</span><span class="n">layer_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span>
|
||||
<span class="n">config</span><span class="o">.</span><span class="n">dense_attention_every_n_layers</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span>
|
||||
<span class="n">block_sparse</span> <span class="o">=</span> <span class="p">(</span>
|
||||
<span class="p">(</span><span class="n">layer_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">config</span><span class="o">.</span><span class="n">dense_attention_every_n_layers</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span>
|
||||
<span class="n">attention_mask_type</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">blocksparse</span> <span class="k">if</span> <span class="n">block_sparse</span> <span class="k">else</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span>
|
||||
|
||||
<span class="n">block_sparse_attn_params</span> <span class="o">=</span> <span class="n">BlockSparseAttnParams</span><span class="p">(</span>
|
||||
@ -486,7 +487,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e44ea20>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac790dd60>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -805,7 +806,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e4a5c40>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac950ec60>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -476,7 +477,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e44e030>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac950ddc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -442,8 +443,6 @@
|
||||
<span class="s2">"help"</span><span class="p">:</span>
|
||||
<span class="s2">"Enable customized kernels to speed up conv1d operator for Mamba."</span>
|
||||
<span class="p">})</span>
|
||||
<span class="n">_fp4_gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span>
|
||||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">metadata</span><span class="o">=</span><span class="p">{</span><span class="s2">"help"</span><span class="p">:</span> <span class="s2">"Enable FP4 GEMM kernels."</span><span class="p">})</span>
|
||||
<span class="n">_low_latency_gemm_plugin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span>
|
||||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||||
<span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
@ -687,10 +686,6 @@
|
||||
<span class="k">def</span> <span class="nf">set_nccl_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"auto"</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_plugin</span> <span class="o">=</span> <span class="n">dtype</span>
|
||||
<span class="n">init_all_reduce_helper</span><span class="p">()</span>
|
||||
<span class="k">return</span> <span class="bp">self</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">set_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="o">=</span> <span class="n">dtype</span>
|
||||
<span class="k">return</span> <span class="bp">self</span></div>
|
||||
|
||||
|
||||
@ -875,7 +870,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e49fef0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac9425dc0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -201,7 +202,6 @@
|
||||
<span class="n">FP8_PER_CHANNEL_PER_TOKEN</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="n">INT8</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="n">MIXED_PRECISION</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="n">NVFP4</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="n">NO_QUANT</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span></div>
|
||||
|
||||
|
||||
@ -247,8 +247,6 @@
|
||||
<span class="n">FP8_ROWWISE</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="c1"># W4A8 qserve</span>
|
||||
<span class="n">W4A8_QSERVE</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
<span class="c1"># FP4</span>
|
||||
<span class="n">NVFP4</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
|
||||
<span class="c1"># The smallest power-of-two that is not used by a flag. Do not call auto() after that line.</span>
|
||||
<span class="n">COUNT</span> <span class="o">=</span> <span class="n">auto</span><span class="p">()</span>
|
||||
@ -324,9 +322,6 @@
|
||||
<span class="k">def</span> <span class="nf">has_fp8_rowwise</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">FP8_ROWWISE</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">has_nvfp4</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">has_weight_quant</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">INT4_WEIGHTS</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">INT8_WEIGHTS</span><span class="p">)</span>
|
||||
|
||||
@ -334,8 +329,7 @@
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_any</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">INT4_WEIGHTS</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">INT8_WEIGHTS</span>
|
||||
<span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">ACTIVATIONS</span>
|
||||
<span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">INT8_KV_CACHE</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">FP8_KV_CACHE</span>
|
||||
<span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">FP8_QDQ</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">FP8_ROWWISE</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">W4A8_QSERVE</span>
|
||||
<span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">)</span>
|
||||
<span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">FP8_QDQ</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">FP8_ROWWISE</span><span class="p">)</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">set_int8_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="bp">self</span> <span class="o">|</span> <span class="bp">self</span><span class="o">.</span><span class="n">INT8_KV_CACHE</span>
|
||||
@ -360,7 +354,6 @@
|
||||
<span class="n">use_fp8_kv_cache</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_fp8_qdq</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_fp8_rowwise</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_nvfp4</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">use_w4a8_qserve</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">raise_error</span><span class="p">():</span>
|
||||
@ -375,7 +368,6 @@
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">use_fp8_kv_cache</span><span class="si">=}</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">use_fp8_qdq</span><span class="si">=}</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">use_fp8_rowwise</span><span class="si">=}</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">use_nvfp4</span><span class="si">=}</span><span class="s2">"</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">use_w4a8_qserve</span><span class="si">=}</span><span class="s2">"</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># We must quantize weights when we quantize activations.</span>
|
||||
@ -421,9 +413,6 @@
|
||||
<span class="k">if</span> <span class="n">use_fp8_rowwise</span><span class="p">:</span>
|
||||
<span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span> <span class="o">|</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">FP8_ROWWISE</span> <span class="o">|</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">PER_TOKEN</span> <span class="o">|</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">PER_CHANNEL</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">use_nvfp4</span><span class="p">:</span>
|
||||
<span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span> <span class="o">|</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">NVFP4</span>
|
||||
|
||||
<span class="c1"># W4A8 QServe</span>
|
||||
<span class="k">if</span> <span class="n">use_w4a8_qserve</span><span class="p">:</span>
|
||||
<span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span> <span class="o">|</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">W4A8_QSERVE</span>
|
||||
@ -497,8 +486,6 @@
|
||||
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">from_description</span><span class="p">(</span><span class="n">use_fp8_qdq</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8_PER_CHANNEL_PER_TOKEN</span><span class="p">:</span>
|
||||
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">from_description</span><span class="p">(</span><span class="n">use_fp8_rowwise</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">:</span>
|
||||
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="o">.</span><span class="n">from_description</span><span class="p">(</span><span class="n">use_nvfp4</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
|
||||
@ -525,8 +512,6 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">(),</span>
|
||||
<span class="s1">'enable_fp8_rowwise'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">(),</span>
|
||||
<span class="s1">'enable_nvfp4'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_nvfp4</span><span class="p">(),</span>
|
||||
<span class="s1">'fp8_kv_cache'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(),</span>
|
||||
<span class="s1">'use_weight_only'</span><span class="p">:</span>
|
||||
@ -544,7 +529,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e3f4860>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac940e4e0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -190,13 +191,11 @@
|
||||
<span class="kn">from</span> <span class="nn">datasets</span> <span class="kn">import</span> <span class="n">load_dataset</span>
|
||||
<span class="kn">from</span> <span class="nn">safetensors.torch</span> <span class="kn">import</span> <span class="n">load_file</span><span class="p">,</span> <span class="n">save_file</span>
|
||||
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
|
||||
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="p">(</span><span class="n">AutoConfig</span><span class="p">,</span> <span class="n">AutoModelForCausalLM</span><span class="p">,</span> <span class="n">AutoProcessor</span><span class="p">,</span>
|
||||
<span class="n">AutoTokenizer</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoConfig</span><span class="p">,</span> <span class="n">AutoModelForCausalLM</span><span class="p">,</span> <span class="n">AutoTokenizer</span>
|
||||
|
||||
<span class="kn">from</span> <span class="nn">.._utils</span> <span class="kn">import</span> <span class="n">release_gc</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span>
|
||||
<span class="kn">from</span> <span class="nn">..logger</span> <span class="kn">import</span> <span class="n">logger</span>
|
||||
<span class="kn">from</span> <span class="nn">..mapping</span> <span class="kn">import</span> <span class="n">Mapping</span>
|
||||
<span class="kn">from</span> <span class="nn">.image_processing</span> <span class="kn">import</span> <span class="n">MllamaImageProcessor</span>
|
||||
<span class="kn">from</span> <span class="nn">.mode</span> <span class="kn">import</span> <span class="n">QuantAlgo</span>
|
||||
|
||||
<span class="n">EMPTY_CFG</span> <span class="o">=</span> <span class="p">{</span>
|
||||
@ -265,17 +264,13 @@
|
||||
<span class="s2">"int4_wo"</span><span class="p">:</span> <span class="n">EMPTY_CFG</span><span class="p">,</span>
|
||||
<span class="s2">"full_prec"</span><span class="p">:</span> <span class="n">EMPTY_CFG</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">mtq</span><span class="p">,</span> <span class="s2">"NVFP4_DEFAULT_CFG"</span><span class="p">):</span>
|
||||
<span class="n">QUANT_CFG_CHOICES</span><span class="p">[</span><span class="s2">"nvfp4"</span><span class="p">]</span> <span class="o">=</span> <span class="n">mtq</span><span class="o">.</span><span class="n">NVFP4_DEFAULT_CFG</span>
|
||||
<span class="k">return</span> <span class="n">QUANT_CFG_CHOICES</span>
|
||||
|
||||
|
||||
<span class="n">MODEL_NAME_PATTERN_MAP</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s2">"GPT2"</span><span class="p">:</span> <span class="s2">"gpt2"</span><span class="p">,</span>
|
||||
<span class="s2">"Xverse"</span><span class="p">:</span> <span class="s2">"llama"</span><span class="p">,</span>
|
||||
<span class="s2">"MllamaForConditionalGeneration"</span><span class="p">:</span> <span class="s2">"mllama"</span><span class="p">,</span>
|
||||
<span class="s2">"Llama"</span><span class="p">:</span> <span class="s2">"llama"</span><span class="p">,</span>
|
||||
<span class="s2">"MllamaForCausalLM"</span><span class="p">:</span> <span class="s2">"mllama"</span><span class="p">,</span>
|
||||
<span class="s2">"Mistral"</span><span class="p">:</span> <span class="s2">"llama"</span><span class="p">,</span>
|
||||
<span class="s2">"GPTJ"</span><span class="p">:</span> <span class="s2">"gptj"</span><span class="p">,</span>
|
||||
<span class="s2">"FalconForCausalLM"</span><span class="p">:</span> <span class="s2">"falcon"</span><span class="p">,</span>
|
||||
@ -292,7 +287,6 @@
|
||||
<span class="s2">"NemotronForCausalLM"</span><span class="p">:</span> <span class="s2">"nemotron"</span><span class="p">,</span>
|
||||
<span class="s2">"GPTBigCodeForCausalLM"</span><span class="p">:</span> <span class="s2">"gpt_bigcode"</span><span class="p">,</span>
|
||||
<span class="s2">"ArcticForCausalLM"</span><span class="p">:</span> <span class="s2">"llama"</span><span class="p">,</span>
|
||||
<span class="s2">"PhiMoEForCausalLM"</span><span class="p">:</span> <span class="s2">"phi3"</span><span class="p">,</span>
|
||||
<span class="s2">"Phi3SmallForCausalLM"</span><span class="p">:</span> <span class="s2">"phi3small"</span><span class="p">,</span>
|
||||
<span class="s2">"Phi3ForCausalLM"</span><span class="p">:</span> <span class="s2">"phi3"</span><span class="p">,</span>
|
||||
<span class="s2">"Starcoder2ForCausalLM"</span><span class="p">:</span> <span class="s2">"gptnext"</span><span class="p">,</span>
|
||||
@ -301,12 +295,8 @@
|
||||
<span class="s2">"Exaone"</span><span class="p">:</span> <span class="s2">"exaone"</span><span class="p">,</span>
|
||||
<span class="s2">"DeciLMForCausalLM"</span><span class="p">:</span> <span class="s2">"deci"</span><span class="p">,</span>
|
||||
<span class="s2">"DeepseekForCausalLM"</span><span class="p">:</span> <span class="s2">"deepseek"</span><span class="p">,</span>
|
||||
<span class="s2">"GraniteForCausalLM"</span><span class="p">:</span> <span class="s2">"granite"</span><span class="p">,</span>
|
||||
<span class="s2">"GraniteMoeForCausalLM"</span><span class="p">:</span> <span class="s2">"granitemoe"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="n">MULTIMODAL_DATASETS</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'scienceqa'</span><span class="p">,</span> <span class="s1">'science_qa'</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="k">class</span> <span class="nc">_CustomDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span>
|
||||
|
||||
@ -345,31 +335,6 @@
|
||||
<span class="k">return</span> <span class="n">tokenizer</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_processor</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">max_seq_length</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">model_type</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Initializing tokenizer from </span><span class="si">{</span><span class="n">ckpt_path</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="n">processor</span> <span class="o">=</span> <span class="n">AutoProcessor</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
||||
<span class="n">ckpt_path</span><span class="p">,</span>
|
||||
<span class="n">model_max_length</span><span class="o">=</span><span class="n">max_seq_length</span><span class="p">,</span>
|
||||
<span class="n">padding_side</span><span class="o">=</span><span class="s2">"left"</span><span class="p">,</span>
|
||||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">model_type</span> <span class="ow">and</span> <span class="n">model_type</span> <span class="o">==</span> <span class="s2">"qwen"</span><span class="p">:</span>
|
||||
<span class="c1"># qwen use token id 151643 as pad and eos tokens</span>
|
||||
<span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">eos_token</span> <span class="o">=</span> <span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span>
|
||||
<span class="mi">151643</span><span class="p">)</span>
|
||||
<span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token</span> <span class="o">=</span> <span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span>
|
||||
<span class="mi">151643</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token</span> <span class="o">=</span> <span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">eos_token</span>
|
||||
<span class="k">assert</span> <span class="n">processor</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Pad token for </span><span class="si">{</span><span class="n">model_type</span><span class="si">}</span><span class="s2"> cannot be set!"</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">model_type</span> <span class="o">==</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="n">processor</span> <span class="o">=</span> <span class="n">MllamaImageProcessor</span><span class="p">(</span><span class="n">processor</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">processor</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_get_vila_model</span><span class="p">(</span><span class="n">model_dir</span><span class="p">):</span>
|
||||
<span class="n">sys</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">model_dir</span> <span class="o">+</span> <span class="s2">"/../VILA"</span><span class="p">)</span>
|
||||
<span class="kn">from</span> <span class="nn">llava.model</span> <span class="kn">import</span> <span class="n">LlavaLlamaConfig</span><span class="p">,</span> <span class="n">LlavaLlamaModel</span> <span class="c1"># noqa</span>
|
||||
@ -420,10 +385,6 @@
|
||||
<span class="k">elif</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"mpt"</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">MptForCausalLM</span>
|
||||
<span class="n">model_cls</span> <span class="o">=</span> <span class="n">MptForCausalLM</span>
|
||||
<span class="k">elif</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">MllamaForConditionalGeneration</span>
|
||||
<span class="n">model_cls</span> <span class="o">=</span> <span class="n">MllamaForConditionalGeneration</span>
|
||||
|
||||
<span class="k">if</span> <span class="s2">"vila"</span> <span class="ow">in</span> <span class="n">ckpt_path</span><span class="p">:</span>
|
||||
<span class="n">model</span> <span class="o">=</span> <span class="n">_get_vila_model</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"llava-onevision-qwen2"</span> <span class="ow">in</span> <span class="n">ckpt_path</span><span class="p">:</span>
|
||||
@ -442,7 +403,6 @@
|
||||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">hf_config</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"llava"</span><span class="p">,</span> <span class="s2">"internvl_chat"</span><span class="p">]:</span>
|
||||
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">language_model</span>
|
||||
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
|
||||
|
||||
<span class="n">model_dtype</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span><span class="o">.</span><span class="n">dtype</span>
|
||||
@ -455,8 +415,6 @@
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">get_model_type</span><span class="p">(</span><span class="n">model</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span> <span class="ow">in</span> <span class="n">MODEL_NAME_PATTERN_MAP</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="n">MODEL_NAME_PATTERN_MAP</span><span class="p">[</span><span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">MODEL_NAME_PATTERN_MAP</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="n">k</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="ow">in</span> <span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span><span class="o">.</span><span class="n">lower</span><span class="p">():</span>
|
||||
<span class="k">return</span> <span class="n">v</span>
|
||||
@ -477,13 +435,6 @@
|
||||
<span class="n">data_files</span><span class="o">=</span><span class="s2">"https://the-eye.eu/public/AI/pile/val.jsonl.zst"</span><span class="p">,</span>
|
||||
<span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">)</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="s2">"text"</span><span class="p">][:</span><span class="n">calib_size</span><span class="p">]</span>
|
||||
<span class="k">elif</span> <span class="s2">"scienceqa"</span> <span class="ow">in</span> <span class="n">dataset_name_or_dir</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span>
|
||||
<span class="p">)</span> <span class="ow">or</span> <span class="s2">"science_qa"</span> <span class="ow">in</span> <span class="n">dataset_name_or_dir</span><span class="o">.</span><span class="n">lower</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">dataset_name_or_dir</span><span class="p">):</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="n">dataset_name_or_dir</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="s2">"derek-thomas/ScienceQA"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">)</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">calib_size</span><span class="p">))</span>
|
||||
<span class="k">elif</span> <span class="s2">"cnn_dailymail"</span> <span class="ow">in</span> <span class="n">dataset_name_or_dir</span><span class="p">:</span>
|
||||
<span class="n">dataset</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span>
|
||||
<span class="n">dataset_name_or_dir</span><span class="p">,</span>
|
||||
@ -503,48 +454,33 @@
|
||||
<span class="sa">f</span><span class="s2">"Unsupported dataset name or local repo directory: </span><span class="si">{</span><span class="n">dataset_name_or_dir</span><span class="si">}</span><span class="s2">."</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="n">is_multimodal</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="k">for</span> <span class="n">dataset_name</span> <span class="ow">in</span> <span class="n">MULTIMODAL_DATASETS</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">dataset_name</span> <span class="ow">in</span> <span class="n">dataset_name_or_dir</span><span class="p">:</span>
|
||||
<span class="n">is_multimodal</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">if</span> <span class="n">is_multimodal</span><span class="p">:</span>
|
||||
<span class="c1"># Apply the preprocessing function to the dataset</span>
|
||||
<span class="n">processed_dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">preprocess_function</span><span class="p">,</span>
|
||||
<span class="n">batched</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">remove_columns</span><span class="o">=</span><span class="n">dataset</span><span class="o">.</span><span class="n">column_names</span><span class="p">)</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">batch_encode_plus</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">truncation</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">max_length</span><span class="o">=</span><span class="n">block_size</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Create DataLoader with the custom collate function</span>
|
||||
<span class="n">calib_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">processed_dataset</span><span class="p">,</span>
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">collate_fn</span><span class="o">=</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">collate_function</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">device</span><span class="p">:</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">batch_encoded</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">include_labels</span><span class="p">:</span>
|
||||
<span class="c1"># Labels are needed when backward is called in the model.</span>
|
||||
<span class="c1"># The labels should be a shifted version of the input_ids.</span>
|
||||
<span class="c1"># However, we should not shift the input_ids here since the labels are shifted by</span>
|
||||
<span class="c1"># Huggingface models during loss calculation as shown here -</span>
|
||||
<span class="c1"># https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/llama/modeling_llama.py#L1093-L1095</span>
|
||||
<span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"labels"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
|
||||
<span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"attention_mask"</span><span class="p">]</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">],</span>
|
||||
<span class="o">-</span><span class="mi">100</span><span class="p">)</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">_CustomDataset</span><span class="p">(</span><span class="n">batch_encoded</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">batch_encode_plus</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">truncation</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">max_length</span><span class="o">=</span><span class="n">block_size</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">device</span><span class="p">:</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">batch_encoded</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="c1"># For backward compatibility, if labels are not needed, we only return input_ids.</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">_CustomDataset</span><span class="p">(</span>
|
||||
<span class="p">{</span><span class="s2">"input_ids"</span><span class="p">:</span> <span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">]})</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">include_labels</span><span class="p">:</span>
|
||||
<span class="c1"># Labels are needed when backward is called in the model.</span>
|
||||
<span class="c1"># The labels should be a shifted version of the input_ids.</span>
|
||||
<span class="c1"># However, we should not shift the input_ids here since the labels are shifted by</span>
|
||||
<span class="c1"># Huggingface models during loss calculation as shown here -</span>
|
||||
<span class="c1"># https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/llama/modeling_llama.py#L1093-L1095</span>
|
||||
<span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"labels"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
|
||||
<span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"attention_mask"</span><span class="p">]</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">,</span>
|
||||
<span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">],</span> <span class="o">-</span><span class="mi">100</span><span class="p">)</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">_CustomDataset</span><span class="p">(</span><span class="n">batch_encoded</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># For backward compatibility, if labels are not needed, we only return input_ids.</span>
|
||||
<span class="n">batch_encoded</span> <span class="o">=</span> <span class="n">_CustomDataset</span><span class="p">(</span>
|
||||
<span class="p">{</span><span class="s2">"input_ids"</span><span class="p">:</span> <span class="n">batch_encoded</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">]})</span>
|
||||
|
||||
<span class="n">calib_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">batch_encoded</span><span class="p">,</span>
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">calib_dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">batch_encoded</span><span class="p">,</span>
|
||||
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">calib_dataloader</span>
|
||||
|
||||
@ -562,8 +498,7 @@
|
||||
<span class="k">return</span>
|
||||
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
||||
<span class="n">low_mem_mode</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">calib_dataloader</span><span class="p">):</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Calibrating batch </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">calib_dataloader</span><span class="p">):</span>
|
||||
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">keys</span><span class="p">())[</span><span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
|
||||
@ -783,11 +718,6 @@
|
||||
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_tokenizer</span><span class="p">(</span><span class="n">model_dir</span> <span class="o">+</span> <span class="s2">"/llm"</span><span class="p">,</span>
|
||||
<span class="n">max_seq_length</span><span class="o">=</span><span class="n">tokenizer_max_seq_length</span><span class="p">,</span>
|
||||
<span class="n">model_type</span><span class="o">=</span><span class="n">model_type</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="n">model_type</span> <span class="o">==</span> <span class="s2">"mllama"</span><span class="p">:</span>
|
||||
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_processor</span><span class="p">(</span><span class="n">model_dir</span><span class="p">,</span>
|
||||
<span class="n">max_seq_length</span><span class="o">=</span><span class="n">tokenizer_max_seq_length</span><span class="p">,</span>
|
||||
<span class="n">model_type</span><span class="o">=</span><span class="n">model_type</span><span class="p">,</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_tokenizer</span><span class="p">(</span><span class="n">model_dir</span><span class="p">,</span>
|
||||
<span class="n">max_seq_length</span><span class="o">=</span><span class="n">tokenizer_max_seq_length</span><span class="p">,</span>
|
||||
@ -875,8 +805,6 @@
|
||||
<span class="s2">"w4a8_awq"</span><span class="p">:</span> <span class="s2">"W4A8_AWQ"</span><span class="p">,</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">model_type</span> <span class="o">==</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">language_model</span>
|
||||
<span class="n">export_tensorrt_llm_checkpoint</span><span class="p">(</span>
|
||||
<span class="n">model</span><span class="p">,</span>
|
||||
<span class="n">model_type</span><span class="p">,</span>
|
||||
@ -912,7 +840,7 @@
|
||||
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">tensorrt_llm_config</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Workaround for Modelopt 0.9.x fp8_kv_cache knob issue</span>
|
||||
<span class="k">if</span> <span class="n">qformat</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'fp8'</span><span class="p">,</span> <span class="s1">'nvfp4'</span><span class="p">]</span> <span class="ow">and</span> <span class="n">kv_cache_dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">qformat</span> <span class="o">==</span> <span class="s1">'fp8'</span> <span class="ow">and</span> <span class="n">kv_cache_dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">export_path</span><span class="si">}</span><span class="s2">/config.json"</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||||
<span class="n">tensorrt_llm_config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||||
<span class="n">tensorrt_llm_config</span><span class="p">[</span><span class="s2">"quantization"</span><span class="p">][</span><span class="s2">"kv_cache_quant_algo"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
@ -990,21 +918,6 @@
|
||||
<span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">num_medusa_layers</span><span class="p">,</span>
|
||||
<span class="n">max_draft_len</span><span class="p">,</span> <span class="n">medusa_hidden_act</span><span class="p">,</span>
|
||||
<span class="n">medusa_model_dir</span><span class="p">,</span> <span class="n">quant_medusa_head</span><span class="p">)</span>
|
||||
|
||||
<span class="c1"># Workaround for mllama</span>
|
||||
<span class="k">if</span> <span class="n">model_type</span> <span class="o">==</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">tensorrt_llm.models.mllama.config</span> <span class="kn">import</span> <span class="n">MLLaMAConfig</span>
|
||||
<span class="n">config</span> <span class="o">=</span> <span class="n">MLLaMAConfig</span><span class="o">.</span><span class="n">from_hugging_face</span><span class="p">(</span>
|
||||
<span class="n">model_dir</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||||
<span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">tensorrt_llm_config</span><span class="p">:</span>
|
||||
<span class="n">tensorrt_llm_config</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
|
||||
|
||||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">export_path</span><span class="si">}</span><span class="s2">/config.json"</span><span class="p">,</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||||
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">tensorrt_llm_config</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
|
||||
|
||||
<span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||||
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
|
||||
<span class="s2">"Quantized model exported to </span><span class="si">{}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">Total time used </span><span class="si">{:.2f}</span><span class="s2"> s."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
|
||||
@ -1322,7 +1235,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e1b5970>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac940dcd0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -705,7 +706,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e1b6420>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac9556750>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -1204,7 +1205,7 @@
|
||||
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
||||
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">,</span>
|
||||
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">,</span> <span class="s1">'spec_decoding_packed_mask'</span><span class="p">,</span>
|
||||
<span class="s1">'spec_decoding_use'</span><span class="p">,</span> <span class="s1">'medusa_logits'</span>
|
||||
<span class="s1">'medusa_logits'</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
||||
@ -1793,7 +1794,6 @@
|
||||
<span class="c1"># Note: spec_decoding_packed_mask has no paddings in the first dimension.</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
|
||||
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">target_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_spec_decoding_use</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_paths</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_tree_ids</span>
|
||||
@ -2149,7 +2149,6 @@
|
||||
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
||||
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
||||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span></div>
|
||||
@ -2504,7 +2503,6 @@
|
||||
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">)</span>
|
||||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_generation_lengths'</span><span class="p">],</span>
|
||||
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">)</span>
|
||||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">],</span> <span class="s1">'spec_decoding_use'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">return</span> <span class="n">tensors</span>
|
||||
|
||||
@ -2847,7 +2845,6 @@
|
||||
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">)</span>
|
||||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_generation_lengths'</span><span class="p">],</span>
|
||||
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">)</span>
|
||||
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">],</span> <span class="s1">'spec_decoding_use'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
||||
<span class="n">set_redrafter_gen_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">add_tensor</span><span class="p">,</span>
|
||||
@ -3319,11 +3316,11 @@
|
||||
<span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||||
<span class="n">should_stop_with_end_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">end_id_mask</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span>
|
||||
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||||
<span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
|
||||
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">>=</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
|
||||
<span class="c1"># update accept lengths for the current step.</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
||||
<span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">>=</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="n">prev_total_output_length</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
||||
@ -4803,7 +4800,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e5d7bf0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac950eea0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -662,7 +663,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e22aa80>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac950eff0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -1150,7 +1151,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e557cb0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac95383e0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -210,15 +210,6 @@
|
||||
<span class="n">SamplingConfigType</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span><span class="n">SamplingConfig</span><span class="p">,</span> <span class="n">trtllm</span><span class="o">.</span><span class="n">SamplingConfig</span><span class="p">]</span>
|
||||
|
||||
|
||||
<span class="k">def</span> <span class="nf">_world_config_to_mapping</span><span class="p">(</span><span class="n">world_config</span><span class="p">:</span> <span class="n">WorldConfig</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">size</span><span class="p">,</span>
|
||||
<span class="n">rank</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span>
|
||||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">tensor_parallelism</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">pipeline_parallelism</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">context_parallelism</span><span class="p">)</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="ModelRunnerCpp">
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelRunnerCpp">[docs]</a>
|
||||
<span class="k">class</span> <span class="nc">ModelRunnerCpp</span><span class="p">(</span><span class="n">ModelRunnerMixin</span><span class="p">):</span>
|
||||
@ -242,7 +233,12 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">=</span> <span class="n">max_seq_len</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="o">=</span> <span class="n">max_beam_width</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">model_config</span> <span class="o">=</span> <span class="n">model_config</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">_world_config_to_mapping</span><span class="p">(</span><span class="n">world_config</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">size</span><span class="p">,</span>
|
||||
<span class="n">rank</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span>
|
||||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||||
<span class="n">tp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">tensor_parallelism</span><span class="p">,</span>
|
||||
<span class="n">pp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">pipeline_parallelism</span><span class="p">,</span>
|
||||
<span class="n">cp_size</span><span class="o">=</span><span class="n">world_config</span><span class="o">.</span><span class="n">context_parallelism</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">world_config</span> <span class="o">=</span> <span class="n">world_config</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="o">=</span> <span class="n">use_kv_cache</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
|
||||
@ -461,8 +457,9 @@
|
||||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="n">gpus_per_node</span><span class="p">)</span>
|
||||
<span class="k">assert</span> <span class="n">rank</span> <span class="o">==</span> <span class="n">world_config</span><span class="o">.</span><span class="n">rank</span>
|
||||
|
||||
<span class="n">engine_config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">engine_dir</span><span class="si">}</span><span class="s2">/config.json"</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||||
<span class="n">engine_config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span>
|
||||
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">engine_dir</span><span class="si">}</span><span class="s2">/config.json"</span><span class="p">)</span>
|
||||
<span class="n">lora_manager</span> <span class="o">=</span> <span class="n">LoraManager</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="n">lora_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">config_lora_dir</span> <span class="o">=</span> <span class="n">engine_config</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span>
|
||||
@ -613,13 +610,11 @@
|
||||
<span class="kn">from</span> <span class="nn">..pyexecutor.backend_registries.backend_registry</span> <span class="kn">import</span> \
|
||||
<span class="n">unique_create_executor</span>
|
||||
<span class="kn">from</span> <span class="nn">..pyexecutor.config</span> <span class="kn">import</span> <span class="n">update_executor_config</span>
|
||||
<span class="n">update_executor_config</span><span class="p">(</span>
|
||||
<span class="n">trtllm_config</span><span class="p">,</span>
|
||||
<span class="n">backend</span><span class="o">=</span><span class="n">backend</span><span class="p">,</span>
|
||||
<span class="n">mapping</span><span class="o">=</span><span class="n">_world_config_to_mapping</span><span class="p">(</span><span class="n">world_config</span><span class="p">),</span>
|
||||
<span class="n">build_config</span><span class="o">=</span><span class="n">engine_config</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span>
|
||||
<span class="n">trt_engine_dir</span><span class="o">=</span><span class="n">engine_dir</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">py_executor_config</span><span class="p">)</span>
|
||||
<span class="n">update_executor_config</span><span class="p">(</span><span class="n">trtllm_config</span><span class="p">,</span>
|
||||
<span class="n">backend</span><span class="o">=</span><span class="n">backend</span><span class="p">,</span>
|
||||
<span class="n">max_seq_len</span><span class="o">=</span><span class="n">max_seq_len</span><span class="p">,</span>
|
||||
<span class="n">trt_engine_dir</span><span class="o">=</span><span class="n">engine_dir</span><span class="p">,</span>
|
||||
<span class="o">**</span><span class="n">py_executor_config</span><span class="p">)</span>
|
||||
<span class="n">executor</span> <span class="o">=</span> <span class="n">unique_create_executor</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">),</span>
|
||||
<span class="n">trtllm</span><span class="o">.</span><span class="n">ModelType</span><span class="o">.</span><span class="n">DECODER_ONLY</span><span class="p">,</span>
|
||||
<span class="n">trtllm_config</span><span class="p">)</span>
|
||||
@ -1014,11 +1009,11 @@
|
||||
<span class="k">def</span> <span class="nf">_prepare_mrope_executor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_input_ids_list</span><span class="p">,</span> <span class="n">mrope</span><span class="p">:</span> <span class="n">MropeParams</span><span class="p">):</span>
|
||||
<span class="n">mrope_configs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_input_ids_list</span><span class="p">)</span> <span class="o">*</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span>
|
||||
<span class="k">if</span> <span class="n">mrope</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span> <span class="o">=</span> <span class="n">mrope</span><span class="o">.</span><span class="n">mrope_rotary_cos_sin</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span> <span class="o">=</span> <span class="n">mrope</span><span class="o">.</span><span class="n">mrope_rotary_sin_cos</span>
|
||||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">),</span> <span class="s2">"mrope_rotary_cos_sin should be torch.Tensor"</span>
|
||||
<span class="n">mrope_rotary_cos_sin_data</span> <span class="o">=</span> <span class="n">mrope_rotary_cos_sin</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="p">,</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">),</span> <span class="s2">"mrope_rotary_sin_cos should be torch.Tensor"</span>
|
||||
<span class="n">mrope_rotary_sin_cos_data</span> <span class="o">=</span> <span class="n">mrope_rotary_sin_cos</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">'cpu'</span><span class="p">))</span>
|
||||
|
||||
<span class="n">mrope_position_deltas</span> <span class="o">=</span> <span class="n">mrope</span><span class="o">.</span><span class="n">mrope_position_deltas</span>
|
||||
@ -1030,7 +1025,7 @@
|
||||
|
||||
<span class="n">mrope_configs</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">trtllm</span><span class="o">.</span><span class="n">MropeConfig</span><span class="p">(</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="o">=</span><span class="n">mrope_rotary_cos_sin_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="o">=</span><span class="n">mrope_rotary_sin_cos_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="o">=</span><span class="n">mrope_position_deltas_data</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">batch_input_ids_list</span><span class="p">))</span>
|
||||
<span class="p">]</span>
|
||||
@ -1350,7 +1345,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7da791a97050>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac9449730>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -596,10 +597,7 @@
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">use_fast</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"phi-3-vision"</span><span class="p">,</span> <span class="s2">"internvl"</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="p">,</span>
|
||||
<span class="n">use_fast</span><span class="o">=</span><span class="n">use_fast</span><span class="p">,</span>
|
||||
<span class="n">use_legacy</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="p">,</span> <span class="n">use_fast</span><span class="o">=</span><span class="n">use_fast</span><span class="p">,</span> <span class="n">use_legacy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">padding_side</span> <span class="o">=</span> <span class="s2">"right"</span></div>
|
||||
|
||||
@ -676,11 +674,7 @@
|
||||
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
|
||||
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">processor</span> <span class="o">=</span> <span class="n">processor</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="kn">from</span> <span class="nn">.processor_wrapper</span> <span class="kn">import</span> <span class="n">MllamaProcessorWrapper</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">processor</span> <span class="o">=</span> <span class="n">MllamaProcessorWrapper</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">,</span> <span class="n">logger</span><span class="p">)</span></div>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">processor</span> <span class="o">=</span> <span class="n">processor</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MultimodalModelRunner.init_image_encoder">
|
||||
@ -821,15 +815,6 @@
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MultimodalModelRunner.preprocess">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">preprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">warmup</span><span class="p">,</span> <span class="n">pre_prompt</span><span class="p">,</span> <span class="n">post_prompt</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span>
|
||||
<span class="n">other_vision_inputs</span><span class="p">):</span>
|
||||
<span class="c1"># same prompt for single/multiple image(s)</span>
|
||||
<span class="n">n_prompts_n_images</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||||
<span class="n">post_prompt</span><span class="p">,</span>
|
||||
<span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">post_prompt</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span>
|
||||
<span class="n">post_prompt</span><span class="p">)</span> <span class="o">==</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
|
||||
<span class="c1"># n prompts and n images</span>
|
||||
<span class="n">n_prompts_n_images</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'kosmos-2'</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">image</span><span class="p">[</span><span class="s1">'input_ids'</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
||||
<span class="n">image_mask</span> <span class="o">=</span> <span class="n">image</span><span class="p">[</span><span class="s2">"image_embeds_position_mask"</span><span class="p">]</span>
|
||||
@ -915,43 +900,25 @@
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'kosmos-2'</span><span class="p">:</span>
|
||||
<span class="n">visual_features</span> <span class="o">=</span> <span class="n">visual_features</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'vila'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer_image_token</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">pre_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">post_prompt</span><span class="p">,</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer_image_token</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">pre_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">post_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer_image_token</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">pre_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">post_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">)</span>
|
||||
<span class="n">batch_split_prompts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">split_prompt_by_images</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">first_batch_split_prompts</span> <span class="o">=</span> <span class="n">batch_split_prompts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># compute prompt length + visual length</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">first_batch_split_prompts</span><span class="p">])</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="c1"># mode 1: multiple image as a whole, flatten visual dims</span>
|
||||
<span class="n">length</span> <span class="o">+=</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">+=</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">(</span>
|
||||
<span class="p">[</span><span class="n">length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">ptuning_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">setup_fake_prompts_vila</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">visual_features</span><span class="p">,</span>
|
||||
<span class="n">first_batch_split_prompts</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">)</span>
|
||||
<span class="n">first_batch_split_prompts</span> <span class="o">=</span> <span class="n">batch_split_prompts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="c1"># compute prompt length + visual length</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">([</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">first_batch_split_prompts</span><span class="p">])</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="c1"># mode 1: multiple image as a whole, flatten visual dims</span>
|
||||
<span class="n">length</span> <span class="o">+=</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="c1"># mode 2: multiple different prompts corresponding to multiple images (1-1 correspondence)</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="nb">sum</span><span class="p">([</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">batch_split_prompt</span><span class="p">])</span>
|
||||
<span class="k">for</span> <span class="n">batch_split_prompt</span> <span class="ow">in</span> <span class="n">batch_split_prompts</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="p">[</span><span class="n">l</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">l</span> <span class="ow">in</span> <span class="n">length</span><span class="p">]</span>
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">(</span><span class="n">length</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">ptuning_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">setup_fake_prompts_vila</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">visual_features</span><span class="p">,</span> <span class="n">batch_split_prompts</span><span class="p">,</span>
|
||||
<span class="n">input_lengths</span><span class="p">)</span>
|
||||
<span class="c1"># mode 2: multiple images individually (replicate prompt for each image)</span>
|
||||
<span class="n">length</span> <span class="o">+=</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="n">length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">ptuning_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">setup_fake_prompts_vila</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">visual_features</span><span class="p">,</span>
|
||||
<span class="n">first_batch_split_prompts</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">ptuning_args</span><span class="p">,</span> <span class="n">visual_features</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'phi-3-vision'</span><span class="p">:</span>
|
||||
<span class="n">image_sizes</span> <span class="o">=</span> <span class="nb">input</span><span class="p">[</span><span class="s2">"image_sizes"</span><span class="p">]</span>
|
||||
@ -977,10 +944,7 @@
|
||||
<span class="n">pre_input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">(</span><span class="n">pre_prompt</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">input_ids</span>
|
||||
<span class="k">if</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">post_input_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'llava_onevision'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
@ -1018,17 +982,9 @@
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">input_ids</span>
|
||||
<span class="k">if</span> <span class="n">post_prompt</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">post_input_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">(</span><span class="n">post_prompt</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">post_input_ids</span> <span class="o">=</span> <span class="n">post_input_encoded</span><span class="o">.</span><span class="n">input_ids</span>
|
||||
<span class="k">if</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">post_input_attention_mask</span> <span class="o">=</span> <span class="n">post_input_encoded</span><span class="o">.</span><span class="n">attention_mask</span>
|
||||
<span class="n">post_input_ids</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">input_id</span><span class="p">[</span><span class="n">mask</span><span class="o">.</span><span class="n">bool</span><span class="p">()]</span> <span class="k">for</span> <span class="n">input_id</span><span class="p">,</span> <span class="n">mask</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span>
|
||||
<span class="n">post_input_ids</span><span class="p">,</span> <span class="n">post_input_attention_mask</span><span class="p">)</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="n">post_input_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">(</span><span class="n">post_prompt</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">input_ids</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'video-neva'</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">post_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||||
<span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
@ -1036,24 +992,14 @@
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">post_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||||
<span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span>
|
||||
<span class="n">post_input_id</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="k">for</span> <span class="n">post_input_id</span> <span class="ow">in</span> <span class="n">post_input_ids</span>
|
||||
<span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">post_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||||
<span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">post_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
||||
<span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">post_input_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">length</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">visual_atts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="n">n_prompts_n_images</span><span class="p">:</span>
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">(</span><span class="n">length</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="n">length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="n">length</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span>
|
||||
<span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">,</span> <span class="s1">'phi-3-vision'</span><span class="p">,</span> <span class="s1">'llava_next'</span>
|
||||
@ -1070,43 +1016,33 @@
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MultimodalModelRunner.tokenizer_image_token">[docs]</a>
|
||||
<span class="nd">@staticmethod</span>
|
||||
<span class="k">def</span> <span class="nf">tokenizer_image_token</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
||||
<span class="n">pre_prompt</span><span class="p">,</span>
|
||||
<span class="n">post_prompt</span><span class="p">,</span>
|
||||
<span class="n">prompt</span><span class="p">,</span>
|
||||
<span class="n">tokenizer</span><span class="p">,</span>
|
||||
<span class="n">image_token_index</span><span class="o">=-</span><span class="mi">200</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">post_prompt</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">prompts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_prompt</span> <span class="o">+</span> <span class="n">item</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">post_prompt</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">prompts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_prompt</span> <span class="o">+</span> <span class="n">post_prompt</span><span class="p">]</span>
|
||||
<span class="n">prompt_chunks</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">tokenizer</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span><span class="o">.</span><span class="n">input_ids</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">prompt</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"<image>"</span><span class="p">)</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">insert_separator</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">sep</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="p">[</span>
|
||||
<span class="n">ele</span> <span class="k">for</span> <span class="n">sublist</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">[</span><span class="n">sep</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="p">))</span> <span class="k">for</span> <span class="n">ele</span> <span class="ow">in</span> <span class="n">sublist</span>
|
||||
<span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||||
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">prompt</span> <span class="ow">in</span> <span class="n">prompts</span><span class="p">:</span>
|
||||
<span class="n">prompt_chunks</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">tokenizer</span><span class="p">(</span><span class="n">chunk</span><span class="p">)</span><span class="o">.</span><span class="n">input_ids</span> <span class="k">for</span> <span class="n">chunk</span> <span class="ow">in</span> <span class="n">prompt</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"<image>"</span><span class="p">)</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="ow">and</span> <span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">bos_token_id</span><span class="p">):</span>
|
||||
<span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="n">offset</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">></span> <span class="mi">0</span>
|
||||
<span class="ow">and</span> <span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">bos_token_id</span><span class="p">):</span>
|
||||
<span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">insert_separator</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">,</span>
|
||||
<span class="p">[</span><span class="n">image_token_index</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">offset</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)):</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="n">offset</span><span class="p">:])</span>
|
||||
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">insert_separator</span><span class="p">(</span><span class="n">prompt_chunks</span><span class="p">,</span>
|
||||
<span class="p">[</span><span class="n">image_token_index</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">offset</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)):</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="n">offset</span><span class="p">:])</span>
|
||||
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="p">[</span><span class="n">input_ids</span> <span class="o">==</span> <span class="n">image_token_index</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">post_prompt</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">result</span></div>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="p">[</span><span class="n">input_ids</span> <span class="o">==</span> <span class="n">image_token_index</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">input_ids</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MultimodalModelRunner.split_prompt_by_images">
|
||||
@ -1166,7 +1102,7 @@
|
||||
<span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">ptuning_args</span><span class="p">,</span> <span class="n">visual_features</span><span class="p">,</span> <span class="n">mrope_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">preprocess</span><span class="p">(</span>
|
||||
<span class="n">warmup</span><span class="p">,</span> <span class="n">pre_prompt</span><span class="p">,</span> <span class="n">post_prompt</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">other_vision_inputs</span><span class="p">)</span>
|
||||
<span class="n">mrope_params</span> <span class="o">=</span> <span class="n">MropeParams</span><span class="p">(</span>
|
||||
<span class="n">mrope_rotary_cos_sin</span><span class="o">=</span><span class="n">mrope_args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">mrope_rotary_sin_cos</span><span class="o">=</span><span class="n">mrope_args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="n">mrope_position_deltas</span><span class="o">=</span><span class="n">mrope_args</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||||
<span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@ -1356,15 +1292,14 @@
|
||||
<span class="n">output_beams_list</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">batch_decode</span><span class="p">(</span>
|
||||
<span class="n">output_ids</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">,</span> <span class="p">:,</span> <span class="n">input_lengths</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]:],</span>
|
||||
<span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span>
|
||||
<span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
|
||||
<span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)</span>
|
||||
<span class="p">]</span>
|
||||
|
||||
<span class="n">stripped_text</span> <span class="o">=</span> <span class="p">[[</span>
|
||||
<span class="n">output_beams_list</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">][</span><span class="n">beam_idx</span><span class="p">]</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span>
|
||||
<span class="k">for</span> <span class="n">beam_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
|
||||
<span class="p">]</span> <span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span>
|
||||
<span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))]</span>
|
||||
<span class="p">]</span> <span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)]</span>
|
||||
<span class="n">profiler</span><span class="o">.</span><span class="n">stop</span><span class="p">(</span><span class="s2">"Generate"</span><span class="p">)</span>
|
||||
<span class="k">return</span> <span class="n">stripped_text</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
@ -1397,10 +1332,9 @@
|
||||
<span class="n">tensor_info</span><span class="p">)</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">visual_encoder_session</span><span class="o">.</span><span class="n">set_shapes</span><span class="p">(</span><span class="n">visual_features</span><span class="p">)</span>
|
||||
<span class="n">visual_outputs</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">visual_output_info</span>
|
||||
<span class="p">}</span>
|
||||
|
||||
@ -1443,27 +1377,21 @@
|
||||
<span class="c1"># in case no inter or post prompt</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">)</span> <span class="o">></span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="k">elif</span> <span class="n">batch_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="c1"># mode 2: each image has a specific prompt, input_ids = [<pre1><image1><post1>, <pre2><image2><post2>, ...]</span>
|
||||
<span class="c1"># mode 2: each image have individual prompt, <pre><image><post></span>
|
||||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">visual_feature</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">visual_features</span><span class="p">):</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="n">idx</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">fake_prompt_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span>
|
||||
<span class="n">fake_prompt_counter</span><span class="p">,</span>
|
||||
<span class="n">fake_prompt_counter</span> <span class="o">+</span> <span class="n">visual_feature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||||
<span class="n">fake_prompt_id</span> <span class="o">=</span> <span class="n">fake_prompt_id</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">fake_prompt_id</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="n">idx</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||||
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">),</span> <span class="mi">3</span><span class="p">):</span>
|
||||
<span class="c1"># Concatenate every 3 items (<pre>, <image>, <post>)</span>
|
||||
<span class="n">concatenated</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">[</span><span class="n">i</span><span class="p">:</span><span class="n">i</span> <span class="o">+</span> <span class="mi">3</span><span class="p">],</span>
|
||||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">concatenated</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">result</span>
|
||||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_input_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||||
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_llm</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||||
<span class="n">ptuning_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ptuning_setup</span><span class="p">(</span><span class="n">visual_features</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span>
|
||||
@ -1514,27 +1442,10 @@
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">pre_input_ids</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">post_input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">post_input_ids</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">pre_input_fake_prompt_ids</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">pre_input_ids</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">fake_prompt_id</span><span class="p">)],</span> <span class="n">fake_prompt_id</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">pre_input_fake_prompt_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
||||
<span class="n">pre_input_fake_prompt_ids</span><span class="p">,</span>
|
||||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">pre_input_fake_prompt_id</span><span class="p">,</span>
|
||||
<span class="n">post_input_id</span><span class="p">))</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">pre_input_fake_prompt_id</span><span class="p">,</span> <span class="n">post_input_id</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span>
|
||||
<span class="n">pre_input_fake_prompt_ids</span><span class="p">,</span> <span class="n">post_input_ids</span><span class="p">)</span>
|
||||
<span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_input_ids</span><span class="p">,</span> <span class="n">fake_prompt_id</span><span class="p">,</span> <span class="n">post_input_ids</span><span class="p">]</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
|
||||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_input_ids</span><span class="p">,</span> <span class="n">fake_prompt_id</span><span class="p">,</span> <span class="n">post_input_ids</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">fake_prompt_id</span><span class="p">,</span> <span class="n">pre_input_ids</span><span class="p">]</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
|
||||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_llm</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">()</span>
|
||||
<span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">!=</span> <span class="s2">"mllama"</span><span class="p">:</span>
|
||||
@ -1733,16 +1644,14 @@
|
||||
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)):</span>
|
||||
<span class="n">input_id</span> <span class="o">=</span> <span class="n">input_ids</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_id</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_token_id</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span>
|
||||
<span class="n">input_id</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">vision_token_id</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span><span class="n">input_id</span>
|
||||
<span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">video_token_id</span><span class="p">)</span>
|
||||
<span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_config</span><span class="o">.</span><span class="n">vocab_size</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="n">indices</span><span class="p">:</span>
|
||||
<span class="n">input_id</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">idx</span><span class="p">)]</span> <span class="o">=</span> <span class="n">value</span>
|
||||
<span class="n">value</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
<span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_token_id</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span>
|
||||
<span class="n">input_ids</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">vision_token_id</span><span class="p">)</span> <span class="o">|</span> <span class="p">(</span><span class="n">input_ids</span>
|
||||
<span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">video_token_id</span><span class="p">)</span>
|
||||
<span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||||
<span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_config</span><span class="o">.</span><span class="n">vocab_size</span>
|
||||
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="n">indices</span><span class="p">:</span>
|
||||
<span class="n">input_ids</span><span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="n">idx</span><span class="p">)]</span> <span class="o">=</span> <span class="n">value</span>
|
||||
<span class="n">value</span> <span class="o">+=</span> <span class="mi">1</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_llm</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
||||
<span class="n">ptuning_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ptuning_setup</span><span class="p">(</span><span class="n">visual_features</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span>
|
||||
@ -1884,19 +1793,14 @@
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_llm</span><span class="p">:</span> <span class="n">tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">tasks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">max_length</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">input_id</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">input_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
|
||||
<span class="n">tasks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">),</span> <span class="n">max_length</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||||
<span class="n">tasks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
||||
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">prompt_table</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">task_vocab_size</span><span class="p">]</span></div>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="MultimodalModelRunner.load_test_data">
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MultimodalModelRunner.load_test_data">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">load_test_data</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">video_path</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||||
<div class="viewcode-block" id="MultimodalModelRunner.load_test_image">
|
||||
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.MultimodalModelRunner.load_test_image">[docs]</a>
|
||||
<span class="k">def</span> <span class="nf">load_test_image</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">load_images</span><span class="p">(</span><span class="n">image_paths</span><span class="p">):</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">image_paths</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||||
@ -1914,7 +1818,7 @@
|
||||
<span class="k">return</span> <span class="n">images</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
|
||||
<span class="k">if</span> <span class="s2">"vila"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">img_urls</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="s1">'https://github.com/NVlabs/VILA/blob/6b941da19e31ddfdfaa60160908ccf0978d96615/demo_images/av.png?raw=true'</span><span class="p">,</span>
|
||||
<span class="s1">'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'</span>
|
||||
@ -1924,9 +1828,8 @@
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="o">=</span> <span class="s2">","</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">img_urls</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="n">img_urls</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">image_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="n">image_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">path_sep</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">path_sep</span><span class="p">))</span>
|
||||
|
||||
<span class="k">elif</span> <span class="s2">"nougat"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="n">filepath</span> <span class="o">=</span> <span class="n">hf_hub_download</span><span class="p">(</span>
|
||||
@ -1940,38 +1843,33 @@
|
||||
<span class="n">repo_type</span><span class="o">=</span><span class="s1">'model'</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"kosmos"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"pix2struct"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_40963.png'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_40963.png'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"video-neva"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">video_path</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span>
|
||||
<span class="k">elif</span> <span class="s2">"internvl"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">img_url</span> <span class="o">=</span> <span class="s1">'https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/examples/image1.jpg'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span>
|
||||
<span class="n">requests</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">img_url</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">timeout</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span><span class="o">.</span><span class="n">raw</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"qwen2_vl"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">img_url</span> <span class="o">=</span> <span class="s1">'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span>
|
||||
<span class="n">requests</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">img_url</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">timeout</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span><span class="o">.</span><span class="n">raw</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">504</span><span class="p">,</span> <span class="mi">504</span><span class="p">))</span>
|
||||
<span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
|
||||
<span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">images</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span><span class="p">))</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="p">[]</span>
|
||||
<span class="k">for</span> <span class="n">image_path</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">504</span><span class="p">,</span> <span class="mi">504</span><span class="p">))</span>
|
||||
<span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">'RGB'</span><span class="p">)</span>
|
||||
<span class="k">elif</span> <span class="s2">"llava_onevision"</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span> <span class="o">==</span> <span class="s1">'llava-onevision-accuracy'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span> <span class="o">=</span> <span class="n">hf_hub_download</span><span class="p">(</span>
|
||||
@ -1979,7 +1877,7 @@
|
||||
<span class="n">filename</span><span class="o">=</span><span class="s2">"sample_demo_1.mp4"</span><span class="p">,</span>
|
||||
<span class="n">repo_type</span><span class="o">=</span><span class="s2">"dataset"</span><span class="p">)</span>
|
||||
<span class="kn">import</span> <span class="nn">av</span>
|
||||
<span class="k">with</span> <span class="n">av</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">video_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">container</span><span class="p">:</span>
|
||||
<span class="k">with</span> <span class="n">av</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">video_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">container</span><span class="p">:</span>
|
||||
<span class="n">total_frames</span> <span class="o">=</span> <span class="n">container</span><span class="o">.</span><span class="n">streams</span><span class="o">.</span><span class="n">video</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">frames</span>
|
||||
<span class="k">assert</span> <span class="n">total_frames</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_frames</span>
|
||||
<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">total_frames</span><span class="p">,</span>
|
||||
@ -1997,21 +1895,10 @@
|
||||
<span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">to_ndarray</span><span class="p">(</span><span class="nb">format</span><span class="o">=</span><span class="s2">"rgb24"</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">frames</span><span class="p">])</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">images</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">!=</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"llava"</span><span class="p">:</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="s1">'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'</span>
|
||||
<span class="p">]</span> <span class="o">*</span> <span class="mi">8</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="n">image_path</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">]</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="o">=</span> <span class="s2">","</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">image_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||||
<span class="n">image_path</span> <span class="o">=</span> <span class="n">image_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">path_sep</span><span class="p">)</span>
|
||||
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span> <span class="k">if</span> <span class="n">image_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">!=</span> <span class="s1">'mllama'</span><span class="p">:</span>
|
||||
<span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="o">=</span> <span class="s1">'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'</span>
|
||||
<span class="n">images</span> <span class="o">=</span> <span class="n">load_images</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span>
|
||||
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">image_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||||
<span class="k">return</span> <span class="n">images</span></div>
|
||||
|
||||
|
||||
@ -2021,9 +1908,6 @@
|
||||
<span class="kn">from</span> <span class="nn">..tools.multimodal_builder</span> <span class="kn">import</span> <span class="n">compute_rotary_pos_emb</span>
|
||||
<span class="n">other_vision_inputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="n">other_decoder_inputs</span> <span class="o">=</span> <span class="p">{}</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'qwen2_vl'</span><span class="p">,</span> <span class="s1">'vila'</span><span class="p">,</span> <span class="s1">'llava'</span><span class="p">]:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="n">input_text</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span>
|
||||
<span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">input_text</span>
|
||||
|
||||
<span class="k">if</span> <span class="s1">'blip2'</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">input_text</span><span class="p">,</span>
|
||||
@ -2037,34 +1921,28 @@
|
||||
<span class="kn">from</span> <span class="nn">transformers.models.qwen2_vl.modeling_qwen2_vl</span> <span class="kn">import</span> \
|
||||
<span class="n">VisionRotaryEmbedding</span>
|
||||
<span class="n">hf_config</span> <span class="o">=</span> <span class="n">AutoConfig</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="p">)</span>
|
||||
<span class="n">processor</span> <span class="o">=</span> <span class="n">AutoProcessor</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="p">)</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"Question: Describe this image. Answer:"</span>
|
||||
<span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="n">messages</span> <span class="o">=</span> <span class="p">[[{</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"Question: Describe this image. Answer:"</span>
|
||||
<span class="n">messages</span> <span class="o">=</span> <span class="p">[{</span>
|
||||
<span class="s2">"role"</span><span class="p">:</span>
|
||||
<span class="s2">"user"</span><span class="p">,</span>
|
||||
<span class="s2">"content"</span><span class="p">:</span> <span class="p">[</span>
|
||||
<span class="p">{</span>
|
||||
<span class="s2">"type"</span><span class="p">:</span> <span class="s2">"image"</span><span class="p">,</span>
|
||||
<span class="s2">"image"</span><span class="p">:</span> <span class="n">raw_image</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span>
|
||||
<span class="s2">"image"</span><span class="p">:</span> <span class="n">raw_image</span><span class="p">,</span>
|
||||
<span class="p">},</span>
|
||||
<span class="p">{</span>
|
||||
<span class="s2">"type"</span><span class="p">:</span> <span class="s2">"text"</span><span class="p">,</span>
|
||||
<span class="s2">"text"</span><span class="p">:</span> <span class="n">input_text</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span>
|
||||
<span class="s2">"text"</span><span class="p">:</span> <span class="n">input_text</span>
|
||||
<span class="p">},</span>
|
||||
<span class="p">],</span>
|
||||
<span class="p">}]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)]</span>
|
||||
<span class="p">}]</span>
|
||||
|
||||
<span class="n">texts</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">processor</span><span class="o">.</span><span class="n">apply_chat_template</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span>
|
||||
<span class="n">tokenize</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||||
<span class="n">add_generation_prompt</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">msg</span> <span class="ow">in</span> <span class="n">messages</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="o">.</span><span class="n">apply_chat_template</span><span class="p">(</span>
|
||||
<span class="n">messages</span><span class="p">,</span> <span class="n">tokenize</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">add_generation_prompt</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="n">image_inputs</span><span class="p">,</span> <span class="n">video_inputs</span> <span class="o">=</span> <span class="n">process_vision_info</span><span class="p">(</span><span class="n">messages</span><span class="p">)</span>
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span>
|
||||
<span class="n">text</span><span class="o">=</span><span class="n">texts</span><span class="p">,</span>
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span>
|
||||
<span class="n">text</span><span class="o">=</span><span class="p">[</span><span class="n">text</span><span class="p">],</span>
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">image_inputs</span><span class="p">,</span>
|
||||
<span class="n">videos</span><span class="o">=</span><span class="n">video_inputs</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
@ -2080,17 +1958,17 @@
|
||||
<span class="n">image_grid_thw</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||||
<span class="n">cu_seqlens</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">cu_seqlens</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||||
<span class="n">seq_length</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||||
<span class="n">attention_mask_vit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">],</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">,</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||||
<span class="n">attention_mask_vit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">],</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">cu_seqlens</span><span class="p">)):</span>
|
||||
<span class="n">attention_mask_vit</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]:</span><span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
||||
<span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]:</span><span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">0</span>
|
||||
<span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]:</span><span class="n">cu_seqlens</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
|
||||
<span class="n">decoder_input_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="n">images_qwenvl</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="s2">"image"</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span>
|
||||
<span class="s2">"input_ids"</span><span class="p">:</span> <span class="n">input_ids</span><span class="p">,</span>
|
||||
@ -2195,65 +2073,39 @@
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">)</span>
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'vila'</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"<image></span><span class="se">\n</span><span class="s2"> Please elaborate what you see in the images?"</span>
|
||||
<span class="k">if</span> <span class="s1">'8b'</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="o">.</span><span class="n">lower</span><span class="p">():</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"<|begin_of_text|><|start_header_id|>system<|end_header_id|></span><span class="se">\n\n</span><span class="s2">You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|>user<|end_header_id|></span><span class="se">\n\n</span><span class="s2">"</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="s2">"<|eot_id|><|start_header_id|>assistant<|end_header_id|></span><span class="se">\n\n</span><span class="s2">"</span>
|
||||
<span class="k">elif</span> <span class="s1">'40b'</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">hf_model_dir</span><span class="o">.</span><span class="n">lower</span><span class="p">():</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"<|im_start|>system</span><span class="se">\n</span><span class="s2">Answer the questions.<|im_end|><|im_start|>user</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="s2">"<|im_end|><|im_start|>assistant</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="s2">" ASSISTANT:"</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span> <span class="o">+</span> <span class="n">post_prompt</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_text</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="n">input_text</span> <span class="o">+</span> <span class="n">post_prompt</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">raw_image</span> <span class="o">=</span> <span class="p">[</span><span class="n">raw_image</span><span class="p">]</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span>
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'llava'</span><span class="p">,</span> <span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">]:</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'llava'</span><span class="p">,</span> <span class="s1">'vila'</span><span class="p">,</span> <span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">]:</span>
|
||||
<span class="c1"># LLaVA and VILA</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"llava"</span><span class="p">:</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="s2">" ASSISTANT:"</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"USER:</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"</span><span class="se">\n</span><span class="s2"> Which city is this? Answer:"</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span> <span class="o">+</span> <span class="n">post_prompt</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_text</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="n">input_text</span> <span class="o">+</span> <span class="n">post_prompt</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"Question: which city is this? Answer:"</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"vila"</span><span class="p">:</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"<image></span><span class="se">\n</span><span class="s2"> Please elaborate what you see in the images?"</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s1">'fuyu'</span><span class="p">:</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">"Describe this image:"</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"Answer the following VQAv2 question based on the image: How many people are in the image?</span><span class="se">\n</span><span class="s2">"</span>
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"kosmos-2"</span><span class="p">:</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="s2">""</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="n">input_text</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="s2">"<grounding>An image of"</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">]:</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="s2">" ASSISTANT:"</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span> <span class="o">+</span> <span class="n">post_prompt</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_text</span><span class="p">]</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="n">input_text</span> <span class="o">+</span> <span class="n">post_prompt</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="n">input_text</span> <span class="o">+</span> <span class="s2">" ASSISTANT:"</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">]:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">,</span>
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s1">'pt'</span><span class="p">)</span>
|
||||
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="o">==</span> <span class="s2">"vila"</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">raw_image</span> <span class="o">=</span> <span class="p">[</span><span class="n">raw_image</span><span class="p">]</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">]:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">,</span>
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span>
|
||||
<span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">)[</span><span class="s1">'pixel_values'</span><span class="p">]</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s1">'pt'</span><span class="p">)</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">,</span>
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span>
|
||||
@ -2261,8 +2113,6 @@
|
||||
|
||||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'mllama'</span><span class="p">]:</span>
|
||||
<span class="k">if</span> <span class="n">raw_image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">input_text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="o">.</span><span class="n">apply_chat_template</span><span class="p">(</span>
|
||||
<span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">)</span>
|
||||
<span class="n">inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">images</span><span class="o">=</span><span class="n">raw_image</span><span class="p">,</span>
|
||||
<span class="n">text</span><span class="o">=</span><span class="n">input_text</span><span class="p">,</span>
|
||||
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">)</span>
|
||||
@ -2308,8 +2158,7 @@
|
||||
|
||||
<span class="c1"># Repeat inputs to match batch size</span>
|
||||
<span class="n">pre_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="n">pre_prompt</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="n">post_prompt</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="n">post_prompt</span> <span class="o">=</span> <span class="p">[</span><span class="n">post_prompt</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span>
|
||||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
|
||||
<span class="s1">'fuyu'</span><span class="p">,</span> <span class="s1">'pix2struct'</span><span class="p">,</span> <span class="s1">'kosmos-2'</span><span class="p">,</span> <span class="s1">'vila'</span><span class="p">,</span> <span class="s1">'phi-3-vision'</span><span class="p">,</span>
|
||||
<span class="s1">'llava_next'</span><span class="p">,</span> <span class="s1">'internvl'</span><span class="p">,</span> <span class="s1">'llava_onevision'</span>
|
||||
@ -2322,13 +2171,8 @@
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||||
<span class="k">else</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
|
||||
<span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_text</span><span class="p">)),</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||||
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
||||
<span class="k">if</span> <span class="n">image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
||||
<span class="c1"># Generate decoder_input_ids for enc-dec models</span>
|
||||
@ -2378,7 +2222,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e502480>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac93f6f00>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -460,10 +461,9 @@
|
||||
<span class="p">]</span>
|
||||
<span class="n">outputs_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">infer_shapes</span><span class="p">(</span><span class="n">inputs_info</span><span class="p">)</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{</span>
|
||||
<span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
|
||||
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span>
|
||||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span>
|
||||
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
||||
<span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">outputs_info</span>
|
||||
<span class="p">}</span>
|
||||
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
|
||||
@ -482,7 +482,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e5d7d40>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac938c620>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
<script src="../../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -110,6 +110,7 @@
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||||
<ul>
|
||||
@ -204,19 +205,6 @@
|
||||
|
||||
|
||||
|
||||
<span class="nd">@dataclass</span><span class="p">(</span><span class="n">slots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">kw_only</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
<span class="k">class</span> <span class="nc">AdditionalModelOutput</span><span class="p">:</span>
|
||||
<span class="w"> </span><span class="sd">"""</span>
|
||||
<span class="sd"> An additional output to gather from the model.</span>
|
||||
|
||||
<span class="sd"> Args:</span>
|
||||
<span class="sd"> name (str): The name of the additional output to gather from the model.</span>
|
||||
<span class="sd"> gather_context (bool): A value indicating whether or not to gather the additional output from the context too. Defaults to False.</span>
|
||||
<span class="sd"> """</span>
|
||||
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span>
|
||||
<span class="n">gather_context</span><span class="p">:</span> <span class="nb">bool</span>
|
||||
|
||||
|
||||
<div class="viewcode-block" id="SamplingParams">
|
||||
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.SamplingParams">[docs]</a>
|
||||
<span class="nd">@dataclass</span><span class="p">(</span><span class="n">slots</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">kw_only</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||||
@ -269,7 +257,6 @@
|
||||
<span class="sd"> exclude_input_from_output (bool): Controls if output tokens in Result should include the input tokens. Default is true.</span>
|
||||
<span class="sd"> return_encoder_output (bool): Controls if Result should contain encoder output hidden states (for encoder-only and encoder-decoder models). Default is false.</span>
|
||||
<span class="sd"> return_perf_metrics (bool): Controls if Result should contain the performance metrics for this request. Default is false.</span>
|
||||
<span class="sd"> additional_model_outputs (list[AdditionalModelOutput], optional): The additional outputs to gather from the model.</span>
|
||||
|
||||
<span class="sd"> lookahead_config (LookaheadDecodingConfig , optional): Lookahead decoding config. Defaults to None.</span>
|
||||
<span class="sd"> guided_decoding (GuidedDecodingParams, optional): Guided decoding params. Defaults to None.</span>
|
||||
@ -344,7 +331,6 @@
|
||||
<span class="n">exclude_input_from_output</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||||
<span class="n">return_encoder_output</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">return_perf_metrics</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||||
<span class="n">additional_model_outputs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">list</span><span class="p">[</span><span class="n">AdditionalModelOutput</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
|
||||
<span class="c1"># Lookahead decoding config</span>
|
||||
<span class="n">lookahead_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">tllme</span><span class="o">.</span><span class="n">LookaheadDecodingConfig</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||||
@ -568,8 +554,7 @@
|
||||
<span class="n">expected_fields</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="s2">"return_log_probs"</span><span class="p">,</span> <span class="s2">"return_context_logits"</span><span class="p">,</span>
|
||||
<span class="s2">"return_generation_logits"</span><span class="p">,</span> <span class="s2">"exclude_input_from_output"</span><span class="p">,</span>
|
||||
<span class="s2">"return_encoder_output"</span><span class="p">,</span> <span class="s2">"return_perf_metrics"</span><span class="p">,</span>
|
||||
<span class="s2">"additional_model_outputs"</span>
|
||||
<span class="s2">"return_encoder_output"</span><span class="p">,</span> <span class="s2">"return_perf_metrics"</span>
|
||||
<span class="p">]</span>
|
||||
<span class="n">found_fields</span> <span class="o">=</span> <span class="p">[</span>
|
||||
<span class="n">f</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="nb">dir</span><span class="p">(</span><span class="n">tllme</span><span class="o">.</span><span class="n">OutputConfig</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">f</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'__'</span><span class="p">)</span>
|
||||
@ -616,7 +601,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x7f5e6e3542f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac93f56d0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -34,9 +34,3 @@ _______
|
||||
.. doxygenfile:: types.h
|
||||
:project: TensorRT-LLM
|
||||
|
||||
version.h
|
||||
_________
|
||||
|
||||
.. doxygenfile:: version.h
|
||||
:project: TensorRT-LLM
|
||||
|
||||
|
||||
@ -1,12 +1,5 @@
|
||||
/*
|
||||
* basic.css
|
||||
* ~~~~~~~~~
|
||||
*
|
||||
* Sphinx stylesheet -- basic theme.
|
||||
*
|
||||
* :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
|
||||
/* -- main layout ----------------------------------------------------------- */
|
||||
@ -115,15 +108,11 @@ img {
|
||||
/* -- search page ----------------------------------------------------------- */
|
||||
|
||||
ul.search {
|
||||
margin: 10px 0 0 20px;
|
||||
padding: 0;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
ul.search li {
|
||||
padding: 5px 0 5px 20px;
|
||||
background-image: url(file.png);
|
||||
background-repeat: no-repeat;
|
||||
background-position: 0 7px;
|
||||
padding: 5px 0;
|
||||
}
|
||||
|
||||
ul.search li a {
|
||||
|
||||
@ -1,12 +1,5 @@
|
||||
/*
|
||||
* doctools.js
|
||||
* ~~~~~~~~~~~
|
||||
*
|
||||
* Base JavaScript utilities for all Sphinx HTML documentation.
|
||||
*
|
||||
* :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
"use strict";
|
||||
|
||||
|
||||
@ -1,13 +1,6 @@
|
||||
/*
|
||||
* language_data.js
|
||||
* ~~~~~~~~~~~~~~~~
|
||||
*
|
||||
* This script contains the language-specific data used by searchtools.js,
|
||||
* namely the list of stopwords, stemmer, scorer and splitter.
|
||||
*
|
||||
* :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
|
||||
var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"];
|
||||
|
||||
@ -1,12 +1,5 @@
|
||||
/*
|
||||
* searchtools.js
|
||||
* ~~~~~~~~~~~~~~~~
|
||||
*
|
||||
* Sphinx JavaScript utilities for the full-text search.
|
||||
*
|
||||
* :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS.
|
||||
* :license: BSD, see LICENSE for details.
|
||||
*
|
||||
*/
|
||||
"use strict";
|
||||
|
||||
@ -20,7 +13,7 @@ if (typeof Scorer === "undefined") {
|
||||
// and returns the new score.
|
||||
/*
|
||||
score: result => {
|
||||
const [docname, title, anchor, descr, score, filename] = result
|
||||
const [docname, title, anchor, descr, score, filename, kind] = result
|
||||
return score
|
||||
},
|
||||
*/
|
||||
@ -47,6 +40,14 @@ if (typeof Scorer === "undefined") {
|
||||
};
|
||||
}
|
||||
|
||||
// Global search result kind enum, used by themes to style search results.
|
||||
class SearchResultKind {
|
||||
static get index() { return "index"; }
|
||||
static get object() { return "object"; }
|
||||
static get text() { return "text"; }
|
||||
static get title() { return "title"; }
|
||||
}
|
||||
|
||||
const _removeChildren = (element) => {
|
||||
while (element && element.lastChild) element.removeChild(element.lastChild);
|
||||
};
|
||||
@ -64,9 +65,13 @@ const _displayItem = (item, searchTerms, highlightTerms) => {
|
||||
const showSearchSummary = DOCUMENTATION_OPTIONS.SHOW_SEARCH_SUMMARY;
|
||||
const contentRoot = document.documentElement.dataset.content_root;
|
||||
|
||||
const [docName, title, anchor, descr, score, _filename] = item;
|
||||
const [docName, title, anchor, descr, score, _filename, kind] = item;
|
||||
|
||||
let listItem = document.createElement("li");
|
||||
// Add a class representing the item's type:
|
||||
// can be used by a theme's CSS selector for styling
|
||||
// See SearchResultKind for the class names.
|
||||
listItem.classList.add(`kind-${kind}`);
|
||||
let requestUrl;
|
||||
let linkUrl;
|
||||
if (docBuilder === "dirhtml") {
|
||||
@ -115,8 +120,10 @@ const _finishSearch = (resultCount) => {
|
||||
"Your search did not match any documents. Please make sure that all words are spelled correctly and that you've selected enough categories."
|
||||
);
|
||||
else
|
||||
Search.status.innerText = _(
|
||||
"Search finished, found ${resultCount} page(s) matching the search query."
|
||||
Search.status.innerText = Documentation.ngettext(
|
||||
"Search finished, found one page matching the search query.",
|
||||
"Search finished, found ${resultCount} pages matching the search query.",
|
||||
resultCount,
|
||||
).replace('${resultCount}', resultCount);
|
||||
};
|
||||
const _displayNextItem = (
|
||||
@ -138,7 +145,7 @@ const _displayNextItem = (
|
||||
else _finishSearch(resultCount);
|
||||
};
|
||||
// Helper function used by query() to order search results.
|
||||
// Each input is an array of [docname, title, anchor, descr, score, filename].
|
||||
// Each input is an array of [docname, title, anchor, descr, score, filename, kind].
|
||||
// Order the results by score (in opposite order of appearance, since the
|
||||
// `_displayNextItem` function uses pop() to retrieve items) and then alphabetically.
|
||||
const _orderResultsByScoreThenName = (a, b) => {
|
||||
@ -248,6 +255,7 @@ const Search = {
|
||||
searchSummary.classList.add("search-summary");
|
||||
searchSummary.innerText = "";
|
||||
const searchList = document.createElement("ul");
|
||||
searchList.setAttribute("role", "list");
|
||||
searchList.classList.add("search");
|
||||
|
||||
const out = document.getElementById("search-results");
|
||||
@ -318,7 +326,7 @@ const Search = {
|
||||
const indexEntries = Search._index.indexentries;
|
||||
|
||||
// Collect multiple result groups to be sorted separately and then ordered.
|
||||
// Each is an array of [docname, title, anchor, descr, score, filename].
|
||||
// Each is an array of [docname, title, anchor, descr, score, filename, kind].
|
||||
const normalResults = [];
|
||||
const nonMainIndexResults = [];
|
||||
|
||||
@ -337,6 +345,7 @@ const Search = {
|
||||
null,
|
||||
score + boost,
|
||||
filenames[file],
|
||||
SearchResultKind.title,
|
||||
]);
|
||||
}
|
||||
}
|
||||
@ -354,6 +363,7 @@ const Search = {
|
||||
null,
|
||||
score,
|
||||
filenames[file],
|
||||
SearchResultKind.index,
|
||||
];
|
||||
if (isMain) {
|
||||
normalResults.push(result);
|
||||
@ -475,6 +485,7 @@ const Search = {
|
||||
descr,
|
||||
score,
|
||||
filenames[match[0]],
|
||||
SearchResultKind.object,
|
||||
]);
|
||||
};
|
||||
Object.keys(objects).forEach((prefix) =>
|
||||
@ -585,6 +596,7 @@ const Search = {
|
||||
null,
|
||||
score,
|
||||
filenames[file],
|
||||
SearchResultKind.text,
|
||||
]);
|
||||
}
|
||||
return results;
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -270,7 +270,7 @@ If version of UCX =1.18, set <code class="docutils literal notranslate"><span cl
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a60f079a00>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7a46de0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -290,7 +290,7 @@ the TensorRT-LLM C++ Executor API.</p>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b2777a0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac78946b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -202,7 +202,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61acb05f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac799a4e0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -522,7 +522,7 @@ is computed as:</p>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a60fe348c0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7b227b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -575,7 +575,7 @@ one.</p></li>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61120b4a0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ace887bf0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -381,7 +381,7 @@ techniques to optimize the underlying graph. It provides a wrapper similar to P
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61ade8fe0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac799a000>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -397,7 +397,7 @@ The mandatory input tensors to create a valid <code class="docutils literal notr
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b98bec0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7912750>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -285,7 +285,7 @@ Assume vocaburlay size is 100, which means normal text token ids are in range [0
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b117560>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7994170>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -359,7 +359,7 @@ The following tensors are for a LoRA which has a <code class="docutils literal n
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b98ae70>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7896990>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -619,7 +619,7 @@ However, similar to any new model, you can follow the same approach to define yo
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61ab41b80>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac79ab1a0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -233,7 +233,7 @@ python3<span class="w"> </span>examples/summarize.py<span class="w"> </span><spa
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61ab86ab0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac79aba70>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -272,7 +272,7 @@ python<span class="w"> </span>../summarize.py<span class="w"> </span>--engine_di
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b159b20>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac799e1b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -537,7 +537,7 @@ trtllm-build<span class="w"> </span>--checkpoint_dir<span class="w"> </span>./op
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61ac6e900>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac77ee900>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -536,7 +536,7 @@ srun<span class="w"> </span><span class="se">\</span>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61ab872f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac76d5ac0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -457,7 +457,7 @@ The support for Qwen-1 is in <code class="docutils literal notranslate"><span cl
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a610f81580>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac796a300>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -190,7 +190,7 @@ Server</a> to easily create web-based services for LLMs. TensorRT-LLM supports m
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b8feb10>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac799c620>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -369,7 +369,7 @@ The usage of this API looks like this:</p>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a610e15850>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac76d01a0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -327,7 +327,7 @@ ISL = Input Sequence Length
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b8feae0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7831370>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -279,7 +279,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b114740>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac7804770>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -271,7 +271,7 @@ TensorRT-LLM v0.5.0, TensorRT v9.1.0.4 | H200, H100 FP8. </sub></p>
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61c0d60c0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac76d36e0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -236,7 +236,7 @@ ISL = Input Sequence Length
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b19e8d0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac75cf2f0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -391,7 +391,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a610e17cb0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74cb470>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -87,7 +87,15 @@
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">trtllm-build</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">trtllm-build</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-named-arguments">Named Arguments</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-logits-arguments">Logits arguments</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-lora-arguments">LoRA arguments</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-speculative-decoding-arguments">Speculative decoding arguments</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-auto-parallel-arguments">Auto parallel arguments</a></li>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#tensorrt_llm.commands.build-parse_arguments-plugin-config-arguments">Plugin config arguments</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
<li class="toctree-l1"><a class="reference internal" href="trtllm-serve.html">trtllm-serve</a></li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||||
@ -164,6 +172,352 @@
|
||||
|
||||
<section id="trtllm-build">
|
||||
<h1>trtllm-build<a class="headerlink" href="#trtllm-build" title="Link to this heading"></a></h1>
|
||||
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">usage</span><span class="p">:</span> <span class="n">trtllm</span><span class="o">-</span><span class="n">build</span> <span class="p">[</span><span class="o">-</span><span class="n">h</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">checkpoint_dir</span> <span class="n">CHECKPOINT_DIR</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">model_config</span> <span class="n">MODEL_CONFIG</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">build_config</span> <span class="n">BUILD_CONFIG</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">model_cls_file</span> <span class="n">MODEL_CLS_FILE</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">model_cls_name</span> <span class="n">MODEL_CLS_NAME</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">output_dir</span> <span class="n">OUTPUT_DIR</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_batch_size</span> <span class="n">MAX_BATCH_SIZE</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">max_input_len</span> <span class="n">MAX_INPUT_LEN</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">max_seq_len</span> <span class="n">MAX_SEQ_LEN</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_beam_width</span> <span class="n">MAX_BEAM_WIDTH</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">max_num_tokens</span> <span class="n">MAX_NUM_TOKENS</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">opt_num_tokens</span> <span class="n">OPT_NUM_TOKENS</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_encoder_input_len</span> <span class="n">MAX_ENCODER_INPUT_LEN</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_prompt_embedding_table_size</span> <span class="n">MAX_PROMPT_EMBEDDING_TABLE_SIZE</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">kv_cache_type</span> <span class="n">KV_CACHE_TYPE</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">paged_kv_cache</span> <span class="n">PAGED_KV_CACHE</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">input_timing_cache</span> <span class="n">INPUT_TIMING_CACHE</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">output_timing_cache</span> <span class="n">OUTPUT_TIMING_CACHE</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">profiling_verbosity</span> <span class="p">{</span><span class="n">layer_names_only</span><span class="p">,</span><span class="n">detailed</span><span class="p">,</span><span class="n">none</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">strip_plan</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">weight_sparsity</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">weight_streaming</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">fast_build</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">workers</span> <span class="n">WORKERS</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">log_level</span> <span class="p">{</span><span class="n">internal_error</span><span class="p">,</span><span class="n">error</span><span class="p">,</span><span class="n">warning</span><span class="p">,</span><span class="n">info</span><span class="p">,</span><span class="n">verbose</span><span class="p">,</span><span class="n">debug</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">enable_debug_output</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">visualize_network</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">dry_run</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">monitor_memory</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">logits_dtype</span> <span class="p">{</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">gather_context_logits</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">gather_generation_logits</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">gather_all_token_logits</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">lora_dir</span> <span class="n">LORA_DIR</span> <span class="p">[</span><span class="n">LORA_DIR</span> <span class="o">...</span><span class="p">]]</span> <span class="p">[</span><span class="o">--</span><span class="n">lora_ckpt_source</span> <span class="p">{</span><span class="n">hf</span><span class="p">,</span><span class="n">nemo</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">lora_target_modules</span> <span class="p">{</span><span class="n">attn_qkv</span><span class="p">,</span><span class="n">attn_q</span><span class="p">,</span><span class="n">attn_k</span><span class="p">,</span><span class="n">attn_v</span><span class="p">,</span><span class="n">attn_dense</span><span class="p">,</span><span class="n">mlp_h_to_4h</span><span class="p">,</span><span class="n">mlp_4h_to_h</span><span class="p">,</span><span class="n">mlp_gate</span><span class="p">,</span><span class="n">cross_attn_qkv</span><span class="p">,</span><span class="n">cross_attn_q</span><span class="p">,</span><span class="n">cross_attn_k</span><span class="p">,</span><span class="n">cross_attn_v</span><span class="p">,</span><span class="n">cross_attn_dense</span><span class="p">,</span><span class="n">moe_h_to_4h</span><span class="p">,</span><span class="n">moe_4h_to_h</span><span class="p">,</span><span class="n">moe_gate</span><span class="p">,</span><span class="n">moe_router</span><span class="p">,</span><span class="n">mlp_router</span><span class="p">}</span> <span class="p">[{</span><span class="n">attn_qkv</span><span class="p">,</span><span class="n">attn_q</span><span class="p">,</span><span class="n">attn_k</span><span class="p">,</span><span class="n">attn_v</span><span class="p">,</span><span class="n">attn_dense</span><span class="p">,</span><span class="n">mlp_h_to_4h</span><span class="p">,</span><span class="n">mlp_4h_to_h</span><span class="p">,</span><span class="n">mlp_gate</span><span class="p">,</span><span class="n">cross_attn_qkv</span><span class="p">,</span><span class="n">cross_attn_q</span><span class="p">,</span><span class="n">cross_attn_k</span><span class="p">,</span><span class="n">cross_attn_v</span><span class="p">,</span><span class="n">cross_attn_dense</span><span class="p">,</span><span class="n">moe_h_to_4h</span><span class="p">,</span><span class="n">moe_4h_to_h</span><span class="p">,</span><span class="n">moe_gate</span><span class="p">,</span><span class="n">moe_router</span><span class="p">,</span><span class="n">mlp_router</span><span class="p">}</span> <span class="o">...</span><span class="p">]]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_lora_rank</span> <span class="n">MAX_LORA_RANK</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">speculative_decoding_mode</span> <span class="p">{</span><span class="n">draft_tokens_external</span><span class="p">,</span><span class="n">lookahead_decoding</span><span class="p">,</span><span class="n">medusa</span><span class="p">,</span><span class="n">explicit_draft_tokens</span><span class="p">,</span><span class="n">eagle</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">max_draft_len</span> <span class="n">MAX_DRAFT_LEN</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">auto_parallel</span> <span class="n">AUTO_PARALLEL</span><span class="p">]</span> <span class="p">[</span><span class="o">--</span><span class="n">gpus_per_node</span> <span class="n">GPUS_PER_NODE</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">cluster_key</span> <span class="p">{</span><span class="n">A100</span><span class="o">-</span><span class="n">SXM</span><span class="o">-</span><span class="mi">80</span><span class="n">GB</span><span class="p">,</span><span class="n">A100</span><span class="o">-</span><span class="n">SXM</span><span class="o">-</span><span class="mi">40</span><span class="n">GB</span><span class="p">,</span><span class="n">A100</span><span class="o">-</span><span class="n">PCIe</span><span class="o">-</span><span class="mi">80</span><span class="n">GB</span><span class="p">,</span><span class="n">A100</span><span class="o">-</span><span class="n">PCIe</span><span class="o">-</span><span class="mi">40</span><span class="n">GB</span><span class="p">,</span><span class="n">H100</span><span class="o">-</span><span class="n">SXM</span><span class="p">,</span><span class="n">H100</span><span class="o">-</span><span class="n">PCIe</span><span class="p">,</span><span class="n">H20</span><span class="p">,</span><span class="n">V100</span><span class="o">-</span><span class="n">PCIe</span><span class="o">-</span><span class="mi">16</span><span class="n">GB</span><span class="p">,</span><span class="n">V100</span><span class="o">-</span><span class="n">PCIe</span><span class="o">-</span><span class="mi">32</span><span class="n">GB</span><span class="p">,</span><span class="n">V100</span><span class="o">-</span><span class="n">SXM</span><span class="o">-</span><span class="mi">16</span><span class="n">GB</span><span class="p">,</span><span class="n">V100</span><span class="o">-</span><span class="n">SXM</span><span class="o">-</span><span class="mi">32</span><span class="n">GB</span><span class="p">,</span><span class="n">V100S</span><span class="o">-</span><span class="n">PCIe</span><span class="p">,</span><span class="n">A40</span><span class="p">,</span><span class="n">A30</span><span class="p">,</span><span class="n">A10</span><span class="p">,</span><span class="n">A10G</span><span class="p">,</span><span class="n">L40S</span><span class="p">,</span><span class="n">L40</span><span class="p">,</span><span class="n">L20</span><span class="p">,</span><span class="n">L4</span><span class="p">,</span><span class="n">L2</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">bert_attention_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">gpt_attention_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">gemm_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">fp8</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">gemm_swiglu_plugin</span> <span class="p">{</span><span class="n">fp8</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">fp8_rowwise_gemm_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">nccl_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">lora_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">moe_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">mamba_conv1d_plugin</span> <span class="p">{</span><span class="n">auto</span><span class="p">,</span><span class="n">float16</span><span class="p">,</span><span class="n">float32</span><span class="p">,</span><span class="n">bfloat16</span><span class="p">,</span><span class="n">int32</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">low_latency_gemm_plugin</span> <span class="p">{</span><span class="n">fp8</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">low_latency_gemm_swiglu_plugin</span> <span class="p">{</span><span class="n">fp8</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">context_fmha</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">bert_context_fmha_fp32_acc</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">remove_input_padding</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">reduce_fusion</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">user_buffer</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">tokens_per_block</span> <span class="n">TOKENS_PER_BLOCK</span><span class="p">]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">use_paged_context_fmha</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">use_fp8_context_fmha</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">multiple_profiles</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">paged_state</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">streamingllm</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span> <span class="p">[</span><span class="o">--</span><span class="n">use_fused_mlp</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
<span class="p">[</span><span class="o">--</span><span class="n">pp_reduce_scatter</span> <span class="p">{</span><span class="n">enable</span><span class="p">,</span><span class="n">disable</span><span class="p">}]</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-named-arguments">
|
||||
<h2>Named Arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-named-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--checkpoint_dir</kbd></dt>
|
||||
<dd><p>The directory path that contains TensorRT-LLM checkpoint.</p>
|
||||
</dd>
|
||||
<dt><kbd>--model_config</kbd></dt>
|
||||
<dd><p>The file path that saves TensorRT-LLM checkpoint config.</p>
|
||||
</dd>
|
||||
<dt><kbd>--build_config</kbd></dt>
|
||||
<dd><p>The file path that saves TensorRT-LLM build config.</p>
|
||||
</dd>
|
||||
<dt><kbd>--model_cls_file</kbd></dt>
|
||||
<dd><p>The file path that defines customized TensorRT-LLM model.</p>
|
||||
</dd>
|
||||
<dt><kbd>--model_cls_name</kbd></dt>
|
||||
<dd><p>The customized TensorRT-LLM model class name.</p>
|
||||
</dd>
|
||||
<dt><kbd>--output_dir</kbd></dt>
|
||||
<dd><p>The directory path to save the serialized engine files and engine config file.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'engine_outputs'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--max_batch_size</kbd></dt>
|
||||
<dd><p>Maximum number of requests that the engine can schedule.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">2048</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--max_input_len</kbd></dt>
|
||||
<dd><p>Maximum input length of one request.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">1024</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--max_seq_len, --max_decoder_seq_len</kbd></dt>
|
||||
<dd><p>Maximum total length of one request, including prompt and outputs. If unspecified, the value is deduced from the model config.</p>
|
||||
</dd>
|
||||
<dt><kbd>--max_beam_width</kbd></dt>
|
||||
<dd><p>Maximum number of beams for beam search decoding.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">1</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--max_num_tokens</kbd></dt>
|
||||
<dd><p>Maximum number of batched input tokens after padding is removed in each batch. Currently, the input padding is removed by default; you may explicitly disable it by specifying <code class="docutils literal notranslate"><span class="pre">--remove_input_padding</span> <span class="pre">disable</span></code>.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">8192</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--opt_num_tokens</kbd></dt>
|
||||
<dd><p>Optimal number of batched input tokens after padding is removed in each batch It equals to <code class="docutils literal notranslate"><span class="pre">max_batch_size</span> <span class="pre">*</span> <span class="pre">max_beam_width</span></code> by default, set this value as close as possible to the actual number of tokens on your workload. Note that this argument might be removed in the future.</p>
|
||||
</dd>
|
||||
<dt><kbd>--max_encoder_input_len</kbd></dt>
|
||||
<dd><p>Maximum encoder input length for enc-dec models. Set <code class="docutils literal notranslate"><span class="pre">max_input_len</span></code> to 1 to start generation from decoder_start_token_id of length 1.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">1024</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--max_prompt_embedding_table_size, --max_multimodal_len</kbd></dt>
|
||||
<dd><p>Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. Setting a value > 0 enables prompt tuning or multimodal input.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">0</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--kv_cache_type</kbd></dt>
|
||||
<dd><p>Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed.</p>
|
||||
</dd>
|
||||
<dt><kbd>--paged_kv_cache</kbd></dt>
|
||||
<dd><p>Deprecated. Enabling this option is equvilient to <code class="docutils literal notranslate"><span class="pre">--kv_cache_type</span> <span class="pre">paged</span></code> for transformer based models.</p>
|
||||
</dd>
|
||||
<dt><kbd>--input_timing_cache</kbd></dt>
|
||||
<dd><p>The file path to read the timing cache. This option is ignored if the file does not exist.</p>
|
||||
</dd>
|
||||
<dt><kbd>--output_timing_cache</kbd></dt>
|
||||
<dd><p>The file path to write the timing cache.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'model.cache'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--profiling_verbosity</kbd></dt>
|
||||
<dd><p>Possible choices: layer_names_only, detailed, none</p>
|
||||
<p>The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'layer_names_only'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--strip_plan</kbd></dt>
|
||||
<dd><p>Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--weight_sparsity</kbd></dt>
|
||||
<dd><p>Enable weight sparsity.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--weight_streaming</kbd></dt>
|
||||
<dd><p>Enable offloading weights to CPU and streaming loading at runtime.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--fast_build</kbd></dt>
|
||||
<dd><p>Enable features for faster engine building. This may cause some performance degradation and is currently incompatible with int8/int4 quantization without plugin.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--workers</kbd></dt>
|
||||
<dd><p>The number of workers for building in parallel.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">1</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--log_level</kbd></dt>
|
||||
<dd><p>Possible choices: internal_error, error, warning, info, verbose, debug</p>
|
||||
<p>The logging level.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'info'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--enable_debug_output</kbd></dt>
|
||||
<dd><p>Enable debug output.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--visualize_network</kbd></dt>
|
||||
<dd><p>Export TensorRT Networks to ONNX prior to Engine build for debugging.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--dry_run</kbd></dt>
|
||||
<dd><p>Run through the build process except the actual Engine build for debugging.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--monitor_memory</kbd></dt>
|
||||
<dd><p>Enable memory monitor during Engine build.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-logits-arguments">
|
||||
<h2>Logits arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-logits-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--logits_dtype</kbd></dt>
|
||||
<dd><p>Possible choices: float16, float32</p>
|
||||
<p>The data type of logits.</p>
|
||||
</dd>
|
||||
<dt><kbd>--gather_context_logits</kbd></dt>
|
||||
<dd><p>Enable gathering context logits.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gather_generation_logits</kbd></dt>
|
||||
<dd><p>Enable gathering generation logits.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gather_all_token_logits</kbd></dt>
|
||||
<dd><p>Enable both <code class="docutils literal notranslate"><span class="pre">gather_context_logits</span></code> and <code class="docutils literal notranslate"><span class="pre">gather_generation_logits</span></code>.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">False</span></code></p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-lora-arguments">
|
||||
<h2>LoRA arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-lora-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--lora_dir</kbd></dt>
|
||||
<dd><p>The directory of LoRA weights. If multiple directories are provided, the first one is used for configuration.</p>
|
||||
</dd>
|
||||
<dt><kbd>--lora_ckpt_source</kbd></dt>
|
||||
<dd><p>Possible choices: hf, nemo</p>
|
||||
<p>The source type of LoRA checkpoint.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'hf'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--lora_target_modules</kbd></dt>
|
||||
<dd><p>Possible choices: attn_qkv, attn_q, attn_k, attn_v, attn_dense, mlp_h_to_4h, mlp_4h_to_h, mlp_gate, cross_attn_qkv, cross_attn_q, cross_attn_k, cross_attn_v, cross_attn_dense, moe_h_to_4h, moe_4h_to_h, moe_gate, moe_router, mlp_router</p>
|
||||
<p>The target module names that LoRA is applied. Only effective when <code class="docutils literal notranslate"><span class="pre">lora_plugin</span></code> is enabled.</p>
|
||||
</dd>
|
||||
<dt><kbd>--max_lora_rank</kbd></dt>
|
||||
<dd><p>Maximum LoRA rank for different LoRA modules. It is used to compute the workspace size of LoRA plugin.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">64</span></code></p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-speculative-decoding-arguments">
|
||||
<h2>Speculative decoding arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-speculative-decoding-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--speculative_decoding_mode</kbd></dt>
|
||||
<dd><p>Possible choices: draft_tokens_external, lookahead_decoding, medusa, explicit_draft_tokens, eagle</p>
|
||||
<p>Mode of speculative decoding.</p>
|
||||
</dd>
|
||||
<dt><kbd>--max_draft_len</kbd></dt>
|
||||
<dd><p>Maximum lengths of draft tokens for speculative decoding target model.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">0</span></code></p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-auto-parallel-arguments">
|
||||
<h2>Auto parallel arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-auto-parallel-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--auto_parallel</kbd></dt>
|
||||
<dd><p>MPI world size for auto parallel.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">1</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gpus_per_node</kbd></dt>
|
||||
<dd><p>Number of GPUs each node has in a multi-node setup. This is a cluster spec and can be greater/smaller than world size. This option is only used for auto parallel specified with <code class="docutils literal notranslate"><span class="pre">--auto_parallel</span></code>.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">8</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--cluster_key</kbd></dt>
|
||||
<dd><p>Possible choices: A100-SXM-80GB, A100-SXM-40GB, A100-PCIe-80GB, A100-PCIe-40GB, H100-SXM, H100-PCIe, H20, V100-PCIe-16GB, V100-PCIe-32GB, V100-SXM-16GB, V100-SXM-32GB, V100S-PCIe, A40, A30, A10, A10G, L40S, L40, L20, L4, L2</p>
|
||||
<p>Unique name for target GPU type. Inferred from current GPU type if not specified. This option is only used for auto parallel specified with <code class="docutils literal notranslate"><span class="pre">--auto_parallel</span></code>.</p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
<section id="tensorrt_llm.commands.build-parse_arguments-plugin-config-arguments">
|
||||
<h2>Plugin config arguments<a class="headerlink" href="#tensorrt_llm.commands.build-parse_arguments-plugin-config-arguments" title="Link to this heading"></a></h2>
|
||||
<dl class="option-list">
|
||||
<dt><kbd>--bert_attention_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>The plugin that uses efficient kernels and enables an in-place update of the KV cache for attention layer of BERT-like encoder models.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'auto'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gpt_attention_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>The plugin that uses efficient kernels and enables an in-place update of the KV cache for attention layer of GPT-like decoder models.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'auto'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gemm_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, fp8, disable</p>
|
||||
<p>The GEMM plugin that utilizes NVIDIA cuBLASLt to perform GEMM operations. Note: it’s only affective for non-quantized gemm operations (except FP8).Note: For FP8, it also requires same calibration in checkpoint.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--gemm_swiglu_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: fp8, disable</p>
|
||||
<p>The GEMM + SwiGLU fusion in Gated-MLP combines two Matmul operations and one SwiGLU operation into a single kernel. Currently this is only supported for FP8 precision on Hopper.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--fp8_rowwise_gemm_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>The quantized GEMM for fp8, which uses per token dynamic scales for activation and per channel static scales for weights.Note: It also requires same calibration in checkpoint.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--nccl_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>The NCCL plugin wraps NCCL operators to support multi-GPU and even multi-nodes.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'auto'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--lora_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>Enable LoRA.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--moe_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>Enable some customized kernels to speed up the MoE layer of MoE models.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'auto'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--mamba_conv1d_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: auto, float16, float32, bfloat16, int32, disable</p>
|
||||
<p>Enable customized kernels to speed up conv1d operator for Mamba.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'auto'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--low_latency_gemm_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: fp8, disable</p>
|
||||
<p>The GEMM plugin that optimized specially for low latency scenarios.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--low_latency_gemm_swiglu_plugin</kbd></dt>
|
||||
<dd><p>Possible choices: fp8, disable</p>
|
||||
<p>The GEMM + SwiGLU fusion plugin that optimized specially for low latency scenarios.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--context_fmha</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable the fused multi-head attention during the context phase, will trigger a kernel that performs the MHA/MQA/GQA block using a single kernel.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'enable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--bert_context_fmha_fp32_acc</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable the FP32 accumulator for context FMHA in the bert_attention_plugin. If disabled, FP16 is used, better performance but potentially worse accuracy is expected.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--remove_input_padding</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Pack different tokens together, which reduces both the amount of computations and memory consumption.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'enable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--reduce_fusion</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Fuse the ResidualAdd and LayerNorm kernels after AllReduce into a single kernel, resulting in improved end-to-end performance.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--user_buffer</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Eliminate extra copies from the local buffer to the shared buffer in the communication kernel, leading to improved end-to-end performance. This feature must be enabled with <cite>–reduce_fusion enable</cite> and is currently only supported for the FP8 LLAMA model.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--tokens_per_block</kbd></dt>
|
||||
<dd><p>Define how many tokens are contained in each paged kv cache block.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">64</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--use_paged_context_fmha</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Allow advanced features like KV cache reuse and chunked context.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--use_fp8_context_fmha</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>When FP8 quantization is activated, the attention can be further accelerated by enabling FP8 Context FMHA</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--multiple_profiles</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enables multiple TensorRT optimization profiles in the built engines, will benefits the performance especially when GEMM plugin is disabled, because more optimization profiles help TensorRT have more chances to select better kernels. Note: This feature increases engine build time but no other adverse effects are expected.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--paged_state</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable paged state, which helps manage memory for the RNN state more efficiently.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'enable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--streamingllm</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable [StreamingLLM](<a class="reference external" href="https://arxiv.org/abs/2309.17453">https://arxiv.org/abs/2309.17453</a>), which uses a window attention to perform efficient and stable LLM on long texts.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--use_fused_mlp</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable horizontal fusion in Gated-MLP that combines two Matmul operations into a single one followed by a separate SwiGLU kernel.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'enable'</span></code></p>
|
||||
</dd>
|
||||
<dt><kbd>--pp_reduce_scatter</kbd></dt>
|
||||
<dd><p>Possible choices: enable, disable</p>
|
||||
<p>Enable a pipeline parallelism optimization with ReduceScatter + AllGather targeting large MoE models.</p>
|
||||
<p>Default: <code class="docutils literal notranslate"><span class="pre">'disable'</span></code></p>
|
||||
</dd>
|
||||
</dl>
|
||||
</section>
|
||||
</section>
|
||||
|
||||
|
||||
@ -177,7 +531,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b8ff3e0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74c8650>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -88,7 +88,10 @@
|
||||
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
||||
<ul class="current">
|
||||
<li class="toctree-l1"><a class="reference internal" href="trtllm-build.html">trtllm-build</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">trtllm-serve</a></li>
|
||||
<li class="toctree-l1 current"><a class="current reference internal" href="#">trtllm-serve</a><ul>
|
||||
<li class="toctree-l2"><a class="reference internal" href="#trtllm-serve">trtllm-serve</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||||
<ul>
|
||||
@ -164,6 +167,88 @@
|
||||
|
||||
<section id="trtllm-serve">
|
||||
<h1>trtllm-serve<a class="headerlink" href="#trtllm-serve" title="Link to this heading"></a></h1>
|
||||
<section id="trtllm-serve">
|
||||
<h2>trtllm-serve<a class="headerlink" href="#trtllm-serve" title="Link to this heading"></a></h2>
|
||||
<p>Running an OpenAI API compatible server</p>
|
||||
<p>MODEL: model name | HF checkpoint path | TensorRT engine path</p>
|
||||
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>trtllm-serve<span class="w"> </span><span class="o">[</span>OPTIONS<span class="o">]</span><span class="w"> </span>MODEL
|
||||
</pre></div>
|
||||
</div>
|
||||
<p class="rubric">Options</p>
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-tokenizer">
|
||||
<span class="sig-name descname"><span class="pre">--tokenizer</span></span><span class="sig-prename descclassname"> <span class="pre"><tokenizer></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-tokenizer" title="Link to this definition"></a></dt>
|
||||
<dd><p>Path | Name of the tokenizer.Specify this value only if using TensorRT engine as model.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-host">
|
||||
<span class="sig-name descname"><span class="pre">--host</span></span><span class="sig-prename descclassname"> <span class="pre"><host></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-host" title="Link to this definition"></a></dt>
|
||||
<dd><p>Hostname of the server.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-port">
|
||||
<span class="sig-name descname"><span class="pre">--port</span></span><span class="sig-prename descclassname"> <span class="pre"><port></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-port" title="Link to this definition"></a></dt>
|
||||
<dd><p>Port of the server.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-max_beam_width">
|
||||
<span class="sig-name descname"><span class="pre">--max_beam_width</span></span><span class="sig-prename descclassname"> <span class="pre"><max_beam_width></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-max_beam_width" title="Link to this definition"></a></dt>
|
||||
<dd><p>Maximum number of beams for beam search decoding.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-max_batch_size">
|
||||
<span class="sig-name descname"><span class="pre">--max_batch_size</span></span><span class="sig-prename descclassname"> <span class="pre"><max_batch_size></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-max_batch_size" title="Link to this definition"></a></dt>
|
||||
<dd><p>Maximum number of requests that the engine can schedule.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-max_num_tokens">
|
||||
<span class="sig-name descname"><span class="pre">--max_num_tokens</span></span><span class="sig-prename descclassname"> <span class="pre"><max_num_tokens></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-max_num_tokens" title="Link to this definition"></a></dt>
|
||||
<dd><p>Maximum number of batched input tokens after padding is removed in each batch.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-max_seq_len">
|
||||
<span class="sig-name descname"><span class="pre">--max_seq_len</span></span><span class="sig-prename descclassname"> <span class="pre"><max_seq_len></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-max_seq_len" title="Link to this definition"></a></dt>
|
||||
<dd><p>Maximum total length of one request, including prompt and outputs. If unspecified, the value is deduced from the model config.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-tp_size">
|
||||
<span class="sig-name descname"><span class="pre">--tp_size</span></span><span class="sig-prename descclassname"> <span class="pre"><tp_size></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-tp_size" title="Link to this definition"></a></dt>
|
||||
<dd><p>Tensor parallelism size.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-pp_size">
|
||||
<span class="sig-name descname"><span class="pre">--pp_size</span></span><span class="sig-prename descclassname"> <span class="pre"><pp_size></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-pp_size" title="Link to this definition"></a></dt>
|
||||
<dd><p>Pipeline parallelism size.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-kv_cache_free_gpu_memory_fraction">
|
||||
<span class="sig-name descname"><span class="pre">--kv_cache_free_gpu_memory_fraction</span></span><span class="sig-prename descclassname"> <span class="pre"><kv_cache_free_gpu_memory_fraction></span></span><a class="headerlink" href="#cmdoption-trtllm-serve-kv_cache_free_gpu_memory_fraction" title="Link to this definition"></a></dt>
|
||||
<dd><p>Free GPU memory fraction reserved for KV Cache, after allocating model weights and buffers.</p>
|
||||
</dd></dl>
|
||||
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-trust_remote_code">
|
||||
<span class="sig-name descname"><span class="pre">--trust_remote_code</span></span><span class="sig-prename descclassname"></span><a class="headerlink" href="#cmdoption-trtllm-serve-trust_remote_code" title="Link to this definition"></a></dt>
|
||||
<dd><p>Flag for HF transformers.</p>
|
||||
</dd></dl>
|
||||
|
||||
<p class="rubric">Arguments</p>
|
||||
<dl class="std option">
|
||||
<dt class="sig sig-object std" id="cmdoption-trtllm-serve-arg-MODEL">
|
||||
<span class="sig-name descname"><span class="pre">MODEL</span></span><span class="sig-prename descclassname"></span><a class="headerlink" href="#cmdoption-trtllm-serve-arg-MODEL" title="Link to this definition"></a></dt>
|
||||
<dd><p>Required argument</p>
|
||||
</dd></dl>
|
||||
|
||||
</section>
|
||||
</section>
|
||||
|
||||
|
||||
@ -177,7 +262,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61b15b830>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac74c98b0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
2629
genindex.html
2629
genindex.html
File diff suppressed because it is too large
Load Diff
@ -15,7 +15,7 @@
|
||||
<script src="_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -446,7 +446,7 @@
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a60f067fb0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac742aab0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -337,7 +337,7 @@ relevant classes. The associated unit tests should also be consulted for underst
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61aa536b0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac742a8d0>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -392,7 +392,7 @@ pip<span class="w"> </span>uninstall<span class="w"> </span>-y<span class="w"> <
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61adea0f0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac72a4c50>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
<script src="../_static/jquery.js?v=5d32c60e"></script>
|
||||
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
|
||||
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/doctools.js?v=9bcbadda"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
|
||||
<script src="../_static/copybutton.js?v=65e89d2a"></script>
|
||||
@ -210,7 +210,7 @@ sudo<span class="w"> </span>apt-get<span class="w"> </span>-y<span class="w"> </
|
||||
<hr/>
|
||||
|
||||
<div role="contentinfo">
|
||||
<jinja2.runtime.BlockReference object at 0x74a61aa522d0>
|
||||
<jinja2.runtime.BlockReference object at 0x7f1ac77ed340>
|
||||
|
||||
<div class="footer">
|
||||
<p>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user