提交 e90ac6dd authored 作者: kijai's avatar kijai

lcm SVD as experimental option

上级 180b6eff
差异被折叠。
...@@ -19,15 +19,18 @@ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline ...@@ -19,15 +19,18 @@ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from mimicmotion.modules.unet import UNetSpatioTemporalConditionModel from mimicmotion.modules.unet import UNetSpatioTemporalConditionModel
from mimicmotion.modules.pose_net import PoseNet from mimicmotion.modules.pose_net import PoseNet
from lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
class MimicMotionModel(torch.nn.Module): class MimicMotionModel(torch.nn.Module):
def __init__(self, base_model_path): def __init__(self, base_model_path, lcm=False):
"""construnct base model components and load pretrained svd model except pose-net """construnct base model components and load pretrained svd model except pose-net
Args: Args:
base_model_path (str): pretrained svd model path base_model_path (str): pretrained svd model path
""" """
super().__init__() super().__init__()
unet_subfolder = "unet_lcm" if lcm else "unet"
self.unet = UNetSpatioTemporalConditionModel.from_config( self.unet = UNetSpatioTemporalConditionModel.from_config(
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet", variant="fp16")) UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder=unet_subfolder, variant="fp16"))
self.vae = AutoencoderKLTemporalDecoder.from_pretrained( self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
base_model_path, subfolder="vae", variant="fp16") base_model_path, subfolder="vae", variant="fp16")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
...@@ -55,6 +58,7 @@ class DownloadAndLoadMimicMotionModel: ...@@ -55,6 +58,7 @@ class DownloadAndLoadMimicMotionModel:
], { ], {
"default": 'fp16' "default": 'fp16'
}), }),
"lcm": ("BOOLEAN", {"default": False}),
}, },
} }
...@@ -64,7 +68,7 @@ class DownloadAndLoadMimicMotionModel: ...@@ -64,7 +68,7 @@ class DownloadAndLoadMimicMotionModel:
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "MimicMotionWrapper" CATEGORY = "MimicMotionWrapper"
def loadmodel(self, precision, model): def loadmodel(self, precision, model, lcm):
device = mm.get_torch_device() device = mm.get_torch_device()
mm.soft_empty_cache() mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
...@@ -86,27 +90,49 @@ class DownloadAndLoadMimicMotionModel: ...@@ -86,27 +90,49 @@ class DownloadAndLoadMimicMotionModel:
pbar.update(1) pbar.update(1)
svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1") svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1")
svd_lcm_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1-lcm", "unet_lcm")
if not os.path.exists(svd_path): if lcm and not os.path.exists(svd_lcm_path):
#raise ValueError(f"Please download stable-video-diffusion-img2vid-xt-1-1 to {svd_path}") print(f"Downloading AnimateLCM SVD model to: {model_path}")
print(f"Downloading SVD model to: {model_path}")
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download(repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1", snapshot_download(repo_id="Kijai/AnimateLCM-SVD-Comfy",
allow_patterns=[f"*.json", "*fp16*"], allow_patterns=[f"*.json", "*diffusion_pytorch_model.fp16.safetensors*"],
local_dir=svd_path, local_dir=svd_path,
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
else:
if not os.path.exists(svd_path):
print(f"Downloading SVD model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1",
allow_patterns=[f"*.json", "*fp16*"],
local_dir=svd_path,
local_dir_use_symlinks=False)
pbar.update(1) pbar.update(1)
mimicmotion_models = MimicMotionModel(svd_path).to(device=device).eval() mimicmotion_models = MimicMotionModel(svd_path, lcm=lcm).to(device=device).eval()
mimicmotion_models.load_state_dict(comfy.utils.load_torch_file(model_path), strict=False) mimicmotion_models.load_state_dict(comfy.utils.load_torch_file(model_path), strict=False)
if lcm:
lcm_noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=700.0,
sigma_data=1.0,
s_noise=1.0,
rho=7,
clip_denoised=False,
)
scheduler = lcm_noise_scheduler
else:
scheduler = mimicmotion_models.noise_scheduler
pipeline = MimicMotionPipeline( pipeline = MimicMotionPipeline(
vae=mimicmotion_models.vae, vae = mimicmotion_models.vae,
image_encoder=mimicmotion_models.image_encoder, image_encoder = mimicmotion_models.image_encoder,
unet=mimicmotion_models.unet, unet = mimicmotion_models.unet,
scheduler=mimicmotion_models.noise_scheduler, scheduler = scheduler,
feature_extractor=mimicmotion_models.feature_extractor, feature_extractor = mimicmotion_models.feature_extractor,
pose_net=mimicmotion_models.pose_net, pose_net = mimicmotion_models.pose_net,
) )
pipeline.unet.to(dtype) pipeline.unet.to(dtype)
...@@ -297,7 +323,7 @@ class MimicMotionGetPoses: ...@@ -297,7 +323,7 @@ class MimicMotionGetPoses:
ref_pose_img = draw_pose(ref_pose, height, width, include_body=include_body, include_hand=include_hand, include_face=include_face) ref_pose_img = draw_pose(ref_pose, height, width, include_body=include_body, include_hand=include_hand, include_face=include_face)
ref_pose_tensor = torch.tensor(np.array(ref_pose_img)) / 255 ref_pose_tensor = torch.tensor(np.array(ref_pose_img)) / 255
output_tensor = torch.cat((ref_pose_tensor.unsqueeze(0), output_tensor)) output_tensor = torch.cat((ref_pose_tensor.unsqueeze(0), output_tensor))
output_tensor = output_tensor.permute(0, 2, 3, 1).cpu().float() output_tensor = output_tensor.permute(0, 2, 3, 1).cpu().float()
return output_tensor, output_tensor[1:] return output_tensor, output_tensor[1:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论