diff --git a/Cargo.toml b/Cargo.toml index 012dc2a..a022415 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ thiserror = "1.0" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" rand = "0.8" +nalgebra = "0.32" [dev-dependencies] approx = "0.5" diff --git a/python/eo_processor/__init__.py b/python/eo_processor/__init__.py index 76e6bad..249a562 100644 --- a/python/eo_processor/__init__.py +++ b/python/eo_processor/__init__.py @@ -47,7 +47,7 @@ binary_erosion as _binary_erosion, binary_opening as _binary_opening, binary_closing as _binary_closing, - detect_breakpoints as _detect_breakpoints, + bfast_monitor as _bfast_monitor, complex_classification as _complex_classification, random_forest_predict as _random_forest_predict, random_forest_train as _random_forest_train, @@ -131,7 +131,7 @@ "binary_erosion", "binary_opening", "binary_closing", - "detect_breakpoints", + "bfast_monitor", "complex_classification", "haralick_features", "random_forest_predict", @@ -164,11 +164,27 @@ def random_forest_predict(model_json, features): return _random_forest_predict(model_json, features) -def detect_breakpoints(stack, dates, threshold): +def bfast_monitor( + stack, + dates, + history_start_date, + monitor_start_date, + order=1, + h=0.25, + alpha=0.05, +): """ - Scaffold for a time-series breakpoint detection workflow (e.g., BFAST-like). + BFAST Monitor workflow for change detection. """ - return _detect_breakpoints(stack, dates, threshold) + return _bfast_monitor( + stack, + dates, + history_start_date, + monitor_start_date, + order, + h, + alpha, + ) def complex_classification(blue, green, red, nir, swir1, swir2, temp): diff --git a/python/eo_processor/__init__.pyi b/python/eo_processor/__init__.pyi index 41687e4..26640ed 100644 --- a/python/eo_processor/__init__.pyi +++ b/python/eo_processor/__init__.pyi @@ -161,4 +161,15 @@ def binary_closing( input: NDArray[np.uint8], kernel_size: int = ... ) -> NDArray[np.uint8]: ... +# Workflows +def bfast_monitor( + stack: NumericArray, + dates: Sequence[int], + history_start_date: int, + monitor_start_date: int, + order: int = ..., + h: float = ..., + alpha: float = ..., +) -> NDArray[np.float64]: ... + # Raises ValueError if p < 1.0 diff --git a/src/lib.rs b/src/lib.rs index c2d9ef3..9184b3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ pub enum CoreError { InvalidArgument(String), #[error("Computation error: {0}")] ComputationError(String), + #[error("Not enough data: {0}")] + NotEnoughData(String), } impl From for PyErr { @@ -27,6 +29,7 @@ impl From for PyErr { match err { CoreError::InvalidArgument(msg) => PyValueError::new_err(msg), CoreError::ComputationError(msg) => PyValueError::new_err(msg), + CoreError::NotEnoughData(msg) => PyValueError::new_err(msg), } } } @@ -95,7 +98,7 @@ fn _core(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(morphology::binary_closing, m)?)?; // --- Workflows --- - m.add_function(wrap_pyfunction!(workflows::detect_breakpoints, m)?)?; + m.add_function(wrap_pyfunction!(workflows::bfast_monitor, m)?)?; m.add_function(wrap_pyfunction!(workflows::complex_classification, m)?)?; // --- Texture --- diff --git a/src/texture.rs b/src/texture.rs index d6edaf2..ae68d9f 100644 --- a/src/texture.rs +++ b/src/texture.rs @@ -175,4 +175,4 @@ pub fn haralick_features_py( homogeneity_out.into_pyarray(py).to_owned(), entropy_out.into_pyarray(py).to_owned(), )) -} \ No newline at end of file +} diff --git a/src/workflows.rs b/src/workflows.rs index 9e71d5c..251b87d 100644 --- a/src/workflows.rs +++ b/src/workflows.rs @@ -1,17 +1,231 @@ use crate::CoreError; +use nalgebra::{DMatrix, DVector}; use ndarray::{Axis, IxDyn, Zip}; use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; use pyo3::prelude::*; use rayon::prelude::*; -// --- 1. Iterative Time-Series Fitter Example --- +const TWO_PI: f64 = 2.0 * std::f64::consts::PI; + +// --- 1. BFAST Monitor Workflow --- + +/// Represents the fitted harmonic model. +struct HarmonicModel { + coefficients: DVector, + sigma: f64, +} + +/// Constructs the design matrix for a harmonic model. +/// +/// # Arguments +/// +/// * `dates` - A slice of fractional years. +/// * `order` - The order of the harmonic model (e.g., 1 for one sine/cosine pair). +/// +/// # Returns +/// +/// A 2D array representing the design matrix `X`. +fn build_design_matrix(dates: &[f64], order: usize) -> DMatrix { + let n = dates.len(); + let num_coeffs = 2 * order + 2; // intercept, trend, and sin/cos pairs + let mut x = DMatrix::::zeros(n, num_coeffs); + + for i in 0..n { + let t = dates[i]; + x[(i, 0)] = 1.0; // Intercept + x[(i, 1)] = t; // Trend + for j in 1..=order { + let freq = TWO_PI * j as f64 * t; + x[(i, 2 * j)] = freq.cos(); + x[(i, 2 * j + 1)] = freq.sin(); + } + } + x +} + +/// Fits a harmonic model to the stable history period using Ordinary Least Squares (OLS). +fn fit_harmonic_model(y: &[f64], dates: &[f64], order: usize) -> Result { + if y.len() < (2 * order + 2) { + return Err(CoreError::NotEnoughData( + "Not enough historical data to fit model".to_string(), + )); + } + + let y_vec = DVector::from_vec(y.to_vec()); + let x = build_design_matrix(dates, order); + + let decomp = x.clone().svd(true, true); + let coeffs = decomp.solve(&y_vec, 1e-10).map_err(|e| { + CoreError::ComputationError(format!("Failed to solve OLS with nalgebra: {}", e)) + })?; + + let y_pred = &x * &coeffs; + let residuals = &y_vec - &y_pred; + let sum_sq_err = residuals.iter().map(|&r| r * r).sum::(); + let df = (y.len() - (2 * order + 2)) as f64; + if df <= 0.0 { + return Err(CoreError::ComputationError( + "Degrees of freedom is non-positive".to_string(), + )); + } + let sigma = (sum_sq_err / df).sqrt(); + + Ok(HarmonicModel { + coefficients: coeffs, + sigma, + }) +} + +/// Predicts values for the monitoring period based on the fitted model. +fn predict_harmonic_model(model: &HarmonicModel, dates: &[f64], order: usize) -> DVector { + let x_mon = build_design_matrix(dates, order); + &x_mon * &model.coefficients +} + +/// Detects a break using the OLS-MOSUM process. +fn detect_mosum_break( + y_monitor: &[f64], + y_pred: &DVector, + monitor_dates: &[f64], + hist_len: usize, + sigma: f64, + h: f64, + alpha: f64, +) -> (f64, f64) { + if y_monitor.is_empty() { + return (0.0, 0.0); + } + + let n_hist = hist_len as f64; + let window_size = (h * n_hist).floor() as usize; + + let residuals: Vec = y_monitor + .iter() + .zip(y_pred.iter()) + .map(|(obs, pred)| obs - pred) + .collect(); + + let mut cusum = vec![0.0; residuals.len() + 1]; + for i in 0..residuals.len() { + cusum[i + 1] = cusum[i] + residuals[i]; + } + + // We can only start calculating MOSUM after `window_size` observations + if residuals.len() < window_size { + return (0.0, 0.0); + } + + let mosum_process: Vec = (window_size..residuals.len()) + .map(|i| cusum[i] - cusum[i - window_size]) + .collect(); + + let standardizer = sigma * n_hist.sqrt(); + let standardized_mosum: Vec = mosum_process + .iter() + .map(|&m| (m / standardizer).abs()) + .collect(); + + // Simplified critical boundary based on a lookup for alpha=0.05 and h=0.25 + // A full implementation would use a precomputed table or a more complex calculation. + let critical_value = if alpha <= 0.05 { 1.36 } else { 1.63 }; // Approximations + + for (i, &mosum_val) in standardized_mosum.iter().enumerate() { + // The index k starts from 1 for the monitoring period + let k = (i + 1) as f64; + let boundary = critical_value * (1.0 + k / n_hist).sqrt(); + + if mosum_val > boundary { + let break_idx = i + window_size; + let magnitude = (y_monitor[break_idx] - y_pred[break_idx]).abs(); + return (monitor_dates[break_idx], magnitude); + } + } + + (0.0, 0.0) // No break detected +} + +/// Converts integer dates (YYYYMMDD) to fractional years. +fn dates_to_frac_years(dates: &[i64]) -> Vec { + dates + .iter() + .map(|&date| { + let year = (date / 10000) as f64; + let month = ((date % 10000) / 100) as f64; + let day = (date % 100) as f64; + // Simple approximation + year + (month - 1.0) / 12.0 + (day - 1.0) / 365.25 + }) + .collect() +} + +// This is the main logic function that runs for each pixel. +fn run_bfast_monitor_per_pixel( + pixel_ts: &[f64], + dates: &[f64], + history_start: f64, + monitor_start: f64, + order: usize, + h: f64, // h parameter for MOSUM window size + alpha: f64, // Significance level +) -> (f64, f64) { + // 1. Find the indices for the history and monitoring periods + let history_indices: Vec = dates + .iter() + .enumerate() + .filter(|(_, &d)| d >= history_start && d < monitor_start) + .map(|(i, _)| i) + .collect(); + + let monitor_indices: Vec = dates + .iter() + .enumerate() + .filter(|(_, &d)| d >= monitor_start) + .map(|(i, _)| i) + .collect(); + + if history_indices.is_empty() || monitor_indices.is_empty() { + return (0.0, 0.0); + } + + // 2. Extract the data for these periods + let history_ts: Vec = history_indices.iter().map(|&i| pixel_ts[i]).collect(); + let history_dates: Vec = history_indices.iter().map(|&i| dates[i]).collect(); + let monitor_ts: Vec = monitor_indices.iter().map(|&i| pixel_ts[i]).collect(); + let monitor_dates: Vec = monitor_indices.iter().map(|&i| dates[i]).collect(); + + // 3. Fit model on the historical period + let model_result = fit_harmonic_model(&history_ts, &history_dates, order); + let model = match model_result { + Ok(m) => m, + Err(_) => return (0.0, 0.0), // Return no-break if model fails + }; + + // 4. Predict for the monitoring period + let predicted_ts = predict_harmonic_model(&model, &monitor_dates, order); + + // 5. Detect break using MOSUM process on residuals + detect_mosum_break( + &monitor_ts, + &predicted_ts, + &monitor_dates, + history_ts.len(), + model.sigma, + h, + alpha, + ) +} #[pyfunction] -pub fn detect_breakpoints( +#[allow(clippy::too_many_arguments)] +pub fn bfast_monitor( py: Python, stack: PyReadonlyArrayDyn, - dates: Vec, // Julian dates - threshold: f64, + dates: Vec, + history_start_date: i64, + monitor_start_date: i64, + order: usize, + h: f64, + alpha: f64, ) -> PyResult>> { let stack_arr = stack.as_array(); @@ -27,8 +241,22 @@ pub fn detect_breakpoints( let height = stack_arr.shape()[1]; let width = stack_arr.shape()[2]; - // Output channels: [break_date, magnitude, confidence] - let mut out_array = ndarray::ArrayD::::zeros(IxDyn(&[3, height, width])); + if time_len != dates.len() { + return Err(CoreError::InvalidArgument(format!( + "Time dimension of stack ({}) does not match length of dates vector ({})", + time_len, + dates.len() + )) + .into()); + } + + // Convert integer dates to fractional years for modeling + let frac_dates = dates_to_frac_years(&dates); + let history_start_frac = dates_to_frac_years(&[history_start_date])[0]; + let monitor_start_frac = dates_to_frac_years(&[monitor_start_date])[0]; + + // Output channels: [break_date, magnitude] + let mut out_array = ndarray::ArrayD::::zeros(IxDyn(&[2, height, width])); // Flatten spatial dimensions for parallel processing let num_pixels = height * width; @@ -38,85 +266,35 @@ pub fn detect_breakpoints( let mut out_flat = out_array .view_mut() - .into_shape((3, num_pixels)) + .into_shape((2, num_pixels)) .map_err(|e| CoreError::ComputationError(e.to_string()))?; // Get mutable 1D views for each output channel let mut out_slices = out_flat.axis_iter_mut(Axis(0)); let mut break_dates = out_slices.next().unwrap(); let mut magnitudes = out_slices.next().unwrap(); - let mut confidences = out_slices.next().unwrap(); // Iterate over each pixel's time series in parallel Zip::from(&mut break_dates) .and(&mut magnitudes) - .and(&mut confidences) .and(stack_flat.axis_iter(Axis(1))) - .par_for_each(|break_date, magnitude, confidence, pixel_ts| { - let (bk_date, mag, conf) = - run_bfast_lite_logic(pixel_ts.as_slice().unwrap(), &dates, threshold); + .par_for_each(|break_date, magnitude, pixel_ts| { + let (bk_date, mag) = run_bfast_monitor_per_pixel( + pixel_ts.as_slice().unwrap(), + &frac_dates, + history_start_frac, + monitor_start_frac, + order, + h, + alpha, + ); *break_date = bk_date; *magnitude = mag; - *confidence = conf; }); Ok(out_array.into_pyarray(py).to_owned()) } -// Pure Rust: The compiler optimizes this loop heavily. -fn run_bfast_lite_logic(pixel_ts: &[f64], dates: &[i64], thresh: f64) -> (f64, f64, f64) { - if pixel_ts.len() <= 10 { - return (-1.0, 0.0, 0.0); - } - - let mut max_diff = 0.0; - let mut break_idx = 0; - - // Iterate through possible breakpoints, ensuring enough data on each side - for i in 5..(pixel_ts.len() - 5) { - let (slope1, _) = calculate_linear_regression(&pixel_ts[..i]); - let (slope2, _) = calculate_linear_regression(&pixel_ts[i..]); - - let diff = (slope1 - slope2).abs(); - if diff > max_diff { - max_diff = diff; - break_idx = i; - } - } - - if max_diff > thresh { - ( - dates.get(break_idx).map_or(-1.0, |d| *d as f64), - max_diff, - 1.0, // Confidence is simplified to 1.0 if a break is found - ) - } else { - (-1.0, 0.0, 0.0) - } -} - -// Local helper for linear regression, adapted from src/trends.rs -fn calculate_linear_regression(y: &[f64]) -> (f64, f64) { - if y.is_empty() { - return (0.0, 0.0); - } - let n = y.len() as f64; - let x_sum: f64 = (0..y.len()).map(|i| i as f64).sum(); - let y_sum: f64 = y.iter().sum(); - let xy_sum: f64 = y.iter().enumerate().map(|(i, &yi)| i as f64 * yi).sum(); - let x_sq_sum: f64 = (0..y.len()).map(|i| (i as f64).powi(2)).sum(); - - let denominator = n * x_sq_sum - x_sum.powi(2); - if denominator.abs() < 1e-10 { - return (0.0, y.iter().sum::() / n); // Vertical line, return mean as intercept - } - - let slope = (n * xy_sum - x_sum * y_sum) / denominator; - let intercept = (y_sum - slope * x_sum) / n; - - (slope, intercept) -} - // --- 2. Short-Circuit Classifier Example --- #[pyfunction] @@ -141,16 +319,18 @@ pub fn complex_classification( let mut out = ndarray::ArrayD::::zeros(blue_arr.raw_dim()); - out.indexed_iter_mut().par_bridge().for_each(|(idx, res)| { - let b = blue_arr[&idx]; - let g = green_arr[&idx]; - let r = red_arr[&idx]; - let n = nir_arr[&idx]; - let s1 = swir1_arr[&idx]; - let s2 = swir2_arr[&idx]; - let t = temp_arr[&idx]; - *res = classify_pixel(b, g, r, n, s1, s2, t); - }); + out.indexed_iter_mut() + .par_bridge() + .for_each(|(idx, res)| { + let b = blue_arr[&idx]; + let g = green_arr[&idx]; + let r = red_arr[&idx]; + let n = nir_arr[&idx]; + let s1 = swir1_arr[&idx]; + let s2 = swir2_arr[&idx]; + let t = temp_arr[&idx]; + *res = classify_pixel(b, g, r, n, s1, s2, t); + }); Ok(out.into_pyarray(py).to_owned()) } diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 2e36837..8e976cf 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1,47 +1,83 @@ import numpy as np -from eo_processor import detect_breakpoints, complex_classification +import pandas as pd +from eo_processor import bfast_monitor, complex_classification -def test_detect_breakpoints(): +def test_bfast_monitor_logic(): """ - Test the detect_breakpoints function with a synthetic time series. + Test the bfast_monitor function with synthetic time series + for both break and no-break scenarios. """ - # Create a time series with a clear breakpoint - time = 100 - breakpoint_time = 50 + # --- 1. Generate common data --- + # Create a date range + history_dates = pd.to_datetime(pd.date_range(start="2010-01-01", end="2014-12-31", freq="16D")) + monitor_dates = pd.to_datetime(pd.date_range(start="2015-01-01", end="2017-12-31", freq="16D")) + all_dates = history_dates.union(monitor_dates) + # Convert dates to fractional years for generating the signal + time_frac = all_dates.year + all_dates.dayofyear / 365.25 + + # Convert dates to integer format YYYYMMDD for the function input + dates_int = (all_dates.year * 10000 + all_dates.month * 100 + all_dates.day).to_numpy(dtype=np.int64) + + history_start_date = 20100101 + monitor_start_date = 20150101 + + # Generate a base harmonic signal np.random.seed(42) - # Reduce noise to make the breakpoint more obvious and the test more stable - noise = np.random.normal(0, 0.1, time) - y = ( - np.concatenate( - [ - np.linspace(0, 10, breakpoint_time), - np.linspace(10, 0, time - breakpoint_time), - ] - ) - + noise + noise = np.random.normal(0, 0.05, len(all_dates)) + signal = 0.5 + 0.2 * np.cos(2 * np.pi * time_frac) + 0.1 * np.sin(4 * np.pi * time_frac) + noise + + # --- 2. Test break detection scenario --- + + # Introduce a sudden drop in the monitoring period + break_signal = signal.values.copy() + monitor_start_index = len(history_dates) + break_signal[monitor_start_index:] -= 0.4 + + # Create a 3D stack (Time, Y, X) + stack_break = np.zeros((len(all_dates), 1, 1)) + stack_break[:, 0, 0] = break_signal + + # Run bfast_monitor for the break scenario + result_break = bfast_monitor( + stack_break, + dates_int.tolist(), + history_start_date=history_start_date, + monitor_start_date=monitor_start_date, + order=1, + h=0.25, + alpha=0.05, ) - # Create a 3D stack (time, y, x) - stack = np.zeros((time, 1, 1)) - stack[:, 0, 0] = y + break_date_frac = result_break[0, 0, 0] + magnitude = result_break[1, 0, 0] - # Create corresponding dates - dates = np.arange(time).astype(np.int64) + # Assert that a breakpoint was detected near the start of the monitoring period + # The exact date depends on the MOSUM window, so we check a range + assert 2015.0 < break_date_frac < 2016.5 + assert magnitude > 0.3 # Should be around 0.4 - # Run the breakpoint detection - result = detect_breakpoints(stack, dates.tolist(), threshold=0.1) + # --- 3. Test no-break scenario --- - # Extract the results - break_date = result[0, 0, 0] - magnitude = result[1, 0, 0] - confidence = result[2, 0, 0] + # Use the original stable signal + stack_stable = np.zeros((len(all_dates), 1, 1)) + stack_stable[:, 0, 0] = signal + + # Run bfast_monitor for the stable scenario + result_stable = bfast_monitor( + stack_stable, + dates_int.tolist(), + history_start_date=history_start_date, + monitor_start_date=monitor_start_date, + order=1, + h=0.25, + alpha=0.05, + ) - # Assert that a breakpoint was detected near the expected time - assert np.isclose(break_date, breakpoint_time, atol=5) - assert magnitude > 0 - assert confidence == 1.0 + # Assert that no breakpoint was detected + assert result_stable[0, 0, 0] == 0.0 + assert result_stable[1, 0, 0] == 0.0 def test_complex_classification(): diff --git a/tox.ini b/tox.ini index 8ca2083..4e42d16 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ deps = maturin>=1.9.6 pytest>=7.0 numpy>=1.20.0 + pandas pillow>=9.0.0 scikit-learn>=1.0 scikit-image>=0.18.0 @@ -42,6 +43,7 @@ deps = pytest>=7.0 pytest-cov>=4.0.0 numpy>=1.20.0 + pandas pillow>=9.0.0 scikit-learn>=1.0 scikit-image>=0.18.0