mps: Alternative implementation for repeat_interleave (#766)

* mps: alt. implementation for repeat_interleave

* style

* Bump mps version of PyTorch in the documentation.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Simplify: do not check for device.

* style

* Fix repeat dimensions:

- The unconditional embeddings are always created from a single prompt.
- I was shadowing the batch_size var.

* Split long lines as suggested by Suraj.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Pedro Cuenca
2022-10-11 20:30:09 +02:00
committed by GitHub
parent 757babfcad
commit 24b8b5cf5e
2 changed files with 9 additions and 5 deletions
+1 -1
View File
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
- Mac computer with Apple silicon (M1/M2) hardware. - Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later. - macOS 12.3 or later.
- arm64 version of Python. - arm64 version of Python.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later. - PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
## Inference Pipeline ## Inference Pipeline
@@ -218,8 +218,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt # duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -256,8 +258,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch