提交 97197011 authored 作者: kijai's avatar kijai

separate decode phase

上级 e3933422
...@@ -259,6 +259,7 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -259,6 +259,7 @@ class MimicMotionPipeline(DiffusionPipeline):
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM # decode decode_chunk_size frames at a time to avoid OOM
pbar = ProgressBar(latents.shape[0])
frames = [] frames = []
for i in range(0, latents.shape[0], decode_chunk_size): for i in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[i: i + decode_chunk_size].shape[0] num_frames_in = latents[i: i + decode_chunk_size].shape[0]
...@@ -272,6 +273,7 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -272,6 +273,7 @@ class MimicMotionPipeline(DiffusionPipeline):
self.vae.to(offload_device) self.vae.to(offload_device)
frames.append(frame.cpu()) frames.append(frame.cpu())
pbar.update(decode_chunk_size)
frames = torch.cat(frames, dim=0) frames = torch.cat(frames, dim=0)
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
...@@ -485,6 +487,7 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -485,6 +487,7 @@ class MimicMotionPipeline(DiffusionPipeline):
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
print("num_frames: ", num_frames)
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
......
...@@ -7,15 +7,13 @@ import folder_paths ...@@ -7,15 +7,13 @@ import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
import comfy.utils import comfy.utils
from comfy.clip_vision import clip_preprocess
from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.models import AutoencoderKLTemporalDecoder
from diffusers.schedulers import EulerDiscreteScheduler from diffusers.schedulers import EulerDiscreteScheduler
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
script_directory = os.path.dirname(os.path.abspath(__file__)) script_directory = os.path.dirname(os.path.abspath(__file__))
from .mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline from .mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline, tensor2vid
from .mimicmotion.modules.unet import UNetSpatioTemporalConditionModel from .mimicmotion.modules.unet import UNetSpatioTemporalConditionModel
from .mimicmotion.modules.pose_net import PoseNet from .mimicmotion.modules.pose_net import PoseNet
...@@ -140,7 +138,6 @@ class DownloadAndLoadMimicMotionModel: ...@@ -140,7 +138,6 @@ class DownloadAndLoadMimicMotionModel:
pipeline.pose_net.to(dtype) pipeline.pose_net.to(dtype)
pipeline.vae.to(dtype) pipeline.vae.to(dtype)
pipeline.image_encoder.to(dtype) pipeline.image_encoder.to(dtype)
pipeline.pose_net.to(dtype)
mimic_model = { mimic_model = {
'pipeline': pipeline, 'pipeline': pipeline,
...@@ -168,8 +165,8 @@ class MimicMotionSampler: ...@@ -168,8 +165,8 @@ class MimicMotionSampler:
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("images",) RETURN_NAMES = ("samples",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MimicMotionWrapper" CATEGORY = "MimicMotionWrapper"
...@@ -215,18 +212,47 @@ class MimicMotionSampler: ...@@ -215,18 +212,47 @@ class MimicMotionSampler:
min_guidance_scale=cfg_min, min_guidance_scale=cfg_min,
max_guidance_scale=cfg_max, max_guidance_scale=cfg_max,
decode_chunk_size=4, decode_chunk_size=4,
output_type="pt", output_type="latent",
device=device device=device
).frames ).frames
frames = frames.squeeze(0)[1:].permute(0, 2, 3, 1).cpu().float() #frames = frames.squeeze(0)[1:].permute(0, 2, 3, 1).cpu().float()
if not keep_model_loaded: if not keep_model_loaded:
pipeline.unet.to(offload_device) pipeline.unet.to(offload_device)
pipeline.vae.to(offload_device) pipeline.vae.to(offload_device)
mm.soft_empty_cache() mm.soft_empty_cache()
gc.collect() gc.collect()
return {"samples": frames},
class MimicMotionDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"mimic_pipeline": ("MIMICPIPE",),
"samples": ("LATENT",),
"decode_chunk_size": ("INT", {"default": 4, "min": 1, "max": 200, "step": 1})
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "process"
CATEGORY = "MimicMotionWrapper"
def process(self, mimic_pipeline, samples, decode_chunk_size):
mm.soft_empty_cache()
pipeline = mimic_pipeline['pipeline']
num_frames = samples['samples'].shape[0]
try:
frames = pipeline.decode_latents(samples['samples'], num_frames, decode_chunk_size)
except:
frames = pipeline.decode_latents(samples['samples'], num_frames, 1)
frames = tensor2vid(frames, pipeline.image_processor, output_type="pt")
frames = frames.squeeze(1)[1:].permute(0, 2, 3, 1).cpu().float()
return frames, return frames,
class MimicMotionGetPoses: class MimicMotionGetPoses:
...@@ -251,6 +277,8 @@ class MimicMotionGetPoses: ...@@ -251,6 +277,8 @@ class MimicMotionGetPoses:
from .mimicmotion.dwpose.util import draw_pose from .mimicmotion.dwpose.util import draw_pose
from .mimicmotion.dwpose.dwpose_detector import DWposeDetector from .mimicmotion.dwpose.dwpose_detector import DWposeDetector
assert ref_image.shape[1:3] == pose_images.shape[1:3], "ref_image and pose_images must have the same resolution"
yolo_model = "yolox_l.onnx" yolo_model = "yolox_l.onnx"
dw_pose_model = "dw-ll_ucoco_384.onnx" dw_pose_model = "dw-ll_ucoco_384.onnx"
model_base_path = os.path.join(script_directory, "models", "DWPose") model_base_path = os.path.join(script_directory, "models", "DWPose")
...@@ -331,11 +359,13 @@ class MimicMotionGetPoses: ...@@ -331,11 +359,13 @@ class MimicMotionGetPoses:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMimicMotionModel": DownloadAndLoadMimicMotionModel, "DownloadAndLoadMimicMotionModel": DownloadAndLoadMimicMotionModel,
"MimicMotionSampler": MimicMotionSampler, "MimicMotionSampler": MimicMotionSampler,
"MimicMotionGetPoses": MimicMotionGetPoses "MimicMotionGetPoses": MimicMotionGetPoses,
"MimicMotionDecode": MimicMotionDecode
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMimicMotionModel": "DownloadAndLoadMimicMotionModel", "DownloadAndLoadMimicMotionModel": "DownloadAndLoadMimicMotionModel",
"MimicMotionSampler": "MimicMotionSampler", "MimicMotionSampler": "MimicMotionSampler",
"MimicMotionGetPoses": "MimicMotionGetPoses" "MimicMotionGetPoses": "MimicMotionGetPoses",
"MimicMotionDecode": "MimicMotionDecode"
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论