提交 58bd4f5a authored 作者: kijai's avatar kijai

include unet config

上级 9656f049
{
"_class_name": "UNetSpatioTemporalConditionModel",
"_diffusers_version": "0.24.0.dev0",
"_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet",
"addition_time_embed_dim": 256,
"block_out_channels": [
320,
640,
1280,
1280
],
"cross_attention_dim": 1024,
"down_block_types": [
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal"
],
"in_channels": 8,
"layers_per_block": 2,
"num_attention_heads": [
5,
10,
20,
20
],
"num_frames": 25,
"out_channels": 4,
"projection_class_embeddings_input_dim": 768,
"sample_size": 96,
"transformer_layers_per_block": 1,
"up_block_types": [
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal"
]
}
...@@ -19,8 +19,13 @@ from .mimicmotion.modules.pose_net import PoseNet ...@@ -19,8 +19,13 @@ from .mimicmotion.modules.pose_net import PoseNet
from .lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler from .lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
from accelerate import init_empty_weights from contextlib import nullcontext
from accelerate.utils import set_module_tensor_to_device try:
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
is_accelerate_available = True
except:
pass
def loglinear_interp(t_steps, num_steps): def loglinear_interp(t_steps, num_steps):
...@@ -36,28 +41,6 @@ def loglinear_interp(t_steps, num_steps): ...@@ -36,28 +41,6 @@ def loglinear_interp(t_steps, num_steps):
interped_ys = np.exp(new_ys)[::-1].copy() interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys return interped_ys
class MimicMotionModel(torch.nn.Module):
def __init__(self, base_model_path, lcm=False):
"""construnct base model components and load pretrained svd model except pose-net
Args:
base_model_path (str): pretrained svd model path
"""
super().__init__()
unet_subfolder = "unet_lcm" if lcm else "unet"
self.unet = UNetSpatioTemporalConditionModel.from_config(
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder=unet_subfolder, variant="fp16"))
self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
base_model_path, subfolder="vae", variant="fp16")
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
base_model_path, subfolder="image_encoder", variant="fp16")
self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(
base_model_path, subfolder="scheduler")
self.feature_extractor = CLIPImageProcessor.from_pretrained(
base_model_path, subfolder="feature_extractor")
# pose_net
self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0])
class DownloadAndLoadMimicMotionModel: class DownloadAndLoadMimicMotionModel:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
...@@ -116,13 +99,16 @@ class DownloadAndLoadMimicMotionModel: ...@@ -116,13 +99,16 @@ class DownloadAndLoadMimicMotionModel:
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
pbar.update(1) pbar.update(1)
unet_config = UNetSpatioTemporalConditionModel.load_config(svd_path, subfolder="unet", variant="fp16") unet_config = UNetSpatioTemporalConditionModel.load_config(os.path.join(script_directory, "configs", "unet_config.json"))
print("Loading UNET") print("Loading UNET")
with (init_empty_weights()): with (init_empty_weights() if is_accelerate_available else nullcontext()):
self.unet = UNetSpatioTemporalConditionModel.from_config(unet_config) self.unet = UNetSpatioTemporalConditionModel.from_config(unet_config)
sd = comfy.utils.load_torch_file(os.path.join(model_path)) sd = comfy.utils.load_torch_file(os.path.join(model_path))
for key in sd: if is_accelerate_available:
set_module_tensor_to_device(self.unet, key, dtype=dtype, device=device, value=sd[key]) for key in sd:
set_module_tensor_to_device(self.unet, key, dtype=dtype, device=device, value=sd[key])
else:
self.unet.load_state_dict(sd, strict=False)
del sd del sd
pbar.update(1) pbar.update(1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论