Skip to content
3 changes: 3 additions & 0 deletions specparam/compare/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Model comparison sub-module."""

from .compare import ModelComparison
118 changes: 118 additions & 0 deletions specparam/compare/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Model comparison object."""

from copy import deepcopy

from specparam.models import SpectralModel
from specparam.plts.compare import plot_model_comparison
from specparam.reports.strings import gen_model_comparison_str
from specparam.modutils.docs import (copy_func_docstring_drop_first,
docs_get_section, replace_docstring_sections)

###################################################################################################
###################################################################################################

class ModelComparison():
"""Model a power spectrum with multiple models.

Parameters
----------
models : list of SpectralModel
Model definitions to fit and compare.
"""

def __init__(self, models=None):
"""Initialize model comparison object."""

self.models = []
if models:
self.add_models(models)


def __len__(self):
"""Define length of object as the number of defined models."""

return len(self.models)


def __iter__(self):
"""Define iteration as stepping across models within the object."""

for model in self.models:
yield(model)


def copy(self):
"""Return a copy of the current object."""

return deepcopy(self)


@replace_docstring_sections([docs_get_section(SpectralModel.fit.__doc__, 'Parameters')])
def fit(self, freqs=None, data=None, freq_range=None, prechecks=True):
"""Fit models to a power spectrum.

Parameters
----------
% copied in from SpectralModel object
"""

self.models[0].fit(freqs, data, freq_range, prechecks)
for model in self.models[1:]:
model.fit(prechecks=False)


@replace_docstring_sections([docs_get_section(SpectralModel.report.__doc__, 'Parameters')])
def report(self, freqs=None, data=None, freq_range=None,
plt_log=False, plot_full_range=False, **plot_kwargs):
"""Run model fit, and display a report, which includes a plot, and printed results.

Parameters
----------
% copied in from SpectralModel object
"""

self.fit(freqs, data, freq_range)
self.print('comparison')
self.plot(plt_log=plt_log,
freqs=freqs if plot_full_range else plot_kwargs.pop('plot_freqs', None),
power_spectrum=power_spectrum if \
plot_full_range else plot_kwargs.pop('plot_power_spectrum', None),
freq_range=plot_kwargs.pop('plot_freq_range', None),
**plot_kwargs)


def add_models(self, models, clear=False):
"""Add model definitions.

Parameters
----------
models : list of SpectralModel
Model definitions to add to the object.
clear : bool, optional, default: False
Whether to clear the object of previous model definitions before adding.
"""

for model in models:
self.models.append(deepcopy(model))

if self.models:
self.data = deepcopy(models[0].data)
for model in self.models:
model.data = self.data
model.algorithm._reset_subobjects(data=self.data)


@copy_func_docstring_drop_first(plot_model_comparison)
def plot(self, ax=None, **plot_kwargs):

plot_model_comparison(self, ax=ax, **plot_kwargs)


def print(self, info='comparison'):
"""Print out result information."""

if info == 'comparison':
print(gen_model_comparison_str(self))
else:
for model in self.models:
model.print(info)
16 changes: 11 additions & 5 deletions specparam/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def __len__(self):
return len(self.labels)


def __iter__(self):
"""Make object iterable, across individual metrics."""

return iter(self.metrics)


def __getitem__(self, label):
"""Index into the object based on metric label.

Expand All @@ -43,11 +49,11 @@ def __getitem__(self, label):
Label of the metric to access.
"""

for ind, clabel in enumerate(self.labels):
if label == clabel:
return self.metrics[ind]

raise ValueError('Requested label not found.')
for metric in self:
if metric.label == label:
return metric
else:
raise ValueError('Requested label not found.')


def add_metric(self, metric):
Expand Down
7 changes: 7 additions & 0 deletions specparam/modes/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def __init__(self, aperiodic, periodic, model=None):
self.model = model


@property
def label(self):
"""Define label for the current set of modes."""

return 'ap-{:s}_pe-{:s}'.format(self.aperiodic.name, self.periodic.name)


def get_modes(self):
"""Get the modes definition.

Expand Down
31 changes: 31 additions & 0 deletions specparam/plts/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Plots for model comparison."""

from specparam.plts import plot_spectra
from specparam.plts.utils import check_ax
from specparam.plts.settings import PLT_FIGSIZES

###################################################################################################
###################################################################################################

def plot_model_comparison(modelcomp, ax=None, **plot_kwargs):
"""Plot and compare multiple model fits.

Parameters
----------
modelcomp : ModelComparison
Model comparison object.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

import matplotlib as mpl
cmap = mpl.colormaps['Set1']

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
plot_spectra(modelcomp.data.freqs, modelcomp.data.power_spectrum,
label='Original Spectrum', lw=3, color='black', ax=ax)
for ind, model in enumerate(modelcomp.models):
plot_spectra(modelcomp.data.freqs, model.results.model.modeled_spectrum,
color=cmap.colors[ind], alpha=0.65, label=model.modes.label, ax=ax)
44 changes: 44 additions & 0 deletions specparam/reports/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,50 @@ def gen_event_results_str(event, concise=False):
return _format(str_lst, concise)


## MODEL COMPARISONS

def gen_model_comparison_str(models, concise=False):
"""Generate a string representation of model comparisons.

Parameters
----------
models : ModelComparison
Set of model objects to access results from.
concise : bool, optional, default: False
Whether to generate a concise version of the string.

Returns
-------
str
Formatted string of model comparisons.
"""

str_lst = [
'SPECTRUM MODEL COMPARISON',
'',
]

for ind, model in enumerate(models.models):
str_lst.append('Model {}: '.format(ind + 1) + model.modes.label)

metric_str = 'Model metrics - '
for m_ind in range(len(model.results.metrics)):
metric_str = metric_str + '{}: {:1.4f} '.format(\
model.results.metrics.flabels[m_ind],
list(model.results.metrics.results.values())[m_ind])
str_lst.append(metric_str)

str_lst.append(\
'Model size: {:d} parameters [{:d} ap + {:d} pe ({:d} peak(s) x {:d})].'.format(\
model.results.n_params, model.modes.aperiodic.n_params,
model.results.n_peaks * model.modes.periodic.n_params,
model.results.n_peaks, model.modes.periodic.n_params))

str_lst.append('')

return _format(str_lst, concise)


## HELPER SUB-FUNCTIONS FOR MODEL REPORT STRINGS

def _report_str_algo(model):
Expand Down
Empty file.
47 changes: 47 additions & 0 deletions specparam/tests/compare/test_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test functions for specparam.compare.compare."""

from specparam import SpectralModel

from specparam.compare.compare import *

###################################################################################################
###################################################################################################

def test_model_comparison():

mc1 = ModelComparison()
assert isinstance(mc1, ModelComparison)

# Test initializing with some models
models = [SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'),
SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian')]
mc2 = ModelComparison(models)
assert isinstance(mc2, ModelComparison)
assert len(mc2) == len(models)
for model in mc2:
assert model

# Test data gets linked / results do not
for ind, model in enumerate(mc2.models[:-1]):
model.data is mc2.models[ind+1].data
model.results is not mc2.models[ind+1].results

def test_model_comparison_fit(tdata):

mc = ModelComparison(\
[SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'),
SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'),
])

mc.fit(tdata.freqs, tdata.get_data('full', 'linear'))
for model in mc:
assert model.results.has_model

def test_model_comparison_reporting(tmodelcomp, tdata, skip_if_no_mpl):

mc = tmodelcomp.copy()

mc.report(tdata.freqs, tdata.get_data('full', 'linear'))

mc.print()
mc.plot()
7 changes: 6 additions & 1 deletion specparam/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tdata2dt, get_tdata3d,
get_tfm, get_tfm2, get_tfg, get_tfg2, get_tft, get_tfe,
get_tbands, get_tresults, get_tmodes, get_tdocstring)
get_tmodelcomp, get_tbands, get_tresults, get_tmodes,
get_tdocstring)
from specparam.tests.tsettings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH,
TEST_REPORTS_PATH, TEST_PLOTS_PATH)

Expand Down Expand Up @@ -99,6 +100,10 @@ def tft():
def tfe():
yield get_tfe()

@pytest.fixture(scope='session')
def tmodelcomp():
yield get_tmodelcomp()

@pytest.fixture(scope='session')
def tbands():
yield get_tbands()
Expand Down
4 changes: 4 additions & 0 deletions specparam/tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def test_metrics_obj(tfm):
metrics.compute_metrics(tfm.data, tfm.results)
assert isinstance(metrics.results, dict)

# Check iteration
for metric in metrics:
assert metric

# Check indexing
met_out = metrics['error_mae']
assert isinstance(met_out, Metric)
Expand Down
1 change: 1 addition & 0 deletions specparam/tests/modes/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_modes():
assert modes
assert isinstance(modes.aperiodic, Mode)
assert isinstance(modes.periodic, Mode)
assert modes.label
modes.print()

def test_modes_gets():
Expand Down
14 changes: 14 additions & 0 deletions specparam/tests/plts/test_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Tests for specparam.plts.compare."""

from specparam.tests.tutils import plot_test
from specparam.tests.tsettings import TEST_PLOTS_PATH

from specparam.plts.compare import *

###################################################################################################
###################################################################################################

@plot_test
def test_plot_model_comparison(tmodelcomp, skip_if_no_mpl):

plot_model_comparison(tmodelcomp)
6 changes: 6 additions & 0 deletions specparam/tests/reports/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def test_no_model_str():

assert _no_model_str()

## MODEL COMPARISONS

def test_gen_model_comparison_str(tmodelcomp):

assert gen_model_comparison_str(tmodelcomp)

## UTILITIES

def test_format():
Expand Down
15 changes: 15 additions & 0 deletions specparam/tests/tdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from specparam.data.stores import FitResults
from specparam.models import (SpectralModel, SpectralGroupModel,
SpectralTimeModel, SpectralTimeEventModel)
from specparam.compare import ModelComparison
from specparam.sim.params import param_sampler
from specparam.sim.sim import sim_power_spectrum, sim_group_power_spectra, sim_spectrogram

Expand Down Expand Up @@ -137,6 +138,20 @@ def get_tfe():

return tfe

## TEST MODEL COMPARISON OBJECTS

def get_tmodelcomp():
"""Get a model comparison object, with fit power spectra, for testing."""

modelcomp = ModelComparison([
SpectralModel(aperiodic_mode='fixed', periodic_mode='gaussian'),
SpectralModel(aperiodic_mode='knee', periodic_mode='gaussian'),
])

modelcomp.fit(*sim_power_spectrum(*default_spectrum_params()))

return modelcomp

## TEST OTHER OBJECTS

def get_tbands():
Expand Down
Loading