|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from typing import Any |
| 4 | + |
3 | 5 | import numpy as np |
4 | 6 | import pandas as pd |
| 7 | +from pybaselines import Baseline |
5 | 8 | from scipy.signal import find_peaks |
6 | 9 |
|
7 | | -__all__ = ["find_cosmic_rays", "remove_cosmic_rays", "group_spectra_points"] |
| 10 | +__all__ = [ |
| 11 | + "find_cosmic_rays", |
| 12 | + "remove_cosmic_rays", |
| 13 | + "group_spectra_points", |
| 14 | + "baseline", |
| 15 | +] |
8 | 16 |
|
9 | 17 |
|
10 | 18 | def find_cosmic_rays( |
11 | | - spectra: np.ndarray, ignore_region: tuple[int, int] = (200, 400), **kwargs |
| 19 | + spectra: np.ndarray, ignore_region: tuple[int, int] = (200, 400), **kwargs: Any |
12 | 20 | ) -> np.ndarray: |
13 | 21 | """ |
14 | 22 | Find the indices of cosmic rays. |
@@ -45,7 +53,9 @@ def find_cosmic_rays( |
45 | 53 | return np.asarray(idx) |
46 | 54 |
|
47 | 55 |
|
48 | | -def remove_cosmic_rays(df: pd.DataFrame, plot: bool = False, **kwargs) -> pd.DataFrame: |
| 56 | +def remove_cosmic_rays( |
| 57 | + df: pd.DataFrame, plot: bool = False, **kwargs: Any |
| 58 | +) -> pd.DataFrame: |
49 | 59 | """ |
50 | 60 | Process a dataframe by removing all spectra with detected cosmic rays. |
51 | 61 |
|
@@ -91,7 +101,7 @@ def remove_cosmic_rays(df: pd.DataFrame, plot: bool = False, **kwargs) -> pd.Dat |
91 | 101 | return df.iloc[keep_idx] |
92 | 102 |
|
93 | 103 |
|
94 | | -def group_spectra_points(df, multiplier: int) -> pd.DataFrame: |
| 104 | +def group_spectra_points(df: pd.DataFrame, multiplier: int) -> pd.DataFrame: |
95 | 105 | """ |
96 | 106 | Add which point each spectra is from to the multiindex. |
97 | 107 |
|
@@ -121,3 +131,35 @@ def group_spectra_points(df, multiplier: int) -> pd.DataFrame: |
121 | 131 | offset = df.loc[pos, "pt"].max() |
122 | 132 | df["pt"] = df["pt"].astype(int) |
123 | 133 | return df.set_index("pt", append=True) |
| 134 | + |
| 135 | + |
| 136 | +def baseline(spectra: np.ndarray, method: str = "arpls", **params: Any) -> np.ndarray: |
| 137 | + """ |
| 138 | + Calculate the baseline of [many] spectra using pybaselines. |
| 139 | +
|
| 140 | + Parameters |
| 141 | + ---------- |
| 142 | + spectra : array-like ([N], wns) |
| 143 | + The spectra to calculate the baseline of. |
| 144 | + method : str, default: "arpls" |
| 145 | + The pybaselines method name. |
| 146 | + **params: |
| 147 | + Passed to pybaselines |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + baseline : np.ndarray ([N], wns) |
| 152 | + The calculated baselines |
| 153 | + """ |
| 154 | + baseliner = Baseline(np.arange(1340)) |
| 155 | + baseline_func = getattr(baseliner, method) |
| 156 | + |
| 157 | + spectra = np.atleast_2d(spectra) |
| 158 | + if np.issubdtype(spectra.dtype, np.integer): |
| 159 | + spectra = spectra.astype(np.float32) |
| 160 | + |
| 161 | + baselines = np.zeros_like(spectra) |
| 162 | + |
| 163 | + for i, spec in enumerate(spectra): |
| 164 | + baselines[i], w = baseline_func(spec, **params) |
| 165 | + return baselines.squeeze() |
0 commit comments