SpectrogramΒΆ
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_valid = musdb.DB(download=True, subsets='train', split='valid')
mus_test = musdb.DB(download=True, subsets='test')
print('track numbers of mus_train: {}'.format(len(mus_train)))
print('track numbers of mus_valid: {}'.format(len(mus_valid)))
print('track numbers of mus_test: {}'.format(len(mus_test)))
track numbers of mus_train: 80
track numbers of mus_valid: 14
track numbers of mus_test: 50
import random
import numpy as np
from torch.utils.data import Dataset
class RandomMUSDB18TrainDataset(Dataset):
def __init__(self, musdb, sample_size):
self.musdb = musdb
self.sample_size = sample_size
def __len__(self) -> int:
return len(self.musdb)
def get_track(self, i, target=None):
if target is None:
return self.musdb[i].audio.T.astype(np.float32)
else:
return self.musdb[i].sources[target].audio.T.astype(np.float32)
def __getitem__(self, i):
mixture = self.get_track(i)
length = mixture.shape[1]
rand_start_pos = random.randint(0, length - self.sample_size - 1)
mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
for source
in self.musdb[i].sources.keys()}
return mixture, targets
class SingleTrackSet(Dataset):
def __init__(self, track, window_length, trim_length, overlap_ratio):
assert len(track.shape) == 2
assert track.shape[1] == 2 # check stereo audio
assert 0 <= overlap_ratio <= 0.5
self.is_overlap = 0 < overlap_ratio
self.window_length = window_length
self.trim_length = trim_length
self.true_samples = self.window_length - 2 * self.trim_length
self.hop_length = int(self.true_samples * (1 - overlap_ratio))
assert 0.5 * self.true_samples <= self.hop_length <= self.true_samples
self.lengths = [track.shape[0]]
self.num_tracks = 1
self.source_names = ['vocals', 'drums', 'bass', 'other']
import math
num_chunks = [math.ceil((length - self.true_samples) / self.hop_length) + 1 for length in self.lengths]
self.acc_chunk_final_ids = [sum(num_chunks[:i + 1]) for i in range(self.num_tracks)]
self.cached = track.astype(np.float32) if track.dtype is not np.float32 else track
def __len__(self):
return self.acc_chunk_final_ids[-1]
def __getitem__(self, idx):
output_mask = torch.ones((self.window_length, 2), requires_grad=False)
output_mask[:self.trim_length] *= 0
output_mask[-self.trim_length:] *= 0
if self.is_overlap:
self.overlapped_index_prev = self.true_samples - self.hop_length + self.trim_length
self.overlapped_index_next = - self.overlapped_index_prev
track_idx, start_pos = self.idx_to_track_offset(idx)
length = self.true_samples
left_padding_num = right_padding_num = self.trim_length
if track_idx is None:
raise StopIteration
mixture_length = self.lengths[track_idx]
if start_pos + length > mixture_length: # last
right_padding_num += self.true_samples - (mixture_length - start_pos)
length = None
if self.is_overlap:
if start_pos != 0:
output_mask[:self.overlapped_index_prev] *= 0.5
elif start_pos + length + self.trim_length < mixture_length:
right_padding_num = 0
length = length + self.trim_length
if self.is_overlap:
if start_pos != 0:
output_mask[: self.overlapped_index_prev] *= 0.5
if start_pos + self.hop_length < mixture_length:
output_mask[self.overlapped_index_next:] *= 0.5
if start_pos - self.trim_length >= 0:
left_padding_num = 0
start_pos = start_pos - self.trim_length
if length is not None:
length = length + self.trim_length
mixture = self.get_audio(start_pos, length)
mixture = np.concatenate((np.zeros((left_padding_num, 2), dtype=np.float32), mixture,
np.zeros((right_padding_num, 2), dtype=np.float32)), 0)
mixture = torch.from_numpy(mixture)
return mixture.T, output_mask.T
def idx_to_track_offset(self, idx):
for i, last_chunk in enumerate(self.acc_chunk_final_ids):
if idx < last_chunk:
if i != 0:
offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.hop_length
else:
offset = idx * self.hop_length
return i, offset
return None, None
def get_audio(self, pos=0, length=None):
track = self.cached
return track[pos:pos + length] if length is not None else track[pos:]
import torch
import torch.nn as nn
import torch.optim as optimizer
import torch.nn.functional as f
from torch.utils.data import DataLoader
import soundfile
import numpy as np
import pytorch_lightning as pl
class UNetConv2DForMDX(pl.LightningModule):
def __init__(self, n_fft=1024, hop_length=512, num_frame=128, depth=2, num_channels=24, groth_rate=2):
super().__init__()
self.sub_net_dict = nn.ModuleDict({
source: UNetConv2D(n_fft, hop_length, num_frame, depth, num_channels, groth_rate)
for source
in ['vocals', 'drums', 'bass', 'other']
})
for param in self.parameters():
if param.dim() > 1:
nn.init.kaiming_normal_(param)
def configure_optimizers(self):
return optimizer.Adam(self.parameters())
def stft(self, audio):
return self.sub_net_dict['vocals'].stft(audio)
def istft(self, spec):
return self.sub_net_dict['vocals'].istft(spec)
def forward(self, mixture):
estimate_dict = {
source: self.sub_net_dict[source](mixture)
for source
in self.sub_net_dict.keys()
}
return estimate_dict
def separate(self, mixture, overlap_ratio=0.5, batch_size=1):
estimate_dict = {
source: self.sub_net_dict[source].separate_internal(mixture, overlap_ratio, batch_size, self.device)
for source
in self.sub_net_dict.keys()
}
return estimate_dict
def training_step(self, batch, index):
mixture, source_dict = batch
estimate_dict = self(mixture)
loss_dict = {
source: f.mse_loss(self.stft(estimate_dict[source]), self.stft(source_dict[source]))
for source
in source_dict.keys()
}
for source in loss_dict.keys():
self.log('train/loss_{}'.format(source), loss_dict[source], on_epoch=True)
loss = sum(loss_dict.values())
self.log('train/loss', loss, on_epoch=True)
return loss
class UNetConv2D(nn.Module):
def __init__(self, n_fft, hop_length, num_frame, depth, num_channels, groth_rate):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.num_frame = num_frame
self.chunk_size = self.hop_length * (self.num_frame - 1)
self.encoder, self.decoder = nn.ModuleList(), nn.ModuleList()
in_channels, out_channels = 4, num_channels
for index in range(depth):
encode = [nn.Conv2d(in_channels, out_channels, (2, 2), (2, 2)), nn.ReLU()]
self.encoder.append(nn.Sequential(*encode))
decode = [nn.ConvTranspose2d(out_channels, in_channels, (2, 2), (2, 2)), nn.ReLU()]
if index == 0:
decode[-1] = nn.Sigmoid()
self.decoder.insert(0, nn.Sequential(*decode))
in_channels = out_channels
out_channels = groth_rate * out_channels
def forward(self, mixture):
input_spec = self.stft(mixture)
output_spec = self.estimate_spec(input_spec)
return self.istft(output_spec)
def estimate_spec(self, spec):
x = spec
saved = []
for encode in self.encoder:
x = encode(x)
saved.append(x)
saved[-1] = None
for decode, skip in zip(self.decoder, reversed(saved)):
if skip is not None:
x = x + skip
x = decode(x)
return x*spec
def separate_internal(self, mixture, overlap_ratio=0.5, batch_size=1, device='cpu') -> torch.Tensor:
self.eval()
with torch.no_grad():
trim_length = self.hop_length // 2
db = SingleTrackSet(mixture, self.chunk_size, trim_length, overlap_ratio)
separated = []
for item, mask in DataLoader(db, batch_size):
res = self(item.to(device))
res = res * mask.to(device)
res = res[:, :, trim_length:-trim_length].detach().cpu().transpose(-1, -2).numpy()
separated.append(res)
separated = np.concatenate(separated)
if db.is_overlap:
output = np.zeros_like(mixture)
hop_length = db.hop_length
for i, sep in enumerate(separated):
to = sep.shape[0]
if i * hop_length + sep.shape[0] > output.shape[0]:
to = sep.shape[0] - (i * hop_length + sep.shape[0] - output.shape[0])
output[i * hop_length:i * hop_length + to] += sep[:to]
separated = output
else:
separated = np.concatenate(separated, axis=0)
soundfile.write('temp.wav', separated, 44100)
return soundfile.read('temp.wav')[0]
def stft(self, mixture):
mixture = mixture.view(-1, mixture.shape[-1])
# note: torch.stft will only return complex tensors
spec = torch.stft(mixture, n_fft=self.n_fft, hop_length=self.hop_length, return_complex=True)
spec = torch.view_as_real(spec)
F, T = spec.shape[-3], spec.shape[-2]
spec = spec.view(-1, 2, F, T, 2)
# [B, 2:channel, F, T, 2:complex] => [B, 4, T, F]
spec = spec.transpose(-1, -3).reshape(-1, 4, T, F)
# drop last
return spec[:, :, :, :-1]
def istft(self, spec):
B, ch, T, _ = spec.shape
zero_pad = torch.zeros_like(spec[..., :1], requires_grad=True)
spec = torch.cat([spec, zero_pad], dim=-1)
# [B, 4, T, F] => [B, 2:ch, 2:complex, T, F] => [B, 2:ch, F, T, 2:complex]
spec = spec.view(B, 2, 2, T, -1).transpose(-1, -3)
spec = spec.view(2 * B, -1, T, 2).contiguous()
spec = torch.view_as_complex(spec)
target = torch.istft(spec, n_fft=self.n_fft, hop_length=self.hop_length, return_complex=False)
return target.view(B, 2, -1)
Model DefinitionΒΆ
model = UNetConv2DForMDX(n_fft=256, hop_length=128, num_frame=128, depth=6, num_channels=8, groth_rate=2)
Music Demixing with an Untrained ModelΒΆ
from IPython.display import Audio, display
estimated_dict = model.separate(mus_train[0].audio)
for source in estimated_dict.keys():
print(source)
display(Audio(estimated_dict[source].T, rate=44100))
vocals
drums
bass
other
from pytorch_lightning import Trainer
train_dataset = RandomMUSDB18TrainDataset(mus_train, sample_size=(128) * (128 - 1))
train_dataloader = DataLoader(train_dataset, batch_size=8)
if torch.cuda.is_available():
pl_trainer = Trainer(gpus=1, max_epochs=1)
else:
pl_trainer = Trainer(max_epochs=1)
pl_trainer.fit(model, train_dataloader=train_dataloader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:530: LightningDeprecationWarning: `trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'
rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------------
0 | sub_net_dict | ModuleDict | 1.4 M
--------------------------------------------
1.4 M Trainable params
0 Non-trainable params
1.4 M Total params
5.603 Total estimated model params size (MB)
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:105: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:326: UserWarning: The number of training samples (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
from IPython.display import Audio, display
estimated_dict = model.separate(mus_train[0].audio)
for source in estimated_dict.keys():
print(source)
display(Audio(estimated_dict[source].T, rate=44100))
vocals
drums
bass
other