Data Augmentation¶
The MUSDB train set contains 84 songs. Its size is quite large compared to the other open datasets (except for Slakh2100) but might not be sufficient to train deep networks.
In such case, data augmentation has been used. Data augmentation is critical to training good music demixing models. For example, most of the current state of the art models listed in the MUSDB18 leaderboard use data augmentation.
This section introduces a widely used data augmentation method proposed in [UPG+17].
We will add
remix sources [UPG+17]
random scaling [UPG+17]
swap stereo channels [UPG+17]
pitch shift [CHRP19]
tempo shift [CHRP19]
to our previous MUSDB18RandomTrainDataset
MUSDB18RandomTrainDataset¶
import random
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
random.randint(0, len(mus_train)-1)
44
import random
import numpy as np
from torch.utils.data import Dataset
class MUSDB18RandomTrainDataset(Dataset):
def __init__(self, musdb, sample_size, remixing=False):
self.musdb = musdb
self.sample_size = sample_size
self.remixing = remixing
def __len__(self) -> int:
return len(self.musdb)
def __getitem__(self, i):
if i >= len(self):
raise StopIteration
if self.remixing:
targets = {
source: self.get_track(random.randint(0, len(self.musdb)-1), source)
for source
in ['vocals', 'drums', 'bass', 'other']
}
start_pos = {
source: random.randint(0, targets[source].shape[1] - self.sample_size - 1)
for source
in ['vocals', 'drums', 'bass', 'other']
}
targets = {
source: targets[source][:, start_pos[source]:start_pos[source] + self.sample_size]
for source
in ['vocals', 'drums', 'bass', 'other']
}
mixture = sum(targets.values())
else:
mixture = self.get_track(i)
length = mixture.shape[1]
rand_start_pos = random.randint(0, length - self.sample_size - 1)
mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
for source
in self.musdb[i].sources.keys()}
return mixture, targets
def get_track(self, i, target=None):
if target is None:
return self.musdb[i].audio.T.astype(np.float32)
else:
return self.musdb[i].sources[target].audio.T.astype(np.float32)
Listen to samples!¶
You can listen to samples that a MUSDB18RandomTrainDataset with remixing=True option iterates as below.
from IPython.display import Audio, display
import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_rtd = MUSDB18RandomTrainDataset(mus_train, sample_size=44100*3, remixing=True)
i=0
for mixture, source_dict in mus_rtd:
print('track name: {}\n==================='.format('random remix'))
print('mixture:')
display(Audio(mixture, rate=44100))
for source in source_dict.keys():
print('{}:'.format(source))
display(Audio(source_dict[source], rate=44100))
break
track name: random remix
===================
mixture:
vocals:
drums:
bass:
other: