MUSDB18 with PyTroch

This chapter will present how to make torch.utils.data.Dataset with MUSDB18 to train a music demixing model. MUSDB18RandomTrainDataset iterates each song of MUSDB18. In each iteration, it randomly selects sample_size in the given song.

MUSDB18RandomTrainDataset

import random
import numpy as np
from torch.utils.data import Dataset


class MUSDB18RandomTrainDataset(Dataset):

    def __init__(self, musdb, sample_size):
        self.musdb = musdb
        self.sample_size = sample_size

    def __len__(self) -> int:
        return len(self.musdb)


    def __getitem__(self, i):
        mixture = self.get_track(i)
        length = mixture.shape[1]
        rand_start_pos = random.randint(0, length - self.sample_size - 1)

        mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
        targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
                   for source
                   in self.musdb[i].sources.keys()}

        return mixture, targets
    
    
    def get_track(self, i, target=None):
        if target is None:
            return self.musdb[i].audio.T.astype(np.float32)
        else:
            return self.musdb[i].sources[target].audio.T.astype(np.float32)

Listen to samples!

You can listen to samples that a MUSDB18RandomTrainDataset with the 7 seconds MUSDB18 iterates as below.

from IPython.display import Audio, display
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3)

for i in range(1):
    mixture, source_dict = mus_rtd[i]
    print('track name: {}\n==================='.format(mus_train[i].title))
    print('mixture:')
    display(Audio(mixture, rate=mus_train[i].rate))

    for source in source_dict.keys():
        print('{}:'.format(source))
        display(Audio(source_dict[source], rate=mus_train[i].rate))
track name: NightOwl
===================
mixture:
vocals:
drums:
bass:
other:

Toy Training Scheme with MUSDB18RandomTrainDataset

Below is a toy training scheme with a MUSDB18RandomTrainDataset with the 7 seconds MUSDB18. Note that it does not contain the full MUSDB18 dataset.

from torch.utils.data import DataLoader

mus_train = # YOUR MUSDB #
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3)

for epoch in range(1):
    for mixture, source_dict in DataLoader(mus_rtd, batch_size=4):
        # estimation_dict = model(mixture)
        # loss = compute_loss(estimation_dict, source_dict)
        # gradient descent
        pass
    pass