MUSDB18 with PyTroch¶
This chapter will present how to make torch.utils.data.Dataset with MUSDB18 to train a music demixing model.
MUSDB18RandomTrainDataset iterates each song of MUSDB18. In each iteration, it randomly selects sample_size in the given song.
MUSDB18RandomTrainDataset¶
import random
import numpy as np
from torch.utils.data import Dataset
class MUSDB18RandomTrainDataset(Dataset):
def __init__(self, musdb, sample_size):
self.musdb = musdb
self.sample_size = sample_size
def __len__(self) -> int:
return len(self.musdb)
def __getitem__(self, i):
mixture = self.get_track(i)
length = mixture.shape[1]
rand_start_pos = random.randint(0, length - self.sample_size - 1)
mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
for source
in self.musdb[i].sources.keys()}
return mixture, targets
def get_track(self, i, target=None):
if target is None:
return self.musdb[i].audio.T.astype(np.float32)
else:
return self.musdb[i].sources[target].audio.T.astype(np.float32)
Listen to samples!¶
You can listen to samples that a MUSDB18RandomTrainDataset with the 7 seconds MUSDB18 iterates as below.
from IPython.display import Audio, display
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3)
for i in range(1):
mixture, source_dict = mus_rtd[i]
print('track name: {}\n==================='.format(mus_train[i].title))
print('mixture:')
display(Audio(mixture, rate=mus_train[i].rate))
for source in source_dict.keys():
print('{}:'.format(source))
display(Audio(source_dict[source], rate=mus_train[i].rate))
track name: NightOwl
===================
mixture:
vocals:
drums:
bass:
other:
Toy Training Scheme with MUSDB18RandomTrainDataset¶
Below is a toy training scheme with a MUSDB18RandomTrainDataset with the 7 seconds MUSDB18.
Note that it does not contain the full MUSDB18 dataset.
from torch.utils.data import DataLoader
mus_train = # YOUR MUSDB #
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3)
for epoch in range(1):
for mixture, source_dict in DataLoader(mus_rtd, batch_size=4):
# estimation_dict = model(mixture)
# loss = compute_loss(estimation_dict, source_dict)
# gradient descent
pass
pass