提交 b546cfc3 authored 作者: kijai's avatar kijai

bigger update

上级 4f5929a8
...@@ -20,6 +20,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection ...@@ -20,6 +20,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ..modules.pose_net import PoseNet from ..modules.pose_net import PoseNet
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
import comfy.model_management as mm
offload_device = mm.unet_offload_device()
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -145,8 +147,10 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -145,8 +147,10 @@ class MimicMotionPipeline(DiffusionPipeline):
).pixel_values ).pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
self.image_encoder.to(device)
image_embeddings = self.image_encoder(image).image_embeds image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1) image_embeddings = image_embeddings.unsqueeze(1)
self.image_encoder.to(offload_device)
# duplicate image embeddings for each generation per prompt, using mps friendly method # duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape bs_embed, seq_len, _ = image_embeddings.shape
...@@ -189,7 +193,9 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -189,7 +193,9 @@ class MimicMotionPipeline(DiffusionPipeline):
do_classifier_free_guidance: bool, do_classifier_free_guidance: bool,
): ):
image = image.to(device=device) image = image.to(device=device)
self.vae.to(device)
image_latents = self.vae.encode(image).latent_dist.mode() image_latents = self.vae.encode(image).latent_dist.mode()
self.vae.to(offload_device)
if do_classifier_free_guidance: if do_classifier_free_guidance:
negative_image_latents = torch.zeros_like(image_latents) negative_image_latents = torch.zeros_like(image_latents)
...@@ -256,7 +262,10 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -256,7 +262,10 @@ class MimicMotionPipeline(DiffusionPipeline):
# we only pass num_frames_in if it's expected # we only pass num_frames_in if it's expected
decode_kwargs["num_frames"] = num_frames_in decode_kwargs["num_frames"] = num_frames_in
self.vae.to(latents.device)
frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
self.vae.to(offload_device)
frames.append(frame.cpu()) frames.append(frame.cpu())
frames = torch.cat(frames, dim=0) frames = torch.cat(frames, dim=0)
...@@ -568,6 +577,8 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -568,6 +577,8 @@ class MimicMotionPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
# 8. Denoising loop # 8. Denoising loop
self.unet.to(device)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames) pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)
indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
...@@ -627,6 +638,8 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -627,6 +638,8 @@ class MimicMotionPipeline(DiffusionPipeline):
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
self.unet.to(offload_device)
if not output_type == "latent": if not output_type == "latent":
# cast back to fp16 if needed # cast back to fp16 if needed
......
import os import os
from omegaconf import OmegaConf
import torch import torch
import torch.nn.functional as F
import sys import sys
import numpy as np import numpy as np
import gc
script_directory = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_directory)
from einops import repeat
import folder_paths import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
import comfy.utils import comfy.utils
from contextlib import nullcontext
try:
from accelerate import init_empty_weights
is_accelerate_available = True
except:
pass
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler from diffusers.schedulers import EulerDiscreteScheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
script_directory = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_directory)
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from mimicmotion.modules.unet import UNetSpatioTemporalConditionModel from mimicmotion.modules.unet import UNetSpatioTemporalConditionModel
from mimicmotion.modules.pose_net import PoseNet from mimicmotion.modules.pose_net import PoseNet
...@@ -65,6 +55,7 @@ class DownloadAndLoadMimicMotionModel: ...@@ -65,6 +55,7 @@ class DownloadAndLoadMimicMotionModel:
], { ], {
"default": 'fp16' "default": 'fp16'
}), }),
}, },
} }
...@@ -77,6 +68,8 @@ class DownloadAndLoadMimicMotionModel: ...@@ -77,6 +68,8 @@ class DownloadAndLoadMimicMotionModel:
device = mm.get_torch_device() device = mm.get_torch_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
pbar = comfy.utils.ProgressBar(3)
download_path = os.path.join(folder_paths.models_dir, "mimicmotion") download_path = os.path.join(folder_paths.models_dir, "mimicmotion")
model_path = os.path.join(download_path, model) model_path = os.path.join(download_path, model)
...@@ -89,13 +82,20 @@ class DownloadAndLoadMimicMotionModel: ...@@ -89,13 +82,20 @@ class DownloadAndLoadMimicMotionModel:
local_dir=download_path, local_dir=download_path,
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
ckpt_base_name = os.path.basename(model_path)
print(f"Loading model from: {model_path}") print(f"Loading model from: {model_path}")
pbar.update(1)
svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1") svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1")
if not os.path.exists(svd_path): if not os.path.exists(svd_path):
raise ValueError(f"Please download stable-video-diffusion-img2vid-xt-1-1 to {svd_path}") #raise ValueError(f"Please download stable-video-diffusion-img2vid-xt-1-1 to {svd_path}")
print(f"Downloading SVD model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1",
allow_patterns=[f"*.json", "*fp16*"],
local_dir=svd_path,
local_dir_use_symlinks=False)
pbar.update(1)
mimicmotion_models = MimicMotionModel(svd_path).to(device=device).eval() mimicmotion_models = MimicMotionModel(svd_path).to(device=device).eval()
mimicmotion_models.load_state_dict(comfy.utils.load_torch_file(model_path), strict=False) mimicmotion_models.load_state_dict(comfy.utils.load_torch_file(model_path), strict=False)
...@@ -108,16 +108,18 @@ class DownloadAndLoadMimicMotionModel: ...@@ -108,16 +108,18 @@ class DownloadAndLoadMimicMotionModel:
feature_extractor=mimicmotion_models.feature_extractor, feature_extractor=mimicmotion_models.feature_extractor,
pose_net=mimicmotion_models.pose_net, pose_net=mimicmotion_models.pose_net,
) )
pipeline.unet.to(dtype) pipeline.unet.to(dtype)
pipeline.pose_net.to(dtype) pipeline.pose_net.to(dtype)
pipeline.vae.to(dtype) pipeline.vae.to(dtype)
pipeline.image_encoder.to(dtype) pipeline.image_encoder.to(dtype)
pipeline.pose_net.to(dtype) pipeline.pose_net.to(dtype)
mimic_model = { mimic_model = {
'pipeline': pipeline, 'pipeline': pipeline,
'dtype': dtype 'dtype': dtype
} }
pbar.update(1)
return (mimic_model,) return (mimic_model,)
class MimicMotionSampler: class MimicMotionSampler:
...@@ -151,16 +153,26 @@ class MimicMotionSampler: ...@@ -151,16 +153,26 @@ class MimicMotionSampler:
pipeline = mimic_pipeline['pipeline'] pipeline = mimic_pipeline['pipeline']
B, H, W, C = pose_images.shape B, H, W, C = pose_images.shape
ref_image = ref_image.permute(0, 3, 1, 2).to(device).to(dtype)
pose_images = pose_images.permute(0, 3, 1, 2).to(device).to(dtype) ref_image = ref_image.permute(0, 3, 1, 2)
ref_image = ref_image * 2 - 1 pose_images = pose_images.permute(0, 3, 1, 2)
if ref_image.shape[1:3] != (224, 224):
ref_img = comfy.utils.common_upscale(ref_image, 224, 224, "lanczos", "disabled")
else:
ref_img = ref_image
ref_img = ref_img * 2 - 1
pose_images = pose_images * 2 - 1 pose_images = pose_images * 2 - 1
ref_img = ref_img.to(device).to(dtype)
pose_images = pose_images.to(device).to(dtype)
generator = torch.Generator(device=device) generator = torch.Generator(device=device)
generator.manual_seed(seed) generator.manual_seed(seed)
frames = pipeline( frames = pipeline(
ref_image, ref_img,
image_pose=pose_images, image_pose=pose_images,
num_frames=B, num_frames=B,
tile_size = 16, tile_size = 16,
...@@ -177,8 +189,14 @@ class MimicMotionSampler: ...@@ -177,8 +189,14 @@ class MimicMotionSampler:
output_type="pt", output_type="pt",
device=device device=device
).frames ).frames
frames = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float() frames = frames.squeeze(0)[1:].permute(0, 2, 3, 1).cpu().float()
print(frames.shape)
if not keep_model_loaded:
pipeline.unet.to(offload_device)
pipeline.vae.to(offload_device)
mm.soft_empty_cache()
gc.collect()
return frames, return frames,
...@@ -194,8 +212,8 @@ class MimicMotionGetPoses: ...@@ -194,8 +212,8 @@ class MimicMotionGetPoses:
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE", "IMAGE",)
RETURN_NAMES = ("images",) RETURN_NAMES = ("poses_with_ref", "pose_images")
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MimicMotionWrapper" CATEGORY = "MimicMotionWrapper"
...@@ -246,9 +264,11 @@ class MimicMotionGetPoses: ...@@ -246,9 +264,11 @@ class MimicMotionGetPoses:
pose_images_np = pose_images.cpu().numpy() * 255 pose_images_np = pose_images.cpu().numpy() * 255
# read input video # read input video
pbar = comfy.utils.ProgressBar(len(pose_images_np))
detected_poses_np_list = [] detected_poses_np_list = []
for img_np in pose_images_np: for img_np in pose_images_np:
detected_poses_np_list.append(dwprocessor(img_np)) detected_poses_np_list.append(dwprocessor(img_np))
pbar.update(1)
detected_bodies = np.stack( detected_bodies = np.stack(
[p['bodies']['candidate'] for p in detected_poses_np_list if p['bodies']['candidate'].shape[0] == 18])[:, [p['bodies']['candidate'] for p in detected_poses_np_list if p['bodies']['candidate'].shape[0] == 18])[:,
...@@ -277,10 +297,7 @@ class MimicMotionGetPoses: ...@@ -277,10 +297,7 @@ class MimicMotionGetPoses:
output_tensor = torch.cat((ref_pose_tensor.unsqueeze(0), output_tensor)) output_tensor = torch.cat((ref_pose_tensor.unsqueeze(0), output_tensor))
output_tensor = output_tensor.permute(0, 2, 3, 1).cpu().float() output_tensor = output_tensor.permute(0, 2, 3, 1).cpu().float()
return output_tensor, return output_tensor, output_tensor[1:]
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMimicMotionModel": DownloadAndLoadMimicMotionModel, "DownloadAndLoadMimicMotionModel": DownloadAndLoadMimicMotionModel,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论