Remove deprecated torch_device kwarg (#623)
* Remove deprecated `torch_device` kwarg. * Remove unused imports.
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -74,20 +73,6 @@ class DDIMPipeline(DiffusionPipeline):
|
|||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
# eta corresponds to η in paper and should be between [0, 1]
|
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
@@ -103,6 +88,7 @@ class DDIMPipeline(DiffusionPipeline):
|
|||||||
model_output = self.unet(image, t).sample
|
model_output = self.unet(image, t).sample
|
||||||
|
|
||||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||||
|
# eta corresponds to η in paper and should be between [0, 1]
|
||||||
# do x_t -> x_t-1
|
# do x_t -> x_t-1
|
||||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -66,17 +65,6 @@ class DDPMPipeline(DiffusionPipeline):
|
|||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -94,17 +93,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -60,18 +59,6 @@ class LDMPipeline(DiffusionPipeline):
|
|||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
latents = torch.randn(
|
latents = torch.randn(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -75,18 +74,6 @@ class PNDMPipeline(DiffusionPipeline):
|
|||||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||||
|
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -53,18 +52,6 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
img_size = self.unet.config.sample_size
|
img_size = self.unet.config.sample_size
|
||||||
shape = (batch_size, 3, img_size, img_size)
|
shape = (batch_size, 3, img_size, img_size)
|
||||||
|
|
||||||
|
|||||||
@@ -169,18 +169,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -64,17 +63,6 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
if "torch_device" in kwargs:
|
|
||||||
device = kwargs.pop("torch_device")
|
|
||||||
warnings.warn(
|
|
||||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
|
||||||
" Consider using `pipe.to(torch_device)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set device as before (to be removed in 0.3.0)
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
img_size = self.unet.config.sample_size
|
img_size = self.unet.config.sample_size
|
||||||
shape = (batch_size, 3, img_size, img_size)
|
shape = (batch_size, 3, img_size, img_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user