SpectrogramΒΆ

import musdb
mus_train = musdb.DB(download=True, subsets='train', split='train')
mus_valid = musdb.DB(download=True, subsets='train', split='valid')
mus_test = musdb.DB(download=True, subsets='test')

print('track numbers of mus_train: {}'.format(len(mus_train)))
print('track numbers of mus_valid: {}'.format(len(mus_valid)))
print('track numbers of mus_test: {}'.format(len(mus_test)))
track numbers of mus_train: 80
track numbers of mus_valid: 14
track numbers of mus_test: 50
import random
import numpy as np
from torch.utils.data import Dataset


class RandomMUSDB18TrainDataset(Dataset):

    def __init__(self, musdb, sample_size):
        self.musdb = musdb
        self.sample_size = sample_size

    def __len__(self) -> int:
        return len(self.musdb)

    def get_track(self, i, target=None):
        if target is None:
            return self.musdb[i].audio.T.astype(np.float32)
        else:
            return self.musdb[i].sources[target].audio.T.astype(np.float32)

    def __getitem__(self, i):
        mixture = self.get_track(i)
        length = mixture.shape[1]
        rand_start_pos = random.randint(0, length - self.sample_size - 1)

        mixture = mixture[:, rand_start_pos: rand_start_pos + self.sample_size]
        targets = {source: self.get_track(i, source)[:, rand_start_pos: rand_start_pos + self.sample_size]
                   for source
                   in self.musdb[i].sources.keys()}

        return mixture, targets

class SingleTrackSet(Dataset):

    def __init__(self, track, window_length, trim_length, overlap_ratio):

        assert len(track.shape) == 2
        assert track.shape[1] == 2  # check stereo audio
        assert 0 <= overlap_ratio <= 0.5
        self.is_overlap = 0 < overlap_ratio

        self.window_length = window_length
        self.trim_length = trim_length

        self.true_samples = self.window_length - 2 * self.trim_length
        self.hop_length = int(self.true_samples * (1 - overlap_ratio))
        assert 0.5 * self.true_samples <= self.hop_length <= self.true_samples

        self.lengths = [track.shape[0]]
        self.num_tracks = 1
        self.source_names = ['vocals', 'drums', 'bass', 'other']

        import math
        num_chunks = [math.ceil((length - self.true_samples) / self.hop_length) + 1 for length in self.lengths]

        self.acc_chunk_final_ids = [sum(num_chunks[:i + 1]) for i in range(self.num_tracks)]

        self.cached = track.astype(np.float32) if track.dtype is not np.float32 else track

    def __len__(self):
        return self.acc_chunk_final_ids[-1]

    def __getitem__(self, idx):

        output_mask = torch.ones((self.window_length, 2), requires_grad=False)
        output_mask[:self.trim_length] *= 0
        output_mask[-self.trim_length:] *= 0
        if self.is_overlap:
            self.overlapped_index_prev = self.true_samples - self.hop_length + self.trim_length
            self.overlapped_index_next = - self.overlapped_index_prev

        track_idx, start_pos = self.idx_to_track_offset(idx)

        length = self.true_samples
        left_padding_num = right_padding_num = self.trim_length
        if track_idx is None:
            raise StopIteration
        mixture_length = self.lengths[track_idx]

        if start_pos + length > mixture_length:  # last
            right_padding_num += self.true_samples - (mixture_length - start_pos)
            length = None

            if self.is_overlap:
                if start_pos != 0:  
                    output_mask[:self.overlapped_index_prev] *= 0.5

        elif start_pos + length + self.trim_length < mixture_length:
            right_padding_num = 0
            length = length + self.trim_length

            if self.is_overlap:
                if start_pos != 0:
                    output_mask[: self.overlapped_index_prev] *= 0.5
                if start_pos + self.hop_length < mixture_length:
                    output_mask[self.overlapped_index_next:] *= 0.5

        if start_pos - self.trim_length >= 0:
            left_padding_num = 0
            start_pos = start_pos - self.trim_length
            if length is not None:
                length = length + self.trim_length

        mixture = self.get_audio(start_pos, length)

        mixture = np.concatenate((np.zeros((left_padding_num, 2), dtype=np.float32), mixture,
                                  np.zeros((right_padding_num, 2), dtype=np.float32)), 0)

        mixture = torch.from_numpy(mixture)

        return mixture.T, output_mask.T

    def idx_to_track_offset(self, idx):

        for i, last_chunk in enumerate(self.acc_chunk_final_ids):
            if idx < last_chunk:
                if i != 0:
                    offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.hop_length
                else:
                    offset = idx * self.hop_length
                return i, offset

        return None, None

    def get_audio(self, pos=0, length=None):

        track = self.cached
        return track[pos:pos + length] if length is not None else track[pos:]
import torch
import torch.nn as nn
import torch.optim as optimizer
import torch.nn.functional as f
from torch.utils.data import DataLoader

import soundfile
import numpy as np
import pytorch_lightning as pl


class UNetConv2DForMDX(pl.LightningModule):

    def __init__(self, n_fft=1024, hop_length=512, num_frame=128, depth=2, num_channels=24, groth_rate=2):
        super().__init__()
        self.sub_net_dict = nn.ModuleDict({
            source: UNetConv2D(n_fft, hop_length, num_frame, depth, num_channels, groth_rate)
            for source
            in ['vocals', 'drums', 'bass', 'other']
        })

        for param in self.parameters():
            if param.dim() > 1:
                nn.init.kaiming_normal_(param)

    def configure_optimizers(self):
        return optimizer.Adam(self.parameters())
    
    def stft(self, audio):
        return self.sub_net_dict['vocals'].stft(audio)
    
    def istft(self, spec):
        return self.sub_net_dict['vocals'].istft(spec)

    def forward(self, mixture):
        estimate_dict = {
            source: self.sub_net_dict[source](mixture)
            for source
            in self.sub_net_dict.keys()
        }
        return estimate_dict

    def separate(self, mixture, overlap_ratio=0.5, batch_size=1):
        estimate_dict = {
            source: self.sub_net_dict[source].separate_internal(mixture, overlap_ratio, batch_size, self.device)
            for source
            in self.sub_net_dict.keys()
        }
        return estimate_dict

    
    def training_step(self, batch, index):
        mixture, source_dict = batch
        estimate_dict = self(mixture)

        loss_dict = {
            source: f.mse_loss(self.stft(estimate_dict[source]), self.stft(source_dict[source]))
            for source
            in source_dict.keys()
        }

        for source in loss_dict.keys():
            self.log('train/loss_{}'.format(source), loss_dict[source], on_epoch=True)

        loss = sum(loss_dict.values())
        self.log('train/loss', loss, on_epoch=True)
        return loss


class UNetConv2D(nn.Module):
    def __init__(self, n_fft, hop_length, num_frame, depth, num_channels, groth_rate):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.num_frame = num_frame
        self.chunk_size = self.hop_length * (self.num_frame - 1)

        self.encoder, self.decoder = nn.ModuleList(), nn.ModuleList()
        in_channels, out_channels = 4, num_channels

        for index in range(depth):
            encode = [nn.Conv2d(in_channels, out_channels, (2, 2), (2, 2)), nn.ReLU()]
            self.encoder.append(nn.Sequential(*encode))

            decode = [nn.ConvTranspose2d(out_channels, in_channels, (2, 2), (2, 2)), nn.ReLU()]
            if index == 0:
                decode[-1] = nn.Sigmoid()

            self.decoder.insert(0, nn.Sequential(*decode))

            in_channels = out_channels
            out_channels = groth_rate * out_channels

    def forward(self, mixture):
        input_spec = self.stft(mixture) 
        output_spec = self.estimate_spec(input_spec)

        return self.istft(output_spec)
    
    def estimate_spec(self, spec):
        x = spec
        saved = []
        for encode in self.encoder:
            x = encode(x)
            saved.append(x)

        saved[-1] = None

        for decode, skip in zip(self.decoder, reversed(saved)):
            if skip is not None:
                x = x + skip
            x = decode(x)
            
        return x*spec

    def separate_internal(self, mixture, overlap_ratio=0.5, batch_size=1, device='cpu') -> torch.Tensor:

        self.eval()

        with torch.no_grad():
            trim_length = self.hop_length // 2
            db = SingleTrackSet(mixture, self.chunk_size, trim_length, overlap_ratio)
            separated = []

            for item, mask in DataLoader(db, batch_size):
                res = self(item.to(device))
                res = res * mask.to(device)
                res = res[:, :, trim_length:-trim_length].detach().cpu().transpose(-1, -2).numpy()
                separated.append(res)

        separated = np.concatenate(separated)
        if db.is_overlap:
            output = np.zeros_like(mixture)
            hop_length = db.hop_length
            for i, sep in enumerate(separated):
                to = sep.shape[0]
                if i * hop_length + sep.shape[0] > output.shape[0]:
                    to = sep.shape[0] - (i * hop_length + sep.shape[0] - output.shape[0])
                output[i * hop_length:i * hop_length + to] += sep[:to]
            separated = output

        else:
            separated = np.concatenate(separated, axis=0)

        soundfile.write('temp.wav', separated, 44100)
        return soundfile.read('temp.wav')[0]

    def stft(self, mixture):
        mixture = mixture.view(-1, mixture.shape[-1])

        # note: torch.stft will only return complex tensors
        spec = torch.stft(mixture, n_fft=self.n_fft, hop_length=self.hop_length, return_complex=True)
        spec = torch.view_as_real(spec)

        F, T = spec.shape[-3], spec.shape[-2]
        spec = spec.view(-1, 2, F, T, 2)

        # [B, 2:channel, F, T, 2:complex] => [B, 4, T, F]
        spec = spec.transpose(-1, -3).reshape(-1, 4, T, F)

        
        # drop last
        return spec[:, :, :, :-1]

    def istft(self, spec):
        B, ch, T, _ = spec.shape
        zero_pad = torch.zeros_like(spec[..., :1], requires_grad=True)
        spec = torch.cat([spec, zero_pad], dim=-1)

        # [B, 4, T, F] => [B, 2:ch, 2:complex, T, F] => [B, 2:ch, F, T, 2:complex]
        spec = spec.view(B, 2, 2, T, -1).transpose(-1, -3)

        spec = spec.view(2 * B, -1, T, 2).contiguous()
        spec = torch.view_as_complex(spec)
        target = torch.istft(spec, n_fft=self.n_fft, hop_length=self.hop_length, return_complex=False)
        return target.view(B, 2, -1)

Model DefinitionΒΆ

model = UNetConv2DForMDX(n_fft=256, hop_length=128, num_frame=128, depth=6, num_channels=8, groth_rate=2)

Music Demixing with an Untrained ModelΒΆ

from IPython.display import Audio, display

estimated_dict = model.separate(mus_train[0].audio)
for source in estimated_dict.keys():
    print(source)
    display(Audio(estimated_dict[source].T, rate=44100))
    
vocals
drums
bass
other
from pytorch_lightning import Trainer

train_dataset = RandomMUSDB18TrainDataset(mus_train, sample_size=(128) * (128 - 1))
train_dataloader = DataLoader(train_dataset, batch_size=8)

if torch.cuda.is_available():
    pl_trainer = Trainer(gpus=1, max_epochs=1)
else:
    pl_trainer = Trainer(max_epochs=1)

pl_trainer.fit(model, train_dataloader=train_dataloader)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:530: LightningDeprecationWarning: `trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  | Name         | Type       | Params
--------------------------------------------
0 | sub_net_dict | ModuleDict | 1.4 M 
--------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.603     Total estimated model params size (MB)
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:105: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/wschoi/exit/envs/tutorial-environment/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:326: UserWarning: The number of training samples (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
from IPython.display import Audio, display

estimated_dict = model.separate(mus_train[0].audio)
for source in estimated_dict.keys():
    print(source)
    display(Audio(estimated_dict[source].T, rate=44100))
    
vocals
drums
bass
other