mps: Alternative implementation for repeat_interleave (#766)
* mps: alt. implementation for repeat_interleave * style * Bump mps version of PyTorch in the documentation. * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Simplify: do not check for device. * style * Fix repeat dimensions: - The unconditional embeddings are always created from a single prompt. - I was shadowing the batch_size var. * Split long lines as suggested by Suraj. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
|
||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||
- macOS 12.3 or later.
|
||||
- arm64 version of Python.
|
||||
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later.
|
||||
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
|
||||
|
||||
## Inference Pipeline
|
||||
|
||||
|
||||
@@ -218,8 +218,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -256,8 +258,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
|
||||
Reference in New Issue
Block a user