[tests] Fix failing float16 cuda tests (#11835)
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
parent
f3e1310469
commit
3f3f0c16a6
@ -1378,7 +1378,6 @@ class PipelineTesterMixin:
|
|||||||
for component in pipe_fp16.components.values():
|
for component in pipe_fp16.components.values():
|
||||||
if hasattr(component, "set_default_attn_processor"):
|
if hasattr(component, "set_default_attn_processor"):
|
||||||
component.set_default_attn_processor()
|
component.set_default_attn_processor()
|
||||||
|
|
||||||
pipe_fp16.to(torch_device, torch.float16)
|
pipe_fp16.to(torch_device, torch.float16)
|
||||||
pipe_fp16.set_progress_bar_config(disable=None)
|
pipe_fp16.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
@ -1386,17 +1385,20 @@ class PipelineTesterMixin:
|
|||||||
# Reset generator in case it is used inside dummy inputs
|
# Reset generator in case it is used inside dummy inputs
|
||||||
if "generator" in inputs:
|
if "generator" in inputs:
|
||||||
inputs["generator"] = self.get_generator(0)
|
inputs["generator"] = self.get_generator(0)
|
||||||
|
|
||||||
output = pipe(**inputs)[0]
|
output = pipe(**inputs)[0]
|
||||||
|
|
||||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||||
# Reset generator in case it is used inside dummy inputs
|
# Reset generator in case it is used inside dummy inputs
|
||||||
if "generator" in fp16_inputs:
|
if "generator" in fp16_inputs:
|
||||||
fp16_inputs["generator"] = self.get_generator(0)
|
fp16_inputs["generator"] = self.get_generator(0)
|
||||||
|
|
||||||
output_fp16 = pipe_fp16(**fp16_inputs)[0]
|
output_fp16 = pipe_fp16(**fp16_inputs)[0]
|
||||||
|
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
output = output.cpu()
|
||||||
|
output_fp16 = output_fp16.cpu()
|
||||||
|
|
||||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||||
assert max_diff < 1e-2
|
assert max_diff < expected_max_diff
|
||||||
|
|
||||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||||
@require_accelerator
|
@require_accelerator
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user