-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KNNImputer #303
base: main
Are you sure you want to change the base?
Add KNNImputer #303
Conversation
lib/scholar/impute/knn_imputer.ex
Outdated
if opts[:missing_values] != :nan and | ||
Nx.any(Nx.is_nan(x)) == Nx.tensor(1, type: :u8) do | ||
raise ArgumentError, | ||
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array" | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check does not really work in Nx. If you call fit
inside Nx.Defn.jit
, then x
is an expression, and we can't read its values to find out if there is a nan or not. The best we can do is to remove this check and document it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this check in simple imputer
https://github.com/elixir-nx/scholar/blob/main/lib/scholar/impute/simple_imputer.ex
Are you sure it won't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also broken there. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it there: c024c5b
lib/scholar/impute/knn_imputer.ex
Outdated
|
||
all_nan_rows_count = Nx.sum(all_nan_rows) | ||
|
||
if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this code won't work because, when you have an expression, you can't get a number from it. Can we remove this check? What happens if we don't check for this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can test this by calling fit
after jitting it with Nx.Defn.fit
.
lib/scholar/impute/knn_imputer.ex
Outdated
|
||
# if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute | ||
{potential_neighbor} = | ||
if potential_neighbor[nan_col] == Nx.Constants.nan() do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this check is guaranteed to work, given two NaNs are not guaranteed to be equal. Using Nx.is_nan
would be more appropriate.
lib/scholar/impute/knn_imputer.ex
Outdated
|
||
x = | ||
if opts[:missing_values] != :nan, | ||
do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use Nx.is_nan
here NaN is not equal to itself
lib/scholar/impute/knn_imputer.ex
Outdated
coordinates = coordinates - 1 | ||
|
||
# inputes zeros in nan_col to calculate distance with squared_euclidean | ||
new_row = Nx.indexed_put(row, Nx.new_axis(nan_col, 0), Nx.tensor(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, when you write in defn
, you don't need to wrap this zero in Nx.tensor
. I prefer to explicitly use Nx.<type>
or Nx.tensor(x, type: type)
to indicate the type of the tensor. Now, there are some cases where imputter has fixed type like :f32
. I think that this might cause undesired upcasts when e.g. I have tensor of type :bf16
. So I suggest to check if there are any unwanted casts / upcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it but I don't know how to change this line
row_distances = Nx.iota({rows}, type: {:f, 32})
because i don't know what the type calculated distance will be at this point
lib/scholar/impute/knn_imputer.ex
Outdated
|
||
# if row has all nans we skip it | ||
{weight, potential_neighbor} = | ||
if present_coordinates == 0 do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in comment up, try to replace "bare" numbers with typed tensors
lib/scholar/impute/knn_imputer.ex
Outdated
@@ -0,0 +1,256 @@ | |||
defmodule Scholar.Impute.KNNImputer do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be written with double t KNNImputter
like formatter etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, I dropped some comments :))
Hi @srzeszut and thanks for the pull request. I’m traveling now and don’t have my laptop with me. Will be back this Sunday, so I will have a look probably next week. |
Thanks for the review, I apply suggested changes and left some comments. |
lib/scholar/impute/knn_imputter.ex
Outdated
|
||
if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do | ||
raise ArgumentError, | ||
"Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error messages start in lowercase. :)
"Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" | |
"number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" |
lib/scholar/impute/knn_imputter.ex
Outdated
|
||
all_nan_rows_count = Nx.sum(all_nan_rows) | ||
|
||
if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add some tests? In particular, please add a test where you call jit this function and then you call it: Nx.Defn.jit(...).(arg1, arg2)
. It should reveal some errors around here. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added tests and checked it. I removed those checks and added them in the description
`n_neighbors` nearest neighbors found in the training set. Two samples are | ||
close if the features that neither is missing are close. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`n_neighbors` nearest neighbors found in the training set. Two samples are | |
close if the features that neither is missing are close. | |
`n_neighbors` nearest neighbors found in the training set. Two samples are | |
close if the features that neither is missing are close. |
|
||
Preconditions: | ||
* `number_of_neighbors` is a positive integer. | ||
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to break this long line :)
test "Wrong impute rank" do | ||
x = Nx.tensor([1, 2, 2, 3]) | ||
|
||
assert_raise ArgumentError, | ||
"Wrong input rank. Expected: 2, got: 1", | ||
fn -> | ||
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | ||
end | ||
end | ||
|
||
test "Invalid n_neighbors value" do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test names start in lowercase :)
test "Wrong impute rank" do | |
x = Nx.tensor([1, 2, 2, 3]) | |
assert_raise ArgumentError, | |
"Wrong input rank. Expected: 2, got: 1", | |
fn -> | |
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | |
end | |
end | |
test "Invalid n_neighbors value" do | |
test "invalid impute rank" do | |
x = Nx.tensor([1, 2, 2, 3]) | |
assert_raise ArgumentError, | |
"Wrong input rank. Expected: 2, got: 1", | |
fn -> | |
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | |
end | |
end | |
test "invalid n_neighbors value" do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dropped the last round of nitpicks and we are good to go!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First review. Some features we might wanna have:
- Make k-NN algorithm configurable.
- Make the metric configurable.
You can leave these for another pull request. Have a look at e.g. KNNClassifier
how it is done over there.
I should have another look tonight.
The default value expects there are no NaNs in the input tensor. | ||
""" | ||
], | ||
number_of_neighbors: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest changing this to num_neighbors
to be consistent with the rest of Scholar
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several minor comments for now. I have to go through the code at least once more as I don't exactly understand the logic here.
|
||
x = | ||
if opts[:missing_values] != :nan, | ||
do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to use ==
instead of Nx.equal/2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a deftransform, so Nx.equal
is the proper function. ==
will be Elixir.Kernel.==
placeholder_value = Nx.Constants.nan() |> Nx.tensor() | ||
|
||
statistics = knn_impute(x, placeholder_value, num_neighbors: num_neighbors) | ||
missing_values = opts[:missing_values] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this line above so that you don't access opts[:missing_values]
multiple times.
|
||
{_, values_to_impute} = | ||
while {{row = 0, mask, num_neighbors, num_rows, x}, values_to_impute}, | ||
Nx.less(row, num_rows) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use <
instead of Nx.less/2
over here.
Nx.less(row, num_rows) do | ||
{_, values_to_impute} = | ||
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute}, | ||
Nx.less(col, num_cols) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
{_, values_to_impute} = | ||
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute}, | ||
Nx.less(col, num_cols) do | ||
if mask[row][col] > 0 do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if mask[row][col] do
should work here.
* `number_of_neighbors` is a positive integer. | ||
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* `number_of_neighbors` is a positive integer. | |
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter | |
* The number of neighbors must be less than the number of valid rows - 1. | |
A valid row is a row with more than 1 non-NaN values. Otherwise it is better to use a simpler imputer. |
Preconditions: | ||
* `number_of_neighbors` is a positive integer. | ||
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter | ||
* when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor | |
* When you set a value different than `:nan` in `missing_values` there should be no NaNs in the input tensor |
* `:missing_values` - the same value as in `:missing_values` | ||
|
||
* `:statistics` - The imputation fill value for each feature. Computing statistics can result in | ||
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. | |
[`Nx.Constants.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the explicit linking in hexdoc?
|
||
The function returns a struct with the following parameters: | ||
|
||
* `:missing_values` - the same value as in `:missing_values` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* `:missing_values` - the same value as in `:missing_values` | |
* `:missing_values` - the same value as in the `:missing_values` option |
|
||
num_neighbors = opts[:number_of_neighbors] | ||
|
||
placeholder_value = Nx.Constants.nan() |> Nx.tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
placeholder_value = Nx.Constants.nan() |> Nx.tensor() | |
placeholder_value = Nx.Constants.nan() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably want to pass the input type here to avoid upcasts
|
||
opts_schema = [ | ||
missing_values: [ | ||
type: {:or, [:float, :integer, {:in, [:nan]}]}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type: {:or, [:float, :integer, {:in, [:nan]}]}, | |
type: {:or, [:float, :integer, {:in, [:nan]}]}, |
I believe this should allow :infinity and :neg_infinity too for completeness
indices = | ||
[Nx.stack(row), Nx.stack(col)] | ||
|> Nx.concatenate() | ||
|> Nx.stack() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indices = | |
[Nx.stack(row), Nx.stack(col)] | |
|> Nx.concatenate() | |
|> Nx.stack() | |
indices = Nx.stack([row, col]) |> Nx.reshape({1, 2}) |
If I read the code correctly, row and col are scalars and this should yield the same result
|> Nx.concatenate() | ||
|> Nx.stack() | ||
|
||
values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg)) | |
values_to_impute = Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1})) |
I think this is even simpler
{_, row_distances} = | ||
while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}, | ||
Nx.less(i, rows) do | ||
potential_donor = x[i] | ||
|
||
distance = | ||
if i == nan_row do | ||
Nx.Constants.infinity(Nx.type(row_with_value_to_fill)) | ||
else | ||
nan_euclidian(row_with_value_to_fill, nan_col, potential_donor) | ||
end | ||
|
||
row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance) | ||
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances} | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try this:
potential_donors = Nx.vectorize(x, :rows)
distances = nan_euclidean(row_with_value_to_fill, nan_col, potential_donors) |> Nx.devectorize()
row_distances = Nx.indexed_put(distances, [i], Nx.Constants.infinity())
I have added the KNNImputer and I am currently implementing tests to ensure that it behaves as expected across various scenarios, including edge cases.