Skip to content

Commit 4b548b5

Browse files
authored
Merge pull request #6 from Hekstra-Lab/preprocess
FEAT: add baselining
2 parents 456fbf0 + a216843 commit 4b548b5

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ classifiers = [
2525
dynamic = ["version"]
2626
dependencies = [
2727
"scipy",
28+
"pybaselines",
29+
"pentapy",
2830
]
2931

3032
# extras

src/raman_analysis/preprocessing.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import numpy as np
46
import pandas as pd
7+
from pybaselines import Baseline
58
from scipy.signal import find_peaks
69

7-
__all__ = ["find_cosmic_rays", "remove_cosmic_rays", "group_spectra_points"]
10+
__all__ = [
11+
"find_cosmic_rays",
12+
"remove_cosmic_rays",
13+
"group_spectra_points",
14+
"baseline",
15+
]
816

917

1018
def find_cosmic_rays(
11-
spectra: np.ndarray, ignore_region: tuple[int, int] = (200, 400), **kwargs
19+
spectra: np.ndarray, ignore_region: tuple[int, int] = (200, 400), **kwargs: Any
1220
) -> np.ndarray:
1321
"""
1422
Find the indices of cosmic rays.
@@ -45,7 +53,9 @@ def find_cosmic_rays(
4553
return np.asarray(idx)
4654

4755

48-
def remove_cosmic_rays(df: pd.DataFrame, plot: bool = False, **kwargs) -> pd.DataFrame:
56+
def remove_cosmic_rays(
57+
df: pd.DataFrame, plot: bool = False, **kwargs: Any
58+
) -> pd.DataFrame:
4959
"""
5060
Process a dataframe by removing all spectra with detected cosmic rays.
5161
@@ -91,7 +101,7 @@ def remove_cosmic_rays(df: pd.DataFrame, plot: bool = False, **kwargs) -> pd.Dat
91101
return df.iloc[keep_idx]
92102

93103

94-
def group_spectra_points(df, multiplier: int) -> pd.DataFrame:
104+
def group_spectra_points(df: pd.DataFrame, multiplier: int) -> pd.DataFrame:
95105
"""
96106
Add which point each spectra is from to the multiindex.
97107
@@ -121,3 +131,35 @@ def group_spectra_points(df, multiplier: int) -> pd.DataFrame:
121131
offset = df.loc[pos, "pt"].max()
122132
df["pt"] = df["pt"].astype(int)
123133
return df.set_index("pt", append=True)
134+
135+
136+
def baseline(spectra: np.ndarray, method: str = "arpls", **params: Any) -> np.ndarray:
137+
"""
138+
Calculate the baseline of [many] spectra using pybaselines.
139+
140+
Parameters
141+
----------
142+
spectra : array-like ([N], wns)
143+
The spectra to calculate the baseline of.
144+
method : str, default: "arpls"
145+
The pybaselines method name.
146+
**params:
147+
Passed to pybaselines
148+
149+
Returns
150+
-------
151+
baseline : np.ndarray ([N], wns)
152+
The calculated baselines
153+
"""
154+
baseliner = Baseline(np.arange(1340))
155+
baseline_func = getattr(baseliner, method)
156+
157+
spectra = np.atleast_2d(spectra)
158+
if np.issubdtype(spectra.dtype, np.integer):
159+
spectra = spectra.astype(np.float32)
160+
161+
baselines = np.zeros_like(spectra)
162+
163+
for i, spec in enumerate(spectra):
164+
baselines[i], w = baseline_func(spec, **params)
165+
return baselines.squeeze()

0 commit comments

Comments
 (0)