Skip to content

fix: convert boundary_ratio from timestep space to index space via inverse sigma_shift#32

Open
Auraithm wants to merge 2 commits intoOpenMOSS:mainfrom
Auraithm:main
Open

fix: convert boundary_ratio from timestep space to index space via inverse sigma_shift#32
Auraithm wants to merge 2 commits intoOpenMOSS:mainfrom
Auraithm:main

Conversation

@Auraithm
Copy link

Summary

boundary_ratio=0.9 represents the DiT1/DiT2 switch point at timestep 900. In inference, this value is correctly compared against actual timestep values. However, in training_step, it was incorrectly used as an index fraction for max/min_timestep_boundary.

Due to sigma_shift (flow_match.py L57), the scheduler's timestep array is non-linearly spaced — timestep 900 sits at index ~250 (shift=3), not index 900. Using 0.9 as the index boundary caused DiT1 to be trained on 90% of indices and DiT2 on only 10%, while in inference DiT2 handles the majority of denoising (timestep 900→0).

Fix

Apply inverse sigma_shift to convert boundary_ratio from timestep space to index space before passing it to max/min_timestep_boundary. This aligns training with inference behavior.

  • shift=3 → index boundary = 0.25
  • shift=5 → index boundary = 0.358 (matches DiffSynth-Studio PR #806)

Reference

@tianyilt
Copy link
Collaborator

上下文:这里的 boundary 应该是 sigma_shift 前的 index 边界
比如modelscope/DiffSynth-Studio#806

因此不要硬编码:switch_dit_boundary: float = 0.358, #FIXME(dhyu):

@tianyilt
Copy link
Collaborator

@Auraithm 最好贴个测试通过的例子

…ft formula

Replace the analytical inverse sigma_shift formula with a direct table
lookup on scheduler.timesteps. This avoids floating-point precision
issues (e.g. 0.24999 vs 0.25) and the sigma_min != 0 approximation
error, giving an exact boundary at the intended timestep threshold.
@Auraithm
Copy link
Author

测试命令

cd xxx/MOVA && python -c "
import torch
from mova.diffusion.schedulers.flow_match import FlowMatchScheduler
from mova.diffusion.pipelines.mova_train import TimestepConfig, compute_density_for_timestep_sampling

scheduler = FlowMatchScheduler(num_train_timesteps=1000, shift=3.0)
scheduler.set_timesteps(1000, training=True)

boundary_ratio = 0.9
boundary = (scheduler.timesteps >= boundary_ratio * scheduler.num_train_timesteps).sum().item() / scheduler.num_train_timesteps
boundary_idx = int(boundary * 1000)

print(f'boundary = {boundary}, boundary_idx = {boundary_idx}')
print(f'timesteps[{boundary_idx-1}] = {scheduler.timesteps[boundary_idx-1].item():.2f}')
print(f'timesteps[{boundary_idx}] = {scheduler.timesteps[boundary_idx].item():.2f}')

print()
print('--- 偶数步 (DiT1) 采样 100 次 ---')
min_ts, max_ts = float('inf'), float('-inf')
for _ in range(100):
    u = compute_density_for_timestep_sampling('uniform', 1, min_timestep_boundary=0.0, max_timestep_boundary=boundary)
    tid = torch.clamp(torch.floor(u * 1000).long(), 0, boundary_idx - 1).item()
    ts = scheduler.timesteps[tid].item()
    min_ts, max_ts = min(min_ts, ts), max(max_ts, ts)
assert min_ts >= 900
print(f'[PASS] [{min_ts:.2f}, {max_ts:.2f}], 全部 >= 900')

print()
print('--- 奇数步 (DiT2) 采样 100 次 ---')
min_ts, max_ts = float('inf'), float('-inf')
for _ in range(100):
    u = compute_density_for_timestep_sampling('uniform', 1, min_timestep_boundary=boundary, max_timestep_boundary=1.0)
    tid = torch.clamp(torch.floor(u * 1000).long(), boundary_idx, 999).item()
    ts = scheduler.timesteps[tid].item()
    min_ts, max_ts = min(min_ts, ts), max(max_ts, ts)
assert max_ts < 900
print(f'[PASS] [{min_ts:.2f}, {max_ts:.2f}], 全部 < 900')

print()
print('ALL CHECKS PASSED')
"

结果:

Set TORCH_CUDA_ARCH_LIST to 9.0
flash_attn_interface loaded from flash_attn_interface
boundary = 0.251, boundary_idx = 251
timesteps[250] = 900.24
timesteps[251] = 899.76

--- 偶数步 (DiT1) 采样 100 次 ---
[PASS] [900.72, 1000.00], 全部 >= 900

--- 奇数步 (DiT2) 采样 100 次 ---
[PASS] [14.82, 898.32], 全部 < 900

ALL CHECKS PASSED

Copy link
Collaborator

@yhzx233 yhzx233 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也许这个boundary可以提前算好存为成员变量

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants