-
Notifications
You must be signed in to change notification settings - Fork 45
RANSAC estimator #320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
RANSAC estimator #320
Conversation
| @@ -0,0 +1,159 @@ | |||
| defmodule RANSACRegression do | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to namespace the module :)
| opts = NimbleOptions.validate!(opts, @opts_schema) | ||
|
|
||
| inliers_mask = fit_n(x, y, opts) | ||
| n_inliers = Nx.to_number(Nx.sum(inliers_mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You want to avoid Nx.to_number because it means fit then cannot be called from another defn. The best way is to write some tests calling this function directly and also from another defn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josevalim Hey! thanks for the suggestions. I've been looking for a way to use n_inliers without converting to number but failed to do so. Since it is a Nx.Defn.Expr, I cannot use it for indexing. Any suggestion on how to approach this? The objective is to extract from x and y the inliers, meaning the indices from inliers_mask such that the value is 1.
Everything that is not an option is assumed to be an expression and converted to a tensor or nx.defn.expr. But that's generally good. When it is an expression, it can be cached and optimised. |
Hello! This is work in progress but I would love to get some feedback on how this looks so far.
Specifically on the
whilecode, not sure if I'm passing things correctly (loss_fnandmin_samplesare being passed as tensor shapes because otherwise they turn intoNx.Defn.Expr0_o).Ideally, RANSAC could take any fit-predict model as base estimator, so I'd love to hear suggestions on how to go about implementing that (could be something similar to the
loss_fnpattern, but not sure).