test: Get Eagle tests working (#3593)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
brb-nv 2025-04-19 09:50:57 -07:00 committed by GitHub
parent e70961f541
commit c35d2a7532
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 12 deletions

View File

@ -295,6 +295,14 @@ if __name__ == '__main__':
args.n_positions = hf_config.max_position_embeddings
args.dtype = str(
hf_config.torch_dtype)[6:] if args.dtype == 'auto' else args.dtype
if 'head_dim' in hf_config:
args.head_dim = hf_config.head_dim
else:
args.head_dim = args.n_embd // args.n_head
if 'head_size' in hf_config:
args.head_size = hf_config.head_size
else:
args.head_size = args.head_dim
if args.eagle_model_dir is None:
hf_config_eagle = hf_config.eagle
@ -305,6 +313,14 @@ if __name__ == '__main__':
args.n_kv_head_eagle = hf_config_eagle['num_key_value_heads']
args.rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps']
args.n_positions_eagle = hf_config_eagle['max_position_embeddings']
if 'head_dim' in hf_config_eagle:
args.head_dim_eagle = hf_config_eagle['head_dim']
else:
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
if 'head_size' in hf_config_eagle:
args.head_size_eagle = hf_config_eagle['head_size']
else:
args.head_size_eagle = args.head_dim_eagle
else:
hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir)
args.n_head_eagle = hf_config_eagle.num_attention_heads
@ -314,6 +330,14 @@ if __name__ == '__main__':
args.n_kv_head_eagle = hf_config_eagle.num_key_value_heads
args.rms_norm_eps_eagle = hf_config_eagle.rms_norm_eps
args.n_positions_eagle = hf_config_eagle.max_position_embeddings
if 'head_dim' in hf_config_eagle:
args.head_dim_eagle = hf_config_eagle.head_dim
else:
args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle
if 'head_size' in hf_config_eagle:
args.head_size_eagle = hf_config_eagle.head_size
else:
args.head_size_eagle = args.head_dim_eagle
elif args.meta_ckpt_dir is not None:
assert False, "meta ckpt is not supported yet"
@ -370,6 +394,8 @@ if __name__ == '__main__':
},
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'head_dim': args.head_dim_eagle,
'head_size': args.head_size_eagle
}
config = {
@ -402,7 +428,9 @@ if __name__ == '__main__':
'max_draft_len': args.max_draft_len,
'num_eagle_layers': args.num_eagle_layers,
'max_non_leaves_per_layer': args.max_non_leaves_per_layer,
'eagle_net_config': eagle_net_config
'eagle_net_config': eagle_net_config,
'head_dim': args.head_dim,
'head_size': args.head_size
}
assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported"

View File

@ -88,6 +88,14 @@ class EagleConfig(LLaMAConfig):
n_positions = hf_config.max_position_embeddings
hidden_act = hf_config.hidden_act
dtype = str(hf_config.torch_dtype)[6:] if dtype == 'auto' else dtype
if hasattr(hf_config, 'head_dim'):
head_dim = hf_config.head_dim
else:
head_dim = hf_config.n_embd // hf_config.n_head
if hasattr(hf_config, 'head_size'):
head_size = hf_config.head_size
else:
head_size = head_dim
if speculative_config_or_dir is None:
hf_config_eagle = hf_config.eagle
@ -143,6 +151,8 @@ class EagleConfig(LLaMAConfig):
},
'use_parallel_embedding': kwargs['use_parallel_embedding'],
'embedding_sharding_dim': kwargs['embedding_sharding_dim'],
'head_dim': head_dim,
'head_size': head_size
}
config = {

View File

@ -945,11 +945,22 @@ def get_dummy_spec_decoding_heads(hf_model_dir,
)
quant_cfg = getattr(mtq, "FP8_DEFAULT_CFG")
# Following quantizers are needed for KV cache quantization.
quant_cfg["quant_cfg"]["*output_quantizer"] = {
"num_bits": (4, 3),
"axis": None,
"enable": True,
}
quant_cfg["quant_cfg"]["*k_bmm_quantizer"] = {
"num_bits": (4, 3),
"axis": None,
"enable": True,
}
quant_cfg["quant_cfg"]["*v_bmm_quantizer"] = {
"num_bits": (4, 3),
"axis": None,
"enable": True,
}
calibrate_loop = dataset_utils.create_forward_loop(
calib_dataloader, dataloader=calib_dataloader)

View File

@ -270,17 +270,6 @@ def test_codellama_eagle_1gpu(code_llama_model_root,
llm_datasets_root=llm_datasets_root,
llm_rouge_root=llm_rouge_root)
test_with_dummy_eagle(hf_model_root=code_llama_model_root,
eagle_example_root=eagle_example_root,
llm_venv=llm_venv,
cmodel_dir=cmodel_dir,
engine_dir=engine_dir,
batch_size=batch_size,
data_type=data_type,
use_dynamic_tree=use_dynamic_tree,
llm_datasets_root=llm_datasets_root,
llm_rouge_root=llm_rouge_root)
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
ids=['eagle1', 'eagle2'])
@ -309,6 +298,33 @@ def test_mistral_eagle_1gpu(llm_mistral_model_root,
llm_rouge_root=llm_rouge_root)
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
ids=['eagle1', 'eagle2'])
@pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'],
indirect=True)
def test_mistral_nemo_eagle_1gpu(mistral_nemo_model_root,
eagle_example_root,
llm_datasets_root,
llm_rouge_root,
llm_venv,
cmodel_dir,
engine_dir,
use_dynamic_tree,
batch_size=8,
data_type='bfloat16'):
test_with_dummy_eagle(hf_model_root=mistral_nemo_model_root,
eagle_example_root=eagle_example_root,
llm_venv=llm_venv,
cmodel_dir=cmodel_dir,
engine_dir=engine_dir,
batch_size=batch_size,
data_type=data_type,
use_dynamic_tree=use_dynamic_tree,
llm_datasets_root=llm_datasets_root,
llm_rouge_root=llm_rouge_root)
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
ids=['eagle1', 'eagle2'])
@pytest.mark.parametrize("llm_qwen_model_root", [

View File

@ -500,6 +500,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle1]
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle1]
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle1]
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle1]
examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle1]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle1]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle1]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle1]
@ -514,6 +515,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle2]
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle2]
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle2]
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle2]
examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle2]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle2]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle2]
examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle2]