提交 e90ac6dd authored 作者: kijai's avatar kijai

lcm SVD as experimental option

上级 180b6eff
差异被折叠。
......@@ -19,15 +19,18 @@ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from mimicmotion.modules.unet import UNetSpatioTemporalConditionModel
from mimicmotion.modules.pose_net import PoseNet
from lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
class MimicMotionModel(torch.nn.Module):
def __init__(self, base_model_path):
def __init__(self, base_model_path, lcm=False):
"""construnct base model components and load pretrained svd model except pose-net
Args:
base_model_path (str): pretrained svd model path
"""
super().__init__()
unet_subfolder = "unet_lcm" if lcm else "unet"
self.unet = UNetSpatioTemporalConditionModel.from_config(
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet", variant="fp16"))
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder=unet_subfolder, variant="fp16"))
self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
base_model_path, subfolder="vae", variant="fp16")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
......@@ -55,6 +58,7 @@ class DownloadAndLoadMimicMotionModel:
], {
"default": 'fp16'
}),
"lcm": ("BOOLEAN", {"default": False}),
},
}
......@@ -64,7 +68,7 @@ class DownloadAndLoadMimicMotionModel:
FUNCTION = "loadmodel"
CATEGORY = "MimicMotionWrapper"
def loadmodel(self, precision, model):
def loadmodel(self, precision, model, lcm):
device = mm.get_torch_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
......@@ -86,9 +90,17 @@ class DownloadAndLoadMimicMotionModel:
pbar.update(1)
svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1")
svd_lcm_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1-lcm", "unet_lcm")
if lcm and not os.path.exists(svd_lcm_path):
print(f"Downloading AnimateLCM SVD model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Kijai/AnimateLCM-SVD-Comfy",
allow_patterns=[f"*.json", "*diffusion_pytorch_model.fp16.safetensors*"],
local_dir=svd_path,
local_dir_use_symlinks=False)
else:
if not os.path.exists(svd_path):
#raise ValueError(f"Please download stable-video-diffusion-img2vid-xt-1-1 to {svd_path}")
print(f"Downloading SVD model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1",
......@@ -97,16 +109,30 @@ class DownloadAndLoadMimicMotionModel:
local_dir_use_symlinks=False)
pbar.update(1)
mimicmotion_models = MimicMotionModel(svd_path).to(device=device).eval()
mimicmotion_models = MimicMotionModel(svd_path, lcm=lcm).to(device=device).eval()
mimicmotion_models.load_state_dict(comfy.utils.load_torch_file(model_path), strict=False)
if lcm:
lcm_noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=700.0,
sigma_data=1.0,
s_noise=1.0,
rho=7,
clip_denoised=False,
)
scheduler = lcm_noise_scheduler
else:
scheduler = mimicmotion_models.noise_scheduler
pipeline = MimicMotionPipeline(
vae=mimicmotion_models.vae,
image_encoder=mimicmotion_models.image_encoder,
unet=mimicmotion_models.unet,
scheduler=mimicmotion_models.noise_scheduler,
feature_extractor=mimicmotion_models.feature_extractor,
pose_net=mimicmotion_models.pose_net,
vae = mimicmotion_models.vae,
image_encoder = mimicmotion_models.image_encoder,
unet = mimicmotion_models.unet,
scheduler = scheduler,
feature_extractor = mimicmotion_models.feature_extractor,
pose_net = mimicmotion_models.pose_net,
)
pipeline.unet.to(dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论