-
Notifications
You must be signed in to change notification settings - Fork 45
RANSAC estimator #320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
santiago-imelio
wants to merge
7
commits into
elixir-nx:main
Choose a base branch
from
santiago-imelio:simelio/ransac-regressor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
RANSAC estimator #320
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
343974c
RANSAC linear regression
santiago-imelio 80f64ac
correct losses
santiago-imelio a9e20e6
add metrics alias
santiago-imelio 1d1752f
support polynomial regression base estimator
santiago-imelio bd1d387
use pattern matching
santiago-imelio 88a99a2
use polynomial regression
santiago-imelio 244967c
correct namespace
santiago-imelio File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| defmodule Scholar.Linear.RANSACRegression do | ||
| @moduledoc """ | ||
| The Random Sample Consensus algorithm is an iterative | ||
| method that fits a robust estimator, able to cope with | ||
| a large proportion of outliers in the data. | ||
|
|
||
| ## References | ||
|
|
||
| [Random sample consensus: a paradigm for model fitting with applications to image analysis and automated cartography](https://www.cs.ait.ac.th/~mdailey/cvreadings/Fischler-RANSAC.pdf) | ||
| """ | ||
|
|
||
| import Nx.Defn | ||
| alias Scholar.Metrics.Regression, as: Metrics | ||
| alias Scholar.Linear.PolynomialRegression | ||
|
|
||
| @derive {Nx.Container, containers: [:model, :inliers_indices, :n_inliers]} | ||
| defstruct [:model, :inliers_indices, :n_inliers] | ||
|
|
||
| @losses [mae: 1, mse: 2] | ||
|
|
||
| opts = [ | ||
| max_iters: [ | ||
| type: :integer, | ||
| doc: """ | ||
| Maximum number of iterations for random sample selection. | ||
| """, | ||
| default: 100 | ||
| ], | ||
| threshold: [ | ||
| type: :float, | ||
| required: true, | ||
| doc: """ | ||
| Maximum error for a sample to be considered classified as inlier. | ||
| """ | ||
| ], | ||
| min_samples: [ | ||
| type: :integer, | ||
| required: true, | ||
| doc: """ | ||
| Minimum number of samples chosen randomly from original data. | ||
| """ | ||
| ], | ||
| degree: [ | ||
| type: :integer, | ||
| doc: """ | ||
| Degree for the polynomial regression base estimator. Default is 1, | ||
| which is equivalente to a linear regression. | ||
| """, | ||
| default: 1 | ||
| ], | ||
| loss: [ | ||
| type: {:in, Keyword.keys(@losses)}, | ||
| default: :mae, | ||
| doc: """ | ||
| Loss function to evaluate estimator. If the loss in a sample is strictly | ||
| lesser than `threshold`, then this sample is classified as an inlier. | ||
| """ | ||
| ], | ||
| random_seed: [ | ||
| type: :integer, | ||
| default: 42 | ||
| ] | ||
| ] | ||
|
|
||
| @opts_schema NimbleOptions.new!(opts) | ||
|
|
||
| @doc """ | ||
| Fits a robust regressor using RANSAC algorithm, using | ||
| a polynomial regression as base estimator. | ||
|
|
||
| #{NimbleOptions.docs(@opts_schema)} | ||
| """ | ||
| def fit(x, y, opts) do | ||
| opts = NimbleOptions.validate!(opts, @opts_schema) | ||
|
|
||
| inliers_mask = fit_n(x, y, opts) | ||
| n_inliers = Nx.to_number(Nx.sum(inliers_mask)) | ||
|
|
||
| if n_inliers == 0 do | ||
| raise "RANSAC was not able to find consensus set that | ||
| meets the required criteria." | ||
| end | ||
|
|
||
| inliers_idx = Nx.argsort(inliers_mask, direction: :desc)[0..(n_inliers - 1)] | ||
|
|
||
| x_inliers = Nx.take(x, inliers_idx) | ||
| y_inliers = Nx.take(y, inliers_idx) | ||
|
|
||
| model = PolynomialRegression.fit(x_inliers, y_inliers, degree: opts[:degree]) | ||
|
|
||
| %__MODULE__{inliers_indices: inliers_idx, n_inliers: n_inliers, model: model} | ||
| end | ||
|
|
||
| def predict(%__MODULE__{model: m}, x), do: PolynomialRegression.predict(m, x) | ||
|
|
||
| defnp loss_fn(loss_t, y_true, y_pred) do | ||
| {loss} = loss_t.shape | ||
|
|
||
| cond do | ||
| loss == @losses[:mae] -> Metrics.mean_absolute_error(y_true, y_pred, axes: [1]) | ||
| loss == @losses[:mse] -> Metrics.mean_square_error(y_true, y_pred, axes: [1]) | ||
| true -> Metrics.mean_absolute_error(y_true, y_pred, axis: 0) | ||
| end | ||
| end | ||
|
|
||
| defnp fit_fn(x_train, y_train, degree_t) do | ||
| {d} = degree_t.shape | ||
| PolynomialRegression.fit(x_train, y_train, degree: d) | ||
| end | ||
|
|
||
| defnp predict_fn(model, x) do | ||
| PolynomialRegression.predict(model, x) | ||
| end | ||
|
|
||
| defnp fit_n(x, y, opts) do | ||
| max_iters = opts[:max_iters] | ||
| thr = opts[:threshold] | ||
| min_samples = opts[:min_samples] | ||
| loss = @losses[opts[:loss]] | ||
| degree = opts[:degree] | ||
|
|
||
| loss_t = Nx.broadcast(:nan, {loss}) | ||
| min_samples_t = Nx.broadcast(:nan, {min_samples}) | ||
| degree_t = Nx.broadcast(:nan, {degree}) | ||
|
|
||
| rand_key = Nx.Random.key(opts[:random_seed]) | ||
| data = Nx.concatenate([x, y], axis: 1) | ||
| inliers = Nx.broadcast(0, {max_iters, elem(x.shape, 0)}) | ||
|
|
||
| {inliers_masks, _} = | ||
| while {inliers, {i = 0, x, y, rand_key, data, min_samples_t, thr, loss_t, degree_t}}, | ||
| i < max_iters do | ||
| n_samples = elem(min_samples_t.shape, 0) | ||
| {rand_samples, rand_key} = Nx.Random.choice(rand_key, data, axis: 0, samples: n_samples) | ||
|
|
||
| {rand_x, rand_y} = Nx.split(rand_samples, elem(x.shape, 1), axis: 1) | ||
| model = fit_fn(rand_x, rand_y, degree_t) | ||
| y_pred = predict_fn(model, x) | ||
| y_pred = Nx.reshape(y_pred, {elem(y_pred.shape, 0), 1}) | ||
|
|
||
| error = loss_fn(loss_t, y, y_pred) | ||
| inliers_i = Nx.less(error, thr) | ||
|
|
||
| updated_inliers = | ||
| Nx.put_slice(inliers, [i, 0], Nx.reshape(inliers_i, {1, elem(x.shape, 0)})) | ||
|
|
||
| {updated_inliers, {i + 1, x, y, rand_key, data, min_samples_t, thr, loss_t, degree_t}} | ||
| end | ||
|
|
||
| best_mask_idx = | ||
| inliers_masks | ||
| |> Nx.vectorize(:masks) | ||
| |> Nx.sum() | ||
| |> Nx.devectorize(keep_names: false) | ||
| |> Nx.argmax() | ||
|
|
||
| inliers_masks[best_mask_idx] | ||
| end | ||
| end | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want to avoid Nx.to_number because it means
fitthen cannot be called from anotherdefn. The best way is to write some tests calling this function directly and also from anotherdefn.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josevalim Hey! thanks for the suggestions. I've been looking for a way to use
n_inlierswithout converting to number but failed to do so. Since it is aNx.Defn.Expr, I cannot use it for indexing. Any suggestion on how to approach this? The objective is to extract fromxandythe inliers, meaning the indices frominliers_masksuch that the value is1.