diff --git a/spatialscaper/core.py b/spatialscaper/core.py index 5046434..267329a 100644 --- a/spatialscaper/core.py +++ b/spatialscaper/core.py @@ -3,15 +3,16 @@ from collections import namedtuple import librosa -import scipy import numpy as np +import scipy +import soundfile as sf # Local application/library specific imports +from .sofa_utils import load_rir_pos, load_pos from .utils import ( get_label_list, get_files_list, new_event_exceeds_max_overlap, - count_leading_zeros_in_period, generate_trajectory, db2scale, traj_2_ir_idx, @@ -22,9 +23,10 @@ get_labels, save_output, sort_matrix_by_columns, + swap_label, + swap_foa, + swap_mic, ) -from .sofa_utils import load_rir_pos, load_pos - # Sound event classes for DCASE Challenge __DCASE_SOUND_EVENT_CLASSES__ = { @@ -78,18 +80,18 @@ class Scaper: def __init__( - self, - duration=60, - foreground_dir="", - background_dir="", - rir_dir="", - room="metu", - fmt="mic", - sr=24000, - DCASE_format=True, - max_event_overlap=2, - max_event_dur=10.0, - ref_db=-60, + self, + duration=60, + foreground_dir="", + background_dir="", + rir_dir="", + room="metu", + fmt="mic", + sr=24000, + DCASE_format=True, + max_event_overlap=2, + max_event_dur=10.0, + ref_db=-60, ): """ Initializes a Scaper object for generating soundscapes. @@ -163,15 +165,15 @@ def add_background(self): ) def add_event( - self, - label=("choose", []), - source_file=("choose", []), - source_time=("const", 0), - event_time=None, - event_duration=None, - event_position=("choose", ("uniform", None, None)), - snr=("uniform"), - split=None, + self, + label=("choose", []), + source_file=("choose", []), + source_time=("const", 0), + event_time=None, + event_duration=None, + event_position=("choose", ("uniform", None, None)), + snr=("uniform"), + split=None, ): """ Adds a foreground event to the soundscape. @@ -263,7 +265,7 @@ def add_event( ) def define_event_onset_time( - self, event_time, event_duration, other_events, max_overlap, increment + self, event_time, event_duration, other_events, max_overlap, increment ): """ Recursively finds a start time for an event that doesn't exceed the maximum overlap with other events. @@ -288,7 +290,7 @@ def define_event_onset_time( # Check if the selected time overlaps with more than max_overlap events if new_event_exceeds_max_overlap( - random_start_time, event_duration, other_events, max_overlap, increment + random_start_time, event_duration, other_events, max_overlap, increment ): # If it does overlap, recursively try again return self.define_event_onset_time( @@ -509,7 +511,7 @@ def synthesize_events_and_labels(self, all_irs, all_ir_xyzs, out_audio): # add to out_audio onsamp = int(event.event_time * self.sr) - out_audio[onsamp : onsamp + len(xS)] += xS + out_audio[onsamp: onsamp + len(xS)] += xS # generate ground truth time_grid = get_timegrid( @@ -527,8 +529,8 @@ def synthesize_events_and_labels(self, all_irs, all_ir_xyzs, out_audio): ) labels[:, 0] = labels[:, 0] + int(event.event_time * self.label_rate) xS = xS[ - : int(time_grid[-1] * self.sr) - ] # trim audio signal to exactly match labels + : int(time_grid[-1] * self.sr) + ] # trim audio signal to exactly match labels all_labels.append(labels) labels = sort_matrix_by_columns(np.vstack(all_labels)) @@ -579,3 +581,102 @@ def generate(self, audiopath, labelpath): # save output save_output(audiopath, labelpath, out_audio, self.sr, labels) + + +class ScaperAug: + def __init__(self, ss_dir, aug_dir, fmt, method): + """ + Initialize an augmented soundscape. + + Args: + ss_dir: soundscape directory + aug_dir: new folder to store augmented soundscapes + fmt: format of original soundscape + method: {'swap', 'rotate', 'mask', 'remix'} + augmentation method to use, all based on https://arxiv.org/abs/2101.02919 + 'swap': + channel swapping + 'rotate': + soundscape rotation + 'mask': + random time frequency masking + 'remix': + time domain remixing + """ + + self.ss_dir = ss_dir + self.aug_dir = aug_dir + self.format = fmt + + # decide which augmentation method to use + if method == 'swap': # channel swapping + self.method = self.swap_channels + elif method == 'rotate': # soundscape rotation + self.method = self.rotation + elif method == 'mask': # time frequency masking + self.method = self.tf_masking + elif method == 'remix': # time domain remixing + self.method = self.remixing + else: + raise NotImplementedError("The augmentation method is not found.") + + def augment(self): + """ + Augment audio files and modify labels. + """ + + in_folder = self.ss_dir + in_data = os.path.join(in_folder, self.format) + in_label = os.path.join(in_folder, 'labels') + out_folder = self.aug_dir + + print("Start augmenting audio in in_folder {} to in_folder {}".format(in_folder, out_folder)) + + for file_cnt, file in enumerate(os.listdir(in_data)): + filename = file.split('.')[0] + data_file = os.path.join(in_data, file) + label_file = os.path.join(in_label, filename + '.csv') + data_file_aug = os.path.join(out_folder, self.format, filename) + label_file_aug = os.path.join(out_folder, 'labels', filename) + + # read and augment audio related data + data, fs = sf.read(data_file) + # read and modify label data + label = np.genfromtxt(label_file, dtype=int, delimiter=',') + data_aug, fs_aug, label_aug = self.method(data, fs, label) + for i, (d, f, l) in enumerate(zip(data_aug, fs_aug, label_aug)): + save_output(data_file_aug + f'_{i}', label_file_aug + f'_{i}', d, f, l, self.format) + + print("Completed augmentation and saved results in {}".format(out_folder)) + + def swap_channels(self, data, fs, label): + """ + Args: + data: multichannel audio data + fs: sample rate + label: [sample ID, class ID, source ID, azimuth, elevation, radius] + + Returns: + Three list of augmented files. + + """ + if self.format == 'mic': + data_aug = swap_mic(data) + elif self.format == 'foa': + data_aug = swap_foa(data) + else: + raise NotImplementedError("Format not supported.") + + label_aug = swap_label(label) + fs_aug = [fs] * 7 + + return data_aug, fs_aug, label_aug + + def rotation(self, data, fs, label): + pass + + def tf_masking(self, data, fs, label): + pass + + def remixing(self, data, fs, label): + pass diff --git a/spatialscaper/utils.py b/spatialscaper/utils.py index 5799ad3..51a5f25 100644 --- a/spatialscaper/utils.py +++ b/spatialscaper/utils.py @@ -777,3 +777,89 @@ def count_leading_zeros_in_period(frequency_hz): for i, digit in enumerate(fractional_part) if digit == "0" and "1" not in fractional_part[: i + 1] ) + + +def swap_mic(audio): + """ + + Args: + audio: multichannel mic format audio + Returns: + A list of 7 types of mic channel swapping. + + """ + # separate channels + chan_1 = audio[:, 0] + chan_2 = audio[:, 1] + chan_3 = audio[:, 2] + chan_4 = audio[:, 3] + + # swapping columns + audio_aug = [] + audio_aug.append(np.dstack((chan_2, chan_4, chan_1, chan_3)).squeeze()) + audio_aug.append(np.dstack((chan_4, chan_2, chan_3, chan_1)).squeeze()) + audio_aug.append(np.dstack((chan_2, chan_1, chan_4, chan_3)).squeeze()) + audio_aug.append(np.dstack((chan_3, chan_1, chan_4, chan_2)).squeeze()) + audio_aug.append(np.dstack((chan_1, chan_3, chan_2, chan_4)).squeeze()) + audio_aug.append(np.dstack((chan_4, chan_3, chan_2, chan_1)).squeeze()) + audio_aug.append(np.dstack((chan_3, chan_4, chan_1, chan_2)).squeeze()) + + return audio_aug + + +def swap_foa(audio): + """ + + Args: + audio: multichannel foa format audio + + Returns: + a list of 7 types of channel swapping foa audio + + """ + # separate channels + chan_1 = audio[:, 0] + chan_2 = audio[:, 1] + chan_3 = audio[:, 2] + chan_4 = audio[:, 3] + + # swapping columns + audio_aug = [] + audio_aug.append(np.dstack((chan_1, -chan_4, -chan_3, chan_2)).squeeze()) + audio_aug.append(np.dstack((chan_1, -chan_4, chan_3, -chan_2)).squeeze()) + audio_aug.append(np.dstack((chan_1, -chan_2, -chan_3, chan_4)).squeeze()) + audio_aug.append(np.dstack((chan_1, chan_4, -chan_3, chan_2)).squeeze()) + audio_aug.append(np.dstack((chan_1, chan_4, chan_3, chan_2)).squeeze()) + audio_aug.append(np.dstack((chan_1, -chan_2, chan_3, -chan_4)).squeeze()) + audio_aug.append(np.dstack((chan_1, chan_2, -chan_3, -chan_4)).squeeze()) + + return audio_aug + +def swap_label(label): + """ + + Args: + label: original label + + Returns: + labels after channel swapping + + """ + frame = label[:, 0] + id = label[:, 1] + source = label[:, 2] + azimuth = label[:, 3] + elevation = label[:, 4] + distance = label[:, 5] if label.shape[1] > 5 else np.full(data.shape[0], None) + + # transform azimuth and elevation + label_aug = [] + label_aug.append(np.dstack((frame, id, source, azimuth - 90, -elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, -azimuth - 90, elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, -azimuth, -elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, azimuth + 90, -elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, -azimuth + 90, elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, azimuth + 180, elevation, distance)).squeeze()) + label_aug.append(np.dstack((frame, id, source, -azimuth + 180, -elevation, distance)).squeeze()) + + return label_aug \ No newline at end of file