Data Augmentation

The MUSDB train set contains 84 songs. Its size is quite large compared to the other open datasets (except for Slakh2100) but might not be sufficient to train deep networks.

In such case, data augmentation has been used. Data augmentation is critical to training good music demixing models. For example, most of the current state of the art models listed in the MUSDB18 leaderboard use data augmentation.

This section introduces a widely used data augmentation method proposed in [UPG+17].

We will add

to our previous MUSDB18RandomTrainDataset

MUSDB18RandomTrainDataset

import random
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')

random.randint(0, len(mus_train)-1)
44
import random
import numpy as np
from torch.utils.data import Dataset


class MUSDB18RandomTrainDataset(Dataset):

    def __init__(self, musdb, sample_size, remixing=False):
        self.musdb = musdb
        self.sample_size = sample_size
        self.remixing = remixing

    def __len__(self) -> int:
        return len(self.musdb)


    def __getitem__(self, i):
        
        if i >= len(self):
            raise StopIteration
        
        if self.remixing:
            
            targets = {
                source: self.get_track(random.randint(0, len(self.musdb)-1), source)
                for source
                in ['vocals', 'drums', 'bass', 'other']
            }
            
            start_pos = {
                source: random.randint(0, targets[source].shape[1] - self.sample_size - 1)
                for source
                in ['vocals', 'drums', 'bass', 'other']
            }
            
            targets = {
                source: targets[source][:, start_pos[source]:start_pos[source] + self.sample_size]
                for source
                in ['vocals', 'drums', 'bass', 'other']
            }            
            
            mixture = sum(targets.values())
        
        else:
            mixture = self.get_track(i)
            length = mixture.shape[1]
            rand_start_pos = random.randint(0, length - self.sample_size - 1)

            mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
            targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
                       for source
                       in self.musdb[i].sources.keys()}

        return mixture, targets
    
    
    def get_track(self, i, target=None):
        if target is None:
            return self.musdb[i].audio.T.astype(np.float32)
        else:
            return self.musdb[i].sources[target].audio.T.astype(np.float32)

Listen to samples!

You can listen to samples that a MUSDB18RandomTrainDataset with remixing=True option iterates as below.

from IPython.display import Audio, display
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3, remixing=True)
i=0
for mixture, source_dict in mus_rtd:
    print('track name: {}\n==================='.format('random remix'))
    print('mixture:')
    display(Audio(mixture, rate=44100))

    for source in source_dict.keys():
        print('{}:'.format(source))
        display(Audio(source_dict[source], rate=44100))
    break
track name: random remix
===================
mixture:
vocals:
drums:
bass:
other: