You can't ship htdemucs_ft on iOS. You can't ship it on Android. You can't run it in a browser. PyTorch Mobile is a 2 GB install and a permission to break, MLX needs Apple Silicon, and the obvious answer — ONNX — has been "broken on htdemucs" across four open GitHub issues for three years.
I just shipped the first working ONNX export of htdemucs_ft. It runs in onnxruntime, is 1.31× faster than PyTorch on CPU, and is numerically equivalent to the original (max absolute difference: 0.000163 on drums, 0.000008 on vocals). All four specialist sub-models are on Hugging Face: StemSplitio/htdemucs-ft-onnx.
This is the engineering writeup — the 4 blockers that killed every prior attempt and the patches that beat them.
What You'll Learn
- ✅ Why
htdemucsis so hard to export (complex tensors, Python dynamism, fused C++ kernels) - ✅ How to replace
torch.stftwithConv1dwithout losing accuracy (5 × 10⁻⁶ round-trip diff) - ✅ How to patch
fractions.Fraction,random.randrange, andaten::_native_multi_head_attention - ✅ How to verify parity to 4 decimals before trusting the export
- ✅ Pure numpy +
onnxruntimeinference — no PyTorch at runtime
Prerequisites
pip install "torch>=2.4,<2.5" "torchaudio>=2.4,<2.5" demucs onnx onnxruntime numpy soundfile
No GPU required for export or inference. Tested on Apple M4 Pro and Linux x86_64. If you don't have demucs set up yet, follow the complete demucs local setup guide first.
Why doesn't htdemucs export to ONNX out of the box?
Short answer: Because the model uses four PyTorch features that no ONNX exporter has good answers for — complex tensors in the STFT, fractions.Fraction arithmetic in model.segment, random.randrange inside the cross-transformer, and the fused C++ aten::_native_multi_head_attention kernel.
Each one stops torch.onnx.export and torch.onnx.dynamo_export cold. You hit them in order; each new patch unblocks the next failure.
Blocker 1: complex64 STFT output
Short answer: Replace torch.stft with a Conv1d using sin/cos kernels that emit two real-valued channels.
The first op in HT-Demucs is:
z = torch.stft(x, n_fft=4096, hop_length=1024, window=hann,
win_length=4096, normalized=True, center=True,
return_complex=True, pad_mode="reflect")
return_complex=True returns a complex64 tensor. ONNX's STFT op (opset 17+) does not support complex outputs; every downstream slice/transpose fails. The workaround:
import math, torch
import torch.nn as nn
import torch.nn.functional as F
def _make_stft_kernels(n_fft: int):
n = torch.arange(n_fft, dtype=torch.float64)
window = torch.hann_window(n_fft, periodic=True, dtype=torch.float64)
norm = 1.0 / math.sqrt(n_fft)
k = torch.arange(n_fft // 2 + 1, dtype=torch.float64).unsqueeze(1)
angles = 2 * math.pi * k * n.unsqueeze(0) / n_fft
cos = (window * torch.cos(angles)) * norm
sin = (window * -torch.sin(angles)) * norm # negative for forward STFT
return cos.float().unsqueeze(1), sin.float().unsqueeze(1)
class RealSTFT(nn.Module):
def __init__(self, n_fft=4096, hop_length=1024):
super().__init__()
cos, sin = _make_stft_kernels(n_fft)
self.register_buffer("cos_kernel", cos)
self.register_buffer("sin_kernel", sin)
self.n_fft, self.hop_length = n_fft, hop_length
def forward(self, x):
x = F.pad(x.reshape(-1, 1, x.shape[-1]),
(self.n_fft // 2,) * 2, mode="reflect")
real = F.conv1d(x, self.cos_kernel, stride=self.hop_length)
imag = F.conv1d(x, self.sin_kernel, stride=self.hop_length)
return torch.stack([real, imag], dim=1) # (BN, 2, F, T) real
Verify against the real thing before going further:
x = torch.randn(1, 343980)
ref = torch.stft(x, n_fft=4096, hop_length=1024,
window=torch.hann_window(4096), win_length=4096,
normalized=True, center=True, return_complex=True,
pad_mode="reflect")
ref_real = torch.stack([ref.real, ref.imag], dim=1)
stft = RealSTFT()
out = stft(x).squeeze(0)
print("max abs diff:", (out - ref_real).abs().max().item()) # ~5e-6
5 × 10⁻⁶ is rounding noise. Use the same trick (ConvTranspose1d with conjugate kernels + overlap-add window-squared envelope) for the inverse STFT. Now every view_as_real / view_as_complex in the model's _magnitude and _mask methods can be rewritten to thread real-channel tensors through the whole forward pass.
Blocker 2: fractions.Fraction in model.segment
Short answer: Coerce to float before exporting.
Pretrained htdemucs_ft ships with model.segment = Fraction(39, 5) (= 7.8 seconds). Dynamo dies:
torch._dynamo.exc.Unsupported: call_function
UserDefinedClassVariable(<class 'fractions.Fraction'>)
Fix:
from fractions import Fraction
if isinstance(model.segment, Fraction):
model.segment = float(model.segment) # 7.8
Mathematically identical at inference. Trivial — but you don't get to the next blocker without it.
Blocker 3: random.randrange in the cross-transformer
Short answer: Monkey-patch the affected method to hardcode shift=0.
CrossTransformerEncoder._get_pos_embedding calls Python's random:
shift = random.randrange(self.sin_random_shift + 1)
At inference, sin_random_shift = 0, so random.randrange(1) always returns 0 — a no-op. But neither exporter can see through random and both bail out. Patch the method directly:
import types
from demucs.transformer import CrossTransformerEncoder, create_sin_embedding
def _get_pos_embedding_no_random(self_, T, B, C, device):
if self_.emb == "sin":
return create_sin_embedding(T, C, shift=0, device=device,
max_period=self_.max_period)
raise RuntimeError(f"emb={self_.emb} not handled at export")
for m in model.modules():
if isinstance(m, CrossTransformerEncoder):
m._get_pos_embedding = types.MethodType(_get_pos_embedding_no_random, m)
Blocker 4: aten::_native_multi_head_attention
Short answer: Replace nn.MultiheadAttention.forward with a plain Linear/bmm/softmax implementation.
Modern PyTorch short-circuits nn.MultiheadAttention.forward to a fused C++ kernel (_native_multi_head_attention) when its preconditions are met. That kernel has no ONNX symbolic at any opset:
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator
'aten::_native_multi_head_attention' to ONNX opset version 17 is not supported.
Patch every MHA instance's forward with a drop-in that uses only ops with stable ONNX symbolics:
def _onnx_friendly_mha_forward(self_, query, key, value,
key_padding_mask=None, need_weights=True,
attn_mask=None, average_attn_weights=True,
is_causal=False):
if self_.batch_first:
query, key, value = (t.transpose(0, 1) for t in (query, key, value))
tgt_len, bsz, embed_dim = query.shape
head_dim = embed_dim // self_.num_heads
if self_._qkv_same_embed_dim and torch.equal(query, key) and torch.equal(key, value):
q, k, v = F.linear(query, self_.in_proj_weight, self_.in_proj_bias).chunk(3, dim=-1)
else:
# cross-attention path: three separate projections
w_q, w_k, w_v = self_.in_proj_weight.chunk(3)
b_q, b_k, b_v = (self_.in_proj_bias.chunk(3)
if self_.in_proj_bias is not None else (None, None, None))
q = F.linear(query, w_q, b_q)
k = F.linear(key, w_k, b_k)
v = F.linear(value, w_v, b_v)
q = q.contiguous().view(tgt_len, bsz * self_.num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(-1, bsz * self_.num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * self_.num_heads, head_dim).transpose(0, 1)
attn = F.softmax(torch.bmm(q * head_dim ** -0.5, k.transpose(1, 2)), dim=-1)
out = torch.bmm(attn, v).transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
return self_.out_proj(out), None
for m in model.modules():
if isinstance(m, nn.MultiheadAttention):
m.forward = types.MethodType(_onnx_friendly_mha_forward, m)
Parity vs the fused kernel: 1 × 10⁻⁶ max diff. Safe.
The export call
Short answer: Legacy torch.onnx.export at opset 17, dynamo=False. dynamo_export dies on the patches anyway; legacy works.
from demucs.pretrained import get_model
bag = get_model("htdemucs_ft")
drums_model = bag.models[0].eval().cpu()
# apply all 4 patches above to drums_model
n = int(float(drums_model.segment) * int(bag.samplerate)) # 343980
dummy = torch.randn(1, 2, n, dtype=torch.float32)
torch.onnx.export(
drums_model,
dummy,
"htdemucs_ft_drums.onnx",
input_names=["mix"],
output_names=["stems"],
opset_version=17,
dynamo=False,
do_constant_folding=True,
)
316 MB per specialist. ~6.5 s export time on CPU. Passes onnx.checker.check_model. 24,765 nodes.
Repeat for bag.models[1] (bass), bag.models[2] (other), bag.models[3] (vocals). All four use the same architecture and patches.
Parity verification — the only acceptance test that matters
Short answer: Run the original PyTorch model and the ONNX model on the same fixed input, compute .abs().max(). Should be < 1e-3.
import numpy as np
import onnxruntime as ort
x = np.random.randn(1, 2, 343980).astype("float32")
sess = ort.InferenceSession("htdemucs_ft_drums.onnx",
providers=["CPUExecutionProvider"])
onnx_out = sess.run(["stems"], {"mix": x})[0]
torch_out = drums_model(torch.from_numpy(x)).detach().numpy()
print("max abs diff:", np.abs(onnx_out - torch_out).max()) # 1.63e-4
Per stem, against the original PyTorch htdemucs_ft at fp32:
| Stem | max abs diff |
|---|---|
| drums | 0.000163 |
| bass | 0.000011 |
| other | 0.000739 |
| vocals | 0.000008 |
All comfortably under the 1e-3 tolerance that fp32 reordering normally explains. SDR scores measured against MUSDB18-HQ are unchanged — if you want to verify yourself, query the 800-row leaderboard from pandas without re-running the benchmark.
Inference with zero PyTorch (pure numpy + onnxruntime)
Short answer: Load the four ONNX files, do overlap-add chunking, sum the per-model outputs.
import numpy as np
import onnxruntime as ort
import soundfile as sf
SOURCES = ["drums", "bass", "other", "vocals"]
CHUNK = 343980 # 7.8 s * 44100 Hz
HOP = CHUNK // 4
sessions = {
s: ort.InferenceSession(f"htdemucs_ft_{s}.onnx",
providers=["CPUExecutionProvider"])
for s in SOURCES
}
def separate(mix, sr=44100):
pad = CHUNK - (mix.shape[-1] % CHUNK)
mix_p = np.pad(mix, ((0, 0), (0, pad)), mode="constant")
out = {s: np.zeros_like(mix_p) for s in SOURCES}
weight = np.zeros(mix_p.shape[-1])
w = np.hanning(CHUNK).astype("float32")
for start in range(0, mix_p.shape[-1] - CHUNK + 1, HOP):
chunk = mix_p[:, start:start+CHUNK].astype("float32")[None]
for s in SOURCES:
y = sessions[s].run(["stems"], {"mix": chunk})[0][0]
target_row = SOURCES.index(s)
out[s][:, start:start+CHUNK] += y[target_row] * w
weight[start:start+CHUNK] += w
for s in SOURCES:
out[s] = (out[s] / np.maximum(weight, 1e-8))[:, :mix.shape[-1]]
return out
mix, sr = sf.read("song.wav", dtype="float32", always_2d=True)
stems = separate(mix.T, sr)
for s in SOURCES:
sf.write(f"{s}.wav", stems[s].T, sr)
Zero PyTorch. Zero MLX. Runs anywhere onnxruntime runs — iOS via onnxruntime-objc, Android via onnxruntime-android, browsers via onnxruntime-web.
Performance numbers
Short answer: ONNX Runtime CPU EP is 1.31× faster than PyTorch CPU. A single-specialist ONNX is ~5.7× faster than the full bag.
Apple M4 Pro, 3-minute song:
| Backend | Latency | Notes |
|---|---|---|
| ONNX CPU EP — single specialist | ~22 s | Use this for vocal removers / drum extractors |
| ONNX CPU EP — full 4-stem bag | ~88 s | All stems |
| PyTorch CPU — full bag | ~125 s | Baseline |
| PyTorch MPS — full bag | ~47 s | Apple GPU |
| ONNX CUDA — NVIDIA L4 (extrapolated) | ~6 s | Server-side deployment |
The single-specialist trick works because the htdemucs_ft bag is one-hot:
# the bag's per-model weight matrix
weights = [[1, 0, 0, 0], # drums = sub-model 0's drums output
[0, 1, 0, 0], # bass = sub-model 1's bass output
[0, 0, 1, 0], # other = sub-model 2's other output
[0, 0, 0, 1]] # vocals = sub-model 3's vocals output
Sub-model 3's vocals output is the bag's vocals output, bit-exact. If you're building a vocal remover, ship sub-model 3 alone (~316 MB ONNX, ~75 MB quantized) instead of the full ~1.26 GB bag at identical per-stem quality.
Wrapping Up
The five new ONNX repos — all MIT-licensed, all parity-verified:
- StemSplitio/htdemucs-ft-onnx — full 4-stem bag + numpy aggregator
- StemSplitio/htdemucs-ft-drums-onnx
- StemSplitio/htdemucs-ft-bass-onnx
- StemSplitio/htdemucs-ft-other-onnx
- StemSplitio/htdemucs-ft-vocals-onnx
If you'd rather skip the deployment plumbing entirely, the StemSplit hosted vocal remover runs the exact same htdemucs_ft weights with credits, queueing, and a dashboard — same model, just hosted. The full long-form writeup on the StemSplit blog has the iOS/Swift, Android/Kotlin, and onnxruntime-web code samples as well.
Open an issue on any of the repos if you find a stem where your parity diff exceeds 1e-3 — would love to hear about it.
























