mps: Alternative implementation for repeat_interleave (#766)
* mps: alt. implementation for repeat_interleave * style * Bump mps version of PyTorch in the documentation. * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Simplify: do not check for device. * style * Fix repeat dimensions: - The unconditional embeddings are always created from a single prompt. - I was shadowing the batch_size var. * Split long lines as suggested by Suraj. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
|
|||||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||||
- macOS 12.3 or later.
|
- macOS 12.3 or later.
|
||||||
- arm64 version of Python.
|
- arm64 version of Python.
|
||||||
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
|
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
|
||||||
|
|
||||||
## Inference Pipeline
|
## Inference Pipeline
|
||||||
|
|
||||||
|
|||||||
@@ -218,8 +218,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||||||
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||||
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
@@ -256,8 +258,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
# duplicate unconditional embeddings for each generation per prompt
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
seq_len = uncond_embeddings.shape[1]
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||||
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
|||||||
Reference in New Issue
Block a user