提交 69548718 authored 作者: kijai's avatar kijai

VRAM optimizations

上级 b3680624
...@@ -542,9 +542,9 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -542,9 +542,9 @@ class MimicMotionPipeline(DiffusionPipeline):
image_latents = image_latents.to(image_embeddings.dtype) image_latents = image_latents.to(image_embeddings.dtype)
ref_latent = first_n_frames[:, 0] if first_n_frames is not None else None ref_latent = first_n_frames[:, 0] if first_n_frames is not None else None
pose_latents = self._encode_pose_image( # pose_latents = self._encode_pose_image(
image_pose, do_classifier_free_guidance=self.do_classifier_free_guidance, # image_pose, do_classifier_free_guidance=self.do_classifier_free_guidance,
) # )
# cast back to fp16 if needed # cast back to fp16 if needed
# if needs_upcasting: # if needs_upcasting:
...@@ -609,7 +609,8 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -609,7 +609,8 @@ class MimicMotionPipeline(DiffusionPipeline):
print(f"start_step_index: {start_step_index}, end_step_index: {end_step_index}") print(f"start_step_index: {start_step_index}, end_step_index: {end_step_index}")
pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames) #pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)
pose_latents_shape = self.pose_net(image_pose[0].to(device))
indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
range(0, num_frames - tile_size + 1, tile_size - tile_overlap)] range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]
if indices[-1][-1] < num_frames - 1: if indices[-1][-1] < num_frames - 1:
...@@ -636,24 +637,51 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -636,24 +637,51 @@ class MimicMotionPipeline(DiffusionPipeline):
if start_step_index <= i <= end_step_index: if start_step_index <= i <= end_step_index:
# Apply pose_latents as currently done # Apply pose_latents as currently done
#print(f"Applying pose on step {i}") #print(f"Applying pose on step {i}")
pose_latents_to_use = pose_latents[:, idx].flatten(0, 1) pose_latents_to_use = self.pose_net(image_pose[idx].to(device))
#pose_latents_to_use = pose_latents[:, idx].flatten(0, 1).to(device)
else: else:
#print(f"Not applying pose on step {i}") #print(f"Not applying pose on step {i}")
# Apply an alternative if pose_latents should not be used outside this range # Apply an alternative if pose_latents should not be used outside this range
# This could be zeros, or any other placeholder logic you define. pose_latents_to_use = torch.zeros_like(pose_latents_shape, device=device)
pose_latents_to_use = torch.zeros_like(pose_latents[:, idx].flatten(0, 1))
# _noise_pred = self.unet(
# latent_model_input[:, idx],
# t,
# encoder_hidden_states=image_embeddings,
# added_time_ids=added_time_ids,
# pose_latents=pose_latents_to_use,
# pose_strength=pose_strength,
# image_only_indicator=image_only_indicator,
# return_dict=False,
# )[0]
# noise_pred[:, idx] += _noise_pred * weight[:, None, None, None]
# classification-free inference
_noise_pred = self.unet(
latent_model_input[:1, idx],
t,
encoder_hidden_states=image_embeddings[:1],
added_time_ids=added_time_ids[:1],
pose_latents=None,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
# normal inference
_noise_pred = self.unet( _noise_pred = self.unet(
latent_model_input[:, idx], latent_model_input[1:, idx],
t, t,
encoder_hidden_states=image_embeddings, encoder_hidden_states=image_embeddings[1:],
added_time_ids=added_time_ids, added_time_ids=added_time_ids[1:],
pose_latents=pose_latents_to_use, pose_latents=pose_latents_to_use,
pose_strength=pose_strength,
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred[:, idx] += _noise_pred * weight[:, None, None, None] noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
noise_pred_cnt[idx] += weight noise_pred_cnt[idx] += weight
progress_bar.update() progress_bar.update()
comfy_pbar.update(1) comfy_pbar.update(1)
......
...@@ -7,6 +7,17 @@ import folder_paths ...@@ -7,6 +7,17 @@ import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
import comfy.utils import comfy.utils
try:
import diffusers.models.activations
def patch_geglu_inplace():
"""Patch GEGLU with inplace multiplication to save GPU memory."""
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states.mul_(self.gelu(gate))
diffusers.models.activations.GEGLU.forward = forward
except:
pass
from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler from diffusers.schedulers import EulerDiscreteScheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论