提交 b66fe529 authored 作者: kijai's avatar kijai

Add strength controls

上级 c7793ad3
...@@ -368,6 +368,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -368,6 +368,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor, added_time_ids: torch.Tensor,
pose_latents: torch.Tensor = None, pose_latents: torch.Tensor = None,
pose_strength: float = 1.0,
image_only_indicator: bool = False, image_only_indicator: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
...@@ -437,11 +438,12 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -437,11 +438,12 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
emb = emb.repeat_interleave(num_frames, dim=0) emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
if pose_latents is not None: if pose_latents is not None:
sample = sample + pose_latents sample = sample + pose_latents * pose_strength
image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \ image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
......
...@@ -124,7 +124,8 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -124,7 +124,8 @@ class MimicMotionPipeline(DiffusionPipeline):
image: PipelineImageInput, image: PipelineImageInput,
device: Union[str, torch.device], device: Union[str, torch.device],
num_videos_per_prompt: int, num_videos_per_prompt: int,
do_classifier_free_guidance: bool): do_classifier_free_guidance: bool,
image_embed_strength: float = 1.0):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
# if not isinstance(image, torch.Tensor): # if not isinstance(image, torch.Tensor):
...@@ -160,6 +161,7 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -160,6 +161,7 @@ class MimicMotionPipeline(DiffusionPipeline):
bs_embed, seq_len, _ = image_embeddings.shape bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
image_embeddings = image_embeddings * image_embed_strength
if do_classifier_free_guidance: if do_classifier_free_guidance:
negative_image_embeddings = torch.zeros_like(image_embeddings) negative_image_embeddings = torch.zeros_like(image_embeddings)
...@@ -169,6 +171,8 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -169,6 +171,8 @@ class MimicMotionPipeline(DiffusionPipeline):
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
return image_embeddings return image_embeddings
...@@ -179,8 +183,10 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -179,8 +183,10 @@ class MimicMotionPipeline(DiffusionPipeline):
): ):
# Get latents_pose # Get latents_pose
pose_latents = self.pose_net(pose_image) pose_latents = self.pose_net(pose_image)
print(pose_latents.shape)
if do_classifier_free_guidance: if do_classifier_free_guidance:
print("doing classifier free guidance")
negative_pose_latents = torch.zeros_like(pose_latents) negative_pose_latents = torch.zeros_like(pose_latents)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -371,6 +377,10 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -371,6 +377,10 @@ class MimicMotionPipeline(DiffusionPipeline):
self, self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
image_pose: Union[torch.FloatTensor], image_pose: Union[torch.FloatTensor],
pose_strength: float = 1.0,
pose_start_percent: float = 0.0,
pose_end_percent: float = 1.0,
image_embed_strength: float = 1.0,
height: int = 576, height: int = 576,
width: int = 1024, width: int = 1024,
num_frames: Optional[int] = None, num_frames: Optional[int] = None,
...@@ -508,7 +518,7 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -508,7 +518,7 @@ class MimicMotionPipeline(DiffusionPipeline):
self._guidance_scale = max_guidance_scale self._guidance_scale = max_guidance_scale
# 3. Encode input image # 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance, image_embed_strength=image_embed_strength)
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here. # is why it is reduced here.
...@@ -589,6 +599,13 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -589,6 +599,13 @@ class MimicMotionPipeline(DiffusionPipeline):
self.unet.to(device) self.unet.to(device)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# Calculate the actual start and end steps based on percentages
start_step_index = round(self._num_timesteps * pose_start_percent)
end_step_index = round(self._num_timesteps * pose_end_percent)
print(f"start_step_index: {start_step_index}, end_step_index: {end_step_index}")
pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames) pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)
indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
range(0, num_frames - tile_size + 1, tile_size - tile_overlap)] range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]
...@@ -610,13 +627,26 @@ class MimicMotionPipeline(DiffusionPipeline): ...@@ -610,13 +627,26 @@ class MimicMotionPipeline(DiffusionPipeline):
# image_pose = pixel_values_pose[:, frame_start:frame_start + self.num_frames, ...] # image_pose = pixel_values_pose[:, frame_start:frame_start + self.num_frames, ...]
weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
weight = torch.minimum(weight, 2 - weight) weight = torch.minimum(weight, 2 - weight)
for idx in indices: for idx in indices:
# Check if the current timestep is within the start and end step range
if start_step_index <= i <= end_step_index:
# Apply pose_latents as currently done
print(f"Applying pose on step {i}")
pose_latents_to_use = pose_latents[:, idx].flatten(0, 1)
else:
print(f"Not applying pose on step {i}")
# Apply an alternative if pose_latents should not be used outside this range
# This could be zeros, or any other placeholder logic you define.
pose_latents_to_use = torch.zeros_like(pose_latents[:, idx].flatten(0, 1))
_noise_pred = self.unet( _noise_pred = self.unet(
latent_model_input[:, idx], latent_model_input[:, idx],
t, t,
encoder_hidden_states=image_embeddings, encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids, added_time_ids=added_time_ids,
pose_latents=pose_latents[:, idx].flatten(0, 1), pose_latents=pose_latents_to_use,
pose_strength=pose_strength,
image_only_indicator=image_only_indicator, image_only_indicator=image_only_indicator,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -242,6 +242,10 @@ class MimicMotionSampler: ...@@ -242,6 +242,10 @@ class MimicMotionSampler:
}, },
"optional": { "optional": {
"optional_scheduler": ("DIFFUSERS_SCHEDULER",), "optional_scheduler": ("DIFFUSERS_SCHEDULER",),
"pose_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"pose_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"pose_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"image_embed_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
} }
} }
...@@ -251,7 +255,7 @@ class MimicMotionSampler: ...@@ -251,7 +255,7 @@ class MimicMotionSampler:
CATEGORY = "MimicMotionWrapper" CATEGORY = "MimicMotionWrapper"
def process(self, mimic_pipeline, ref_image, pose_images, cfg_min, cfg_max, steps, seed, noise_aug_strength, fps, keep_model_loaded, def process(self, mimic_pipeline, ref_image, pose_images, cfg_min, cfg_max, steps, seed, noise_aug_strength, fps, keep_model_loaded,
context_size, context_overlap, optional_scheduler=None): context_size, context_overlap, optional_scheduler=None, pose_strength=1.0, image_embed_strength=1.0, pose_start_percent=0.0, pose_end_percent=1.0):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
mm.unload_all_models() mm.unload_all_models()
...@@ -307,7 +311,11 @@ class MimicMotionSampler: ...@@ -307,7 +311,11 @@ class MimicMotionSampler:
decode_chunk_size=4, decode_chunk_size=4,
output_type="latent", output_type="latent",
device=device, device=device,
sigmas=sigmas sigmas=sigmas,
pose_strength=pose_strength,
pose_start_percent=pose_start_percent,
pose_end_percent=pose_end_percent,
image_embed_strength=image_embed_strength
).frames ).frames
if not keep_model_loaded: if not keep_model_loaded:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论