Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 131 additions & 30 deletions spatialscaper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from collections import namedtuple

import librosa
import scipy
import numpy as np
import scipy
import soundfile as sf

# Local application/library specific imports
from .sofa_utils import load_rir_pos, load_pos
from .utils import (
get_label_list,
get_files_list,
new_event_exceeds_max_overlap,
count_leading_zeros_in_period,
generate_trajectory,
db2scale,
traj_2_ir_idx,
Expand All @@ -22,9 +23,10 @@
get_labels,
save_output,
sort_matrix_by_columns,
swap_label,
swap_foa,
swap_mic,
)
from .sofa_utils import load_rir_pos, load_pos


# Sound event classes for DCASE Challenge
__DCASE_SOUND_EVENT_CLASSES__ = {
Expand Down Expand Up @@ -78,18 +80,18 @@

class Scaper:
def __init__(
self,
duration=60,
foreground_dir="",
background_dir="",
rir_dir="",
room="metu",
fmt="mic",
sr=24000,
DCASE_format=True,
max_event_overlap=2,
max_event_dur=10.0,
ref_db=-60,
self,
duration=60,
foreground_dir="",
background_dir="",
rir_dir="",
room="metu",
fmt="mic",
sr=24000,
DCASE_format=True,
max_event_overlap=2,
max_event_dur=10.0,
ref_db=-60,
):
"""
Initializes a Scaper object for generating soundscapes.
Expand Down Expand Up @@ -163,15 +165,15 @@ def add_background(self):
)

def add_event(
self,
label=("choose", []),
source_file=("choose", []),
source_time=("const", 0),
event_time=None,
event_duration=None,
event_position=("choose", ("uniform", None, None)),
snr=("uniform"),
split=None,
self,
label=("choose", []),
source_file=("choose", []),
source_time=("const", 0),
event_time=None,
event_duration=None,
event_position=("choose", ("uniform", None, None)),
snr=("uniform"),
split=None,
):
"""
Adds a foreground event to the soundscape.
Expand Down Expand Up @@ -263,7 +265,7 @@ def add_event(
)

def define_event_onset_time(
self, event_time, event_duration, other_events, max_overlap, increment
self, event_time, event_duration, other_events, max_overlap, increment
):
"""
Recursively finds a start time for an event that doesn't exceed the maximum overlap with other events.
Expand All @@ -288,7 +290,7 @@ def define_event_onset_time(

# Check if the selected time overlaps with more than max_overlap events
if new_event_exceeds_max_overlap(
random_start_time, event_duration, other_events, max_overlap, increment
random_start_time, event_duration, other_events, max_overlap, increment
):
# If it does overlap, recursively try again
return self.define_event_onset_time(
Expand Down Expand Up @@ -509,7 +511,7 @@ def synthesize_events_and_labels(self, all_irs, all_ir_xyzs, out_audio):

# add to out_audio
onsamp = int(event.event_time * self.sr)
out_audio[onsamp : onsamp + len(xS)] += xS
out_audio[onsamp: onsamp + len(xS)] += xS

# generate ground truth
time_grid = get_timegrid(
Expand All @@ -527,8 +529,8 @@ def synthesize_events_and_labels(self, all_irs, all_ir_xyzs, out_audio):
)
labels[:, 0] = labels[:, 0] + int(event.event_time * self.label_rate)
xS = xS[
: int(time_grid[-1] * self.sr)
] # trim audio signal to exactly match labels
: int(time_grid[-1] * self.sr)
] # trim audio signal to exactly match labels
all_labels.append(labels)

labels = sort_matrix_by_columns(np.vstack(all_labels))
Expand Down Expand Up @@ -579,3 +581,102 @@ def generate(self, audiopath, labelpath):

# save output
save_output(audiopath, labelpath, out_audio, self.sr, labels)


class ScaperAug:
def __init__(self, ss_dir, aug_dir, fmt, method):
"""
Initialize an augmented soundscape.

Args:
ss_dir: soundscape directory
aug_dir: new folder to store augmented soundscapes
fmt: format of original soundscape
method: {'swap', 'rotate', 'mask', 'remix'}
augmentation method to use, all based on https://arxiv.org/abs/2101.02919
'swap':
channel swapping
'rotate':
soundscape rotation
'mask':
random time frequency masking
'remix':
time domain remixing
"""

self.ss_dir = ss_dir
self.aug_dir = aug_dir
self.format = fmt

# decide which augmentation method to use
if method == 'swap': # channel swapping
self.method = self.swap_channels
elif method == 'rotate': # soundscape rotation
self.method = self.rotation
elif method == 'mask': # time frequency masking
self.method = self.tf_masking
elif method == 'remix': # time domain remixing
self.method = self.remixing
else:
raise NotImplementedError("The augmentation method is not found.")

def augment(self):
"""
Augment audio files and modify labels.
"""

in_folder = self.ss_dir
in_data = os.path.join(in_folder, self.format)
in_label = os.path.join(in_folder, 'labels')
out_folder = self.aug_dir

print("Start augmenting audio in in_folder {} to in_folder {}".format(in_folder, out_folder))

for file_cnt, file in enumerate(os.listdir(in_data)):
filename = file.split('.')[0]
data_file = os.path.join(in_data, file)
label_file = os.path.join(in_label, filename + '.csv')
data_file_aug = os.path.join(out_folder, self.format, filename)
label_file_aug = os.path.join(out_folder, 'labels', filename)

# read and augment audio related data
data, fs = sf.read(data_file)
# read and modify label data
label = np.genfromtxt(label_file, dtype=int, delimiter=',')
data_aug, fs_aug, label_aug = self.method(data, fs, label)
for i, (d, f, l) in enumerate(zip(data_aug, fs_aug, label_aug)):
save_output(data_file_aug + f'_{i}', label_file_aug + f'_{i}', d, f, l, self.format)

print("Completed augmentation and saved results in {}".format(out_folder))

def swap_channels(self, data, fs, label):
"""
Args:
data: multichannel audio data
fs: sample rate
label: [sample ID, class ID, source ID, azimuth, elevation, radius]

Returns:
Three list of augmented files.

"""
if self.format == 'mic':
data_aug = swap_mic(data)
elif self.format == 'foa':
data_aug = swap_foa(data)
else:
raise NotImplementedError("Format not supported.")

label_aug = swap_label(label)
fs_aug = [fs] * 7

return data_aug, fs_aug, label_aug

def rotation(self, data, fs, label):
pass

def tf_masking(self, data, fs, label):
pass
Comment on lines +678 to +679

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Pytorch's built-in TimeMasking and FrequencyMasking transforms to make completing the tf_masking() method easier


def remixing(self, data, fs, label):
pass
86 changes: 86 additions & 0 deletions spatialscaper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,89 @@ def count_leading_zeros_in_period(frequency_hz):
for i, digit in enumerate(fractional_part)
if digit == "0" and "1" not in fractional_part[: i + 1]
)


def swap_mic(audio):
"""

Args:
audio: multichannel mic format audio
Returns:
A list of 7 types of mic channel swapping.

"""
# separate channels
chan_1 = audio[:, 0]
chan_2 = audio[:, 1]
chan_3 = audio[:, 2]
chan_4 = audio[:, 3]

# swapping columns
audio_aug = []
audio_aug.append(np.dstack((chan_2, chan_4, chan_1, chan_3)).squeeze())
audio_aug.append(np.dstack((chan_4, chan_2, chan_3, chan_1)).squeeze())
audio_aug.append(np.dstack((chan_2, chan_1, chan_4, chan_3)).squeeze())
audio_aug.append(np.dstack((chan_3, chan_1, chan_4, chan_2)).squeeze())
audio_aug.append(np.dstack((chan_1, chan_3, chan_2, chan_4)).squeeze())
audio_aug.append(np.dstack((chan_4, chan_3, chan_2, chan_1)).squeeze())
audio_aug.append(np.dstack((chan_3, chan_4, chan_1, chan_2)).squeeze())

return audio_aug


def swap_foa(audio):
"""

Args:
audio: multichannel foa format audio

Returns:
a list of 7 types of channel swapping foa audio

"""
# separate channels
chan_1 = audio[:, 0]
chan_2 = audio[:, 1]
chan_3 = audio[:, 2]
chan_4 = audio[:, 3]

# swapping columns
audio_aug = []
audio_aug.append(np.dstack((chan_1, -chan_4, -chan_3, chan_2)).squeeze())
audio_aug.append(np.dstack((chan_1, -chan_4, chan_3, -chan_2)).squeeze())
audio_aug.append(np.dstack((chan_1, -chan_2, -chan_3, chan_4)).squeeze())
audio_aug.append(np.dstack((chan_1, chan_4, -chan_3, chan_2)).squeeze())
audio_aug.append(np.dstack((chan_1, chan_4, chan_3, chan_2)).squeeze())
audio_aug.append(np.dstack((chan_1, -chan_2, chan_3, -chan_4)).squeeze())
audio_aug.append(np.dstack((chan_1, chan_2, -chan_3, -chan_4)).squeeze())

return audio_aug

def swap_label(label):
"""

Args:
label: original label

Returns:
labels after channel swapping

"""
frame = label[:, 0]
id = label[:, 1]
source = label[:, 2]
azimuth = label[:, 3]
elevation = label[:, 4]
distance = label[:, 5] if label.shape[1] > 5 else np.full(data.shape[0], None)

# transform azimuth and elevation
label_aug = []
label_aug.append(np.dstack((frame, id, source, azimuth - 90, -elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, -azimuth - 90, elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, -azimuth, -elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, azimuth + 90, -elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, -azimuth + 90, elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, azimuth + 180, elevation, distance)).squeeze())
label_aug.append(np.dstack((frame, id, source, -azimuth + 180, -elevation, distance)).squeeze())

return label_aug