diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java similarity index 97% rename from hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter rename to hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java index 0dee299f2..f3c1a6322 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/IModelFitter.java @@ -1,45 +1,45 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.util.List; - -/** - * Base class for mathematical model fitter used with {@link RansacFitter}. - * @param mathematical model representing the parameters to estimate - * @since 4.1 - */ -public interface IModelFitter { - - /** - * Fits the mathematical model parameters based on the set of observed data. - * @param points set of observed data - * @return the fitted model parameters - */ - M fitModel(final List points); - - /** - * Computes the error between the model and an observed data. - *

- * This method is used to determine if the observed data is an inlier or an outlier. - *

- * @param model fitted model - * @param point observed data - * @return the error between the model and the observed data - */ - double computeModelError(final M model, final double[] point); -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.util.List; + +/** + * Base class for mathematical model fitter used with {@link RansacFitter}. + * @param mathematical model representing the parameters to estimate + * @since 4.1 + */ +public interface IModelFitter { + + /** + * Fits the mathematical model parameters based on the set of observed data. + * @param points set of observed data + * @return the fitted model parameters + */ + M fitModel(final List points); + + /** + * Computes the error between the model and an observed data. + *

+ * This method is used to determine if the observed data is an inlier or an outlier. + *

+ * @param model fitted model + * @param point observed data + * @return the error between the model and the observed data + */ + double computeModelError(final M model, final double[] point); +} diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java similarity index 97% rename from hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter rename to hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java index 56445a5e4..710abd97c 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/PolynomialModelFitter.java @@ -1,133 +1,133 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.util.List; -import java.util.stream.IntStream; -import org.hipparchus.exception.LocalizedCoreFormats; -import org.hipparchus.exception.MathIllegalArgumentException; -import org.hipparchus.linear.Array2DRowRealMatrix; -import org.hipparchus.linear.ArrayRealVector; -import org.hipparchus.linear.RealMatrix; -import org.hipparchus.linear.RealVector; -import org.hipparchus.linear.SingularValueDecomposition; -import org.hipparchus.util.FastMath; - -/** - * Fitter for polynomial model. - * @since 4.1 - */ -public class PolynomialModelFitter implements IModelFitter { - - /** Class representing the polynomial model to fit. */ - public static final class Model { - - /** Coefficients of the polynomial model. */ - private final double[] coefficients; - - /** - * Constructor. - * @param coefficients coefficients of the polynomial model - */ - public Model(final double[] coefficients) { - this.coefficients = coefficients.clone(); - } - - /** - * Predicts the model value for the input point. - * @param x point - * @return the model value for the given point - */ - public double predict(final double x) { - return IntStream.range(0, coefficients.length).mapToDouble(i -> coefficients[i] * FastMath.pow(x, i)).sum(); - } - - /** - * Get the coefficients of the polynomial model. - *

- * The coefficients are sort by degree. - * For instance, for a quadratic equation the coefficients are as followed: - * y = coefficients[2] * x * x + coefficients[1] * x + coefficients[0] - *

- * @return the coefficients of the polynomial model - */ - public double[] getCoefficients() { - return coefficients; - } - } - - /** Degree of the polynomial to fit. */ - private final int degree; - - /** - * Constructor. - * @param degree degree of the polynomial to fit - */ - public PolynomialModelFitter(final int degree) { - if (degree < 1) { - throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, degree, 1); - } - this.degree = degree; - } - - /** {@inheritDoc. */ - @Override - public Model fitModel(final List points) { - // Reference: Wikipedia page "Polynomial regression" - final int size = points.size(); - checkSampleSize(size); - - // Fill the data - final double[][] x = new double[size][degree + 1]; - final double[] y = new double[size]; - for (int i = 0; i < size; i++) { - final double currentX = points.get(i)[0]; - final double currentY = points.get(i)[1]; - double value = 1.0; - for (int j = 0; j <= degree; j++) { - x[i][j] = value; - value *= currentX; - } - y[i] = currentY; - } - - // Computes (X^T.X)^-1 X^T.Y to determine the coefficients "C" of the polynomial (Y = X.C) - final RealMatrix matrixX = new Array2DRowRealMatrix(x); - final RealVector matrixY = new ArrayRealVector(y); - final RealMatrix matrixXTranspose = matrixX.transpose(); - final RealMatrix xTx = matrixXTranspose.multiply(matrixX); - final RealVector xTy = matrixXTranspose.operate(matrixY); - final RealVector coefficients = new SingularValueDecomposition(xTx).getSolver().solve(xTy); - return new Model(coefficients.toArray()); - } - - /** {@inheritDoc}. */ - @Override - public double computeModelError(final Model model, final double[] point) { - return FastMath.abs(point[1] - model.predict(point[0])); - } - - /** - * Verifies that the size of the set of observed data is consistent with the degree of the polynomial to fit. - * @param size size of the set of observed data - */ - private void checkSampleSize(final int size) { - if (size < degree + 1) { - throw new IllegalArgumentException(String.format("Not enough points to fit polynomial model, at least %d points are required", degree + 1)); - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.util.List; +import java.util.stream.IntStream; +import org.hipparchus.exception.LocalizedCoreFormats; +import org.hipparchus.exception.MathIllegalArgumentException; +import org.hipparchus.linear.Array2DRowRealMatrix; +import org.hipparchus.linear.ArrayRealVector; +import org.hipparchus.linear.RealMatrix; +import org.hipparchus.linear.RealVector; +import org.hipparchus.linear.SingularValueDecomposition; +import org.hipparchus.util.FastMath; + +/** + * Fitter for polynomial model. + * @since 4.1 + */ +public class PolynomialModelFitter implements IModelFitter { + + /** Class representing the polynomial model to fit. */ + public static final class Model { + + /** Coefficients of the polynomial model. */ + private final double[] coefficients; + + /** + * Constructor. + * @param coefficients coefficients of the polynomial model + */ + public Model(final double[] coefficients) { + this.coefficients = coefficients.clone(); + } + + /** + * Predicts the model value for the input point. + * @param x point + * @return the model value for the given point + */ + public double predict(final double x) { + return IntStream.range(0, coefficients.length).mapToDouble(i -> coefficients[i] * FastMath.pow(x, i)).sum(); + } + + /** + * Get the coefficients of the polynomial model. + *

+ * The coefficients are sort by degree. + * For instance, for a quadratic equation the coefficients are as followed: + * y = coefficients[2] * x * x + coefficients[1] * x + coefficients[0] + *

+ * @return the coefficients of the polynomial model + */ + public double[] getCoefficients() { + return coefficients; + } + } + + /** Degree of the polynomial to fit. */ + private final int degree; + + /** + * Constructor. + * @param degree degree of the polynomial to fit + */ + public PolynomialModelFitter(final int degree) { + if (degree < 1) { + throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, degree, 1); + } + this.degree = degree; + } + + /** {@inheritDoc. */ + @Override + public Model fitModel(final List points) { + // Reference: Wikipedia page "Polynomial regression" + final int size = points.size(); + checkSampleSize(size); + + // Fill the data + final double[][] x = new double[size][degree + 1]; + final double[] y = new double[size]; + for (int i = 0; i < size; i++) { + final double currentX = points.get(i)[0]; + final double currentY = points.get(i)[1]; + double value = 1.0; + for (int j = 0; j <= degree; j++) { + x[i][j] = value; + value *= currentX; + } + y[i] = currentY; + } + + // Computes (X^T.X)^-1 X^T.Y to determine the coefficients "C" of the polynomial (Y = X.C) + final RealMatrix matrixX = new Array2DRowRealMatrix(x); + final RealVector matrixY = new ArrayRealVector(y); + final RealMatrix matrixXTranspose = matrixX.transpose(); + final RealMatrix xTx = matrixXTranspose.multiply(matrixX); + final RealVector xTy = matrixXTranspose.operate(matrixY); + final RealVector coefficients = new SingularValueDecomposition(xTx).getSolver().solve(xTy); + return new Model(coefficients.toArray()); + } + + /** {@inheritDoc}. */ + @Override + public double computeModelError(final Model model, final double[] point) { + return FastMath.abs(point[1] - model.predict(point[0])); + } + + /** + * Verifies that the size of the set of observed data is consistent with the degree of the polynomial to fit. + * @param size size of the set of observed data + */ + private void checkSampleSize(final int size) { + if (size < degree + 1) { + throw new IllegalArgumentException(String.format("Not enough points to fit polynomial model, at least %d points are required", degree + 1)); + } + } +} diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java similarity index 97% rename from hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter rename to hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java index 13fbd08e9..65255f413 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitter.java @@ -1,157 +1,157 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.Random; -import java.util.stream.Collectors; -import org.hipparchus.exception.LocalizedCoreFormats; -import org.hipparchus.exception.MathIllegalArgumentException; - -/** - * Class implementing Random sample consensus (RANSAC) algorithm. - *

- * RANSAC is a robust method for estimating the parameters of a - * mathematical model from a set of observed data. - * It works iteratively selecting random subsets of the input data, - * fitting a model to these subsets, and then determining how many - * data points from the entire set are consistent with the estimated - * model parameters. - * The model can yields the largest number of inliers (i.e., point - * that fit well) is considered the best estimate. - *

- *

- * This implementation is designed to be generic and can be used with - * different types of models, such as {@link PolynomialModelFitter - * polynomial models}. - *

- * @param mathematical model representing the parameters to estimate - * @since 4.1 - */ -public class RansacFitter { - - /** Mathematical model fitter. */ - private final IModelFitter fitter; - - /** The minimum number of data points to estimate the model parameters. */ - private final int sampleSize; - - /** The maximum number of iterations allowed to fit the model. */ - private final int maxIterations; - - /** Threshold to assert that a data point fits the model. */ - private final double threshold; - - /** The minimum number of close data points required to assert that the model fits the input data. */ - private final int minInliers; - - /** Random generator. */ - private final Random random; - - /** - * Constructor. - * @param fitter mathematical model fitter - * @param sampleSize minimum number of data points to estimate the model parameters - * @param maxIterations maximum number of iterations allowed to fit the model - * @param threshold threshold to assert that a data point fits the model - * @param minInliers minimum number of close data points required to assert that the model fits the input data - * @param seed seed for the random generator - */ - public RansacFitter(final IModelFitter fitter, final int sampleSize, - final int maxIterations, final double threshold, - final int minInliers, final int seed) { - this.fitter = fitter; - this.sampleSize = sampleSize; - this.maxIterations = maxIterations; - this.threshold = threshold; - this.minInliers = minInliers; - this.random = new Random(seed); - checkInputs(); - } - - /** - * Fits the set of observed data to determine the model parameters. - * @param points set of observed data - * @return a java class containing the best estimate of the model parameters - */ - public RansacFitterOutputs fit(final List points) { - - // Initialize the best model data - final List data = new ArrayList<>(points); - Optional bestModel = Optional.empty(); - List bestInliers = new ArrayList<>(); - - // Iterative loop to determine the best model - for (int iteration = 0; iteration < maxIterations; iteration++) { - - // Random permute the set of observed data and determine the inliers - Collections.shuffle(data, random); - final List inliers = determineCurrentInliersFromRandomlyPermutedPoints(data); - - // Verifies if the current inliers are fit better the model than the previous ones - if (isCurrentInliersSetBetterThanPreviousOne(inliers, bestInliers)) { - bestModel = Optional.of(fitter.fitModel(inliers)); - bestInliers = inliers; - } - - } - - // Returns the best model data - return new RansacFitterOutputs<>(bestModel, bestInliers, fitter); - } - - /** - * Determines the current inliers (i.e., points that fit well the model) from the input randomly permuted data. - * @param permutedPoints randomly permuted data - * @return the list of inliers - */ - private List determineCurrentInliersFromRandomlyPermutedPoints(final List permutedPoints) { - M model = fitter.fitModel(permutedPoints.subList(0, sampleSize)); - return permutedPoints.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList()); - } - - /** - * Verifies is the current inliers are better than the previous ones. - * @param current current inliers - * @param previous previous inliers - * @return true is the current inlier are better than the previous ones - */ - private boolean isCurrentInliersSetBetterThanPreviousOne(final List current, final List previous) { - return current.size() > previous.size() && current.size() >= minInliers; - } - - /** - * Checks that the fitter inputs are correct. - */ - private void checkInputs() { - if (maxIterations < 0) { - throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, maxIterations, 0); - } - if (sampleSize < 0) { - throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, sampleSize, 0); - } - if (threshold < 0.) { - throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, threshold, 0); - } - if (minInliers < 0) { - throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, minInliers, 0); - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.stream.Collectors; +import org.hipparchus.exception.LocalizedCoreFormats; +import org.hipparchus.exception.MathIllegalArgumentException; + +/** + * Class implementing Random sample consensus (RANSAC) algorithm. + *

+ * RANSAC is a robust method for estimating the parameters of a + * mathematical model from a set of observed data. + * It works iteratively selecting random subsets of the input data, + * fitting a model to these subsets, and then determining how many + * data points from the entire set are consistent with the estimated + * model parameters. + * The model can yields the largest number of inliers (i.e., point + * that fit well) is considered the best estimate. + *

+ *

+ * This implementation is designed to be generic and can be used with + * different types of models, such as {@link PolynomialModelFitter + * polynomial models}. + *

+ * @param mathematical model representing the parameters to estimate + * @since 4.1 + */ +public class RansacFitter { + + /** Mathematical model fitter. */ + private final IModelFitter fitter; + + /** The minimum number of data points to estimate the model parameters. */ + private final int sampleSize; + + /** The maximum number of iterations allowed to fit the model. */ + private final int maxIterations; + + /** Threshold to assert that a data point fits the model. */ + private final double threshold; + + /** The minimum number of close data points required to assert that the model fits the input data. */ + private final int minInliers; + + /** Random generator. */ + private final Random random; + + /** + * Constructor. + * @param fitter mathematical model fitter + * @param sampleSize minimum number of data points to estimate the model parameters + * @param maxIterations maximum number of iterations allowed to fit the model + * @param threshold threshold to assert that a data point fits the model + * @param minInliers minimum number of close data points required to assert that the model fits the input data + * @param seed seed for the random generator + */ + public RansacFitter(final IModelFitter fitter, final int sampleSize, + final int maxIterations, final double threshold, + final int minInliers, final int seed) { + this.fitter = fitter; + this.sampleSize = sampleSize; + this.maxIterations = maxIterations; + this.threshold = threshold; + this.minInliers = minInliers; + this.random = new Random(seed); + checkInputs(); + } + + /** + * Fits the set of observed data to determine the model parameters. + * @param points set of observed data + * @return a java class containing the best estimate of the model parameters + */ + public RansacFitterOutputs fit(final List points) { + + // Initialize the best model data + final List data = new ArrayList<>(points); + Optional bestModel = Optional.empty(); + List bestInliers = new ArrayList<>(); + + // Iterative loop to determine the best model + for (int iteration = 0; iteration < maxIterations; iteration++) { + + // Random permute the set of observed data and determine the inliers + Collections.shuffle(data, random); + final List inliers = determineCurrentInliersFromRandomlyPermutedPoints(data); + + // Verifies if the current inliers are fit better the model than the previous ones + if (isCurrentInliersSetBetterThanPreviousOne(inliers, bestInliers)) { + bestModel = Optional.of(fitter.fitModel(inliers)); + bestInliers = inliers; + } + + } + + // Returns the best model data + return new RansacFitterOutputs<>(bestModel, bestInliers, fitter); + } + + /** + * Determines the current inliers (i.e., points that fit well the model) from the input randomly permuted data. + * @param permutedPoints randomly permuted data + * @return the list of inliers + */ + private List determineCurrentInliersFromRandomlyPermutedPoints(final List permutedPoints) { + M model = fitter.fitModel(permutedPoints.subList(0, sampleSize)); + return permutedPoints.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList()); + } + + /** + * Verifies is the current inliers are better than the previous ones. + * @param current current inliers + * @param previous previous inliers + * @return true is the current inlier are better than the previous ones + */ + private boolean isCurrentInliersSetBetterThanPreviousOne(final List current, final List previous) { + return current.size() > previous.size() && current.size() >= minInliers; + } + + /** + * Checks that the fitter inputs are correct. + */ + private void checkInputs() { + if (maxIterations < 0) { + throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, maxIterations, 0); + } + if (sampleSize < 0) { + throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, sampleSize, 0); + } + if (threshold < 0.) { + throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, threshold, 0); + } + if (minInliers < 0) { + throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, minInliers, 0); + } + } +} diff --git a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java similarity index 97% rename from hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs rename to hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java index 35f2c217f..0efdff678 100644 --- a/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs +++ b/hipparchus-fitting/src/main/java/org/hipparchus/fitting/ransac/RansacFitterOutputs.java @@ -1,79 +1,79 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Class containing the best estimate of the model parameters. - * @param mathematical model representing the parameters to estimate - * @since 4.1 - */ -public class RansacFitterOutputs { - - /** Mathematical model fitter used by RANSAC algorithm. */ - private final IModelFitter fitter; - - /** Best model parameters. */ - private final Optional bestModel; - - /** List of points used to determine the best model parameters. */ - private final List bestInliers; - - /** - * Constructor. - * @param bestModel best model parameters - * @param bestInliers list of points used to determine the best model parameters - * @param fitter mathematical model fitter used by RANSAC algorithm - */ - public RansacFitterOutputs(final Optional bestModel, final List bestInliers, final IModelFitter fitter) { - this.bestModel = bestModel; - this.bestInliers = bestInliers; - this.fitter = fitter; - } - - /** - * Get the best model parameters. - * @return the best model parameters - */ - public Optional getBestModel() { - return bestModel; - } - - /** - * Get the list of points used to determine the best model parameters. - * @return the list of points used to determine the best model parameters - */ - public List getBestInliers() { - return bestInliers; - } - - /** - * Finds the points below a given threshold based on the computed best model parameters. - * @param points input list of points - * @param threshold threshold to use - * @return the list of points below the given threshold based on the computed best model parameters - * (can be empty if the all points are above the threshold or if no best model has been found) - */ - public List filterPointsBelowThreshold(final List points, final double threshold) { - return bestModel.map(model -> points.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList())) - .orElse(Collections.emptyList()); - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Class containing the best estimate of the model parameters. + * @param mathematical model representing the parameters to estimate + * @since 4.1 + */ +public class RansacFitterOutputs { + + /** Mathematical model fitter used by RANSAC algorithm. */ + private final IModelFitter fitter; + + /** Best model parameters. */ + private final Optional bestModel; + + /** List of points used to determine the best model parameters. */ + private final List bestInliers; + + /** + * Constructor. + * @param bestModel best model parameters + * @param bestInliers list of points used to determine the best model parameters + * @param fitter mathematical model fitter used by RANSAC algorithm + */ + public RansacFitterOutputs(final Optional bestModel, final List bestInliers, final IModelFitter fitter) { + this.bestModel = bestModel; + this.bestInliers = bestInliers; + this.fitter = fitter; + } + + /** + * Get the best model parameters. + * @return the best model parameters + */ + public Optional getBestModel() { + return bestModel; + } + + /** + * Get the list of points used to determine the best model parameters. + * @return the list of points used to determine the best model parameters + */ + public List getBestInliers() { + return bestInliers; + } + + /** + * Finds the points below a given threshold based on the computed best model parameters. + * @param points input list of points + * @param threshold threshold to use + * @return the list of points below the given threshold based on the computed best model parameters + * (can be empty if the all points are above the threshold or if no best model has been found) + */ + public List filterPointsBelowThreshold(final List points, final double threshold) { + return bestModel.map(model -> points.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList())) + .orElse(Collections.emptyList()); + } +} diff --git a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest.java similarity index 97% rename from hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest rename to hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest.java index 659135678..14c3fb3e0 100644 --- a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest +++ b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/PolynomialModelFitterTest.java @@ -1,38 +1,38 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.util.ArrayList; -import org.hipparchus.exception.MathIllegalArgumentException; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.function.Executable; - -class PolynomialModelFitterTest { - - @Test - void testExceptions() { - assertThrows(MathIllegalArgumentException.class, () -> new PolynomialModelFitter(0), "0 is smaller than the minimum (1)"); - assertThrows(IllegalArgumentException.class, () -> new PolynomialModelFitter(1).fitModel(new ArrayList<>()), "Not enough points to fit polynomial model, at least 2 points are required"); - } - - private static void assertThrows(final Class expectedType, final Executable executable, final String message) { - final T exception = Assertions.assertThrows(expectedType, executable); - Assertions.assertEquals(message, exception.getMessage()); - } - +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.util.ArrayList; +import org.hipparchus.exception.MathIllegalArgumentException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +class PolynomialModelFitterTest { + + @Test + void testExceptions() { + assertThrows(MathIllegalArgumentException.class, () -> new PolynomialModelFitter(0), "0 is smaller than the minimum (1)"); + assertThrows(IllegalArgumentException.class, () -> new PolynomialModelFitter(1).fitModel(new ArrayList<>()), "Not enough points to fit polynomial model, at least 2 points are required"); + } + + private static void assertThrows(final Class expectedType, final Executable executable, final String message) { + final T exception = Assertions.assertThrows(expectedType, executable); + Assertions.assertEquals(message, exception.getMessage()); + } + } \ No newline at end of file diff --git a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java similarity index 98% rename from hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest rename to hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java index e54ba54b0..036f3f7c8 100644 --- a/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest +++ b/hipparchus-fitting/src/test/java/org/hipparchus/fitting/ransac/RansacFitterTest.java @@ -1,127 +1,127 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.hipparchus.fitting.ransac; - -import java.io.IOException; -import java.io.InputStream; -import java.util.List; -import java.util.Random; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.hipparchus.exception.MathIllegalArgumentException; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.function.Executable; -import org.mockito.Mockito; -import org.mockito.internal.util.io.IOUtil; - -class RansacFitterTest { - - @Test - void testExceptionsOnInitialValues() { - assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), -1, 6, 1e-6, 10, 1), "-1 is smaller than the minimum (0)"); - assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, -6, 1e-6, 10, 1), "-6 is smaller than the minimum (0)"); - assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, 6, -1e-6, 10, 1), "-0 is smaller than the minimum (0)"); - assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, 6, 1e-6, -10, 1), "-10 is smaller than the minimum (0)"); - } - - @Test - void testCanFitALineWithNegligibleAmountOfNoiseAndSmallNumberOfOutliers() { - doTestLineFittingWithSmallNumberOfOutliers(7e-4, 3e-2, 0.05); - } - - @Test - void testCanPerfectlyFitALineWithoutNoiseButWithSmallNumberOfOutliers() { - doTestLineFittingWithSmallNumberOfOutliers(1e-12, 1e-12, 0.0); - } - - @Test - void testCanFitALineWithLargeNumberOfOutliers() throws IOException { - // This test reproduces the example provided in RANSAC wikipedia page. Results are strongly consistent - final List points = loadData("line_dataset.csv"); - final double standardDeviation = 0.6159842899599051; - final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(1), 10, 100, standardDeviation / 3, 10, 1).fit(points); - Assertions.assertNotNull(fitted); - Assertions.assertEquals(0.957302, getBestModel(fitted).getCoefficients()[1], 1.0e-6); - Assertions.assertEquals(-0.106412, getBestModel(fitted).getCoefficients()[0], 1.0e-6); - Assertions.assertEquals(49, fitted.getBestInliers().size()); - Assertions.assertEquals(48, fitted.filterPointsBelowThreshold(points, standardDeviation / 5).size()); // Exact number of "true" points! - } - - @Test - void testCanFitAPolynomialOfDegree2WithOutliers() throws IOException { - // Reference: https://forum.orekit.org/t/addition-of-ransac-algorithm/5102 - final List points = loadData("quadratic_dataset.csv"); - final double standardDeviation = 72.59099534185657; - final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(2), 10, 1000, standardDeviation / 3, 10, 1).fit(points); - Assertions.assertNotNull(fitted); - Assertions.assertEquals(-0.002086, getBestModel(fitted).getCoefficients()[2], 1.0e-6); - Assertions.assertEquals(1.048147, getBestModel(fitted).getCoefficients()[1], 1.0e-6); - Assertions.assertEquals(-56.274050, getBestModel(fitted).getCoefficients()[0], 1.0e-6); - Assertions.assertEquals(205, fitted.getBestInliers().size()); - Assertions.assertEquals(214, fitted.filterPointsBelowThreshold(points, standardDeviation).size()); // Exact number of "true" points! - } - - private void doTestLineFittingWithSmallNumberOfOutliers(final double slopeDelta, final double interceptDelta, final double noiseFactor) { - final double expectedSlope = 2.0; - final double expectedIntercept = 1.0; - final int numberOfTrueData = 15; - final int numberOfFalseData = 5; - final int seed = 1; - final RansacFitter ransac = new RansacFitter<>(new PolynomialModelFitter(1), 10, 500, 0.5, 10, seed); - final RansacFitterOutputs fitted = ransac.fit(generateLine(seed, expectedSlope, expectedIntercept, numberOfTrueData, numberOfFalseData, noiseFactor)); - Assertions.assertNotNull(fitted); - Assertions.assertEquals(expectedSlope, getBestModel(fitted).getCoefficients()[1], slopeDelta); - Assertions.assertEquals(expectedIntercept, getBestModel(fitted).getCoefficients()[0], interceptDelta); - Assertions.assertEquals(numberOfTrueData, fitted.getBestInliers().size()); - } - - private List generateLine(final int seed, final double expectedSlope, final double expectedIntercept, - final int trueDataCount, final int falseDataCount, final double noiseFactor) { - final Random random = new Random(seed); - final PolynomialModelFitter.Model trueModel = new PolynomialModelFitter.Model(new double[]{expectedIntercept, expectedSlope}); - final List points = IntStream.range(0, trueDataCount) - .mapToObj(x -> new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor}) - .collect(Collectors.toList()); - points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new double[]{x * 3, random.nextDouble() * 20}).collect(Collectors.toList())); - return points; - } - - private List loadData(final String fileName) { - final InputStream inputStream = this.getClass().getResourceAsStream("/" + this.getClass().getSimpleName() + "/" + fileName); - Assertions.assertNotNull(inputStream, "Could not find resource " + fileName); - return IOUtil.readLines(inputStream) - .stream() - .map(line -> line.split(",")) - .map(values -> new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])}) - .collect(Collectors.toList()); - } - - private static IModelFitter mockModel() { - return Mockito.mock(IModelFitter.class); - } - - private static void assertThrows(final Class expectedType, final Executable executable, final String message) { - final T exception = Assertions.assertThrows(expectedType, executable); - Assertions.assertEquals(message, exception.getMessage()); - } - - private PolynomialModelFitter.Model getBestModel(final RansacFitterOutputs fitted) { - return fitted.getBestModel().orElseThrow(() -> new RuntimeException("No model found")); - } - +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.hipparchus.fitting.ransac; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.hipparchus.exception.MathIllegalArgumentException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; +import org.mockito.internal.util.io.IOUtil; + +class RansacFitterTest { + + @Test + void testExceptionsOnInitialValues() { + assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), -1, 6, 1e-6, 10, 1), "-1 is smaller than the minimum (0)"); + assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, -6, 1e-6, 10, 1), "-6 is smaller than the minimum (0)"); + assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, 6, -1e-6, 10, 1), "-0 is smaller than the minimum (0)"); + assertThrows(MathIllegalArgumentException.class, () -> new RansacFitter<>(mockModel(), 1, 6, 1e-6, -10, 1), "-10 is smaller than the minimum (0)"); + } + + @Test + void testCanFitALineWithNegligibleAmountOfNoiseAndSmallNumberOfOutliers() { + doTestLineFittingWithSmallNumberOfOutliers(7e-4, 3e-2, 0.05); + } + + @Test + void testCanPerfectlyFitALineWithoutNoiseButWithSmallNumberOfOutliers() { + doTestLineFittingWithSmallNumberOfOutliers(1e-12, 1e-12, 0.0); + } + + @Test + void testCanFitALineWithLargeNumberOfOutliers() throws IOException { + // This test reproduces the example provided in RANSAC wikipedia page. Results are strongly consistent + final List points = loadData("line_dataset.csv"); + final double standardDeviation = 0.6159842899599051; + final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(1), 10, 100, standardDeviation / 3, 10, 1).fit(points); + Assertions.assertNotNull(fitted); + Assertions.assertEquals(0.957302, getBestModel(fitted).getCoefficients()[1], 1.0e-6); + Assertions.assertEquals(-0.106412, getBestModel(fitted).getCoefficients()[0], 1.0e-6); + Assertions.assertEquals(49, fitted.getBestInliers().size()); + Assertions.assertEquals(48, fitted.filterPointsBelowThreshold(points, standardDeviation / 5).size()); // Exact number of "true" points! + } + + @Test + void testCanFitAPolynomialOfDegree2WithOutliers() throws IOException { + // Reference: https://forum.orekit.org/t/addition-of-ransac-algorithm/5102 + final List points = loadData("quadratic_dataset.csv"); + final double standardDeviation = 72.59099534185657; + final RansacFitterOutputs fitted = new RansacFitter<>(new PolynomialModelFitter(2), 10, 1000, standardDeviation / 3, 10, 1).fit(points); + Assertions.assertNotNull(fitted); + Assertions.assertEquals(-0.002086, getBestModel(fitted).getCoefficients()[2], 1.0e-6); + Assertions.assertEquals(1.048147, getBestModel(fitted).getCoefficients()[1], 1.0e-6); + Assertions.assertEquals(-56.274050, getBestModel(fitted).getCoefficients()[0], 1.0e-6); + Assertions.assertEquals(205, fitted.getBestInliers().size()); + Assertions.assertEquals(214, fitted.filterPointsBelowThreshold(points, standardDeviation).size()); // Exact number of "true" points! + } + + private void doTestLineFittingWithSmallNumberOfOutliers(final double slopeDelta, final double interceptDelta, final double noiseFactor) { + final double expectedSlope = 2.0; + final double expectedIntercept = 1.0; + final int numberOfTrueData = 15; + final int numberOfFalseData = 5; + final int seed = 1; + final RansacFitter ransac = new RansacFitter<>(new PolynomialModelFitter(1), 10, 500, 0.5, 10, seed); + final RansacFitterOutputs fitted = ransac.fit(generateLine(seed, expectedSlope, expectedIntercept, numberOfTrueData, numberOfFalseData, noiseFactor)); + Assertions.assertNotNull(fitted); + Assertions.assertEquals(expectedSlope, getBestModel(fitted).getCoefficients()[1], slopeDelta); + Assertions.assertEquals(expectedIntercept, getBestModel(fitted).getCoefficients()[0], interceptDelta); + Assertions.assertEquals(numberOfTrueData, fitted.getBestInliers().size()); + } + + private List generateLine(final int seed, final double expectedSlope, final double expectedIntercept, + final int trueDataCount, final int falseDataCount, final double noiseFactor) { + final Random random = new Random(seed); + final PolynomialModelFitter.Model trueModel = new PolynomialModelFitter.Model(new double[]{expectedIntercept, expectedSlope}); + final List points = IntStream.range(0, trueDataCount) + .mapToObj(x -> new double[]{x, trueModel.predict(x) + random.nextGaussian() * noiseFactor}) + .collect(Collectors.toList()); + points.addAll(IntStream.range(0, falseDataCount).mapToObj(x -> new double[]{x * 3, random.nextDouble() * 20}).collect(Collectors.toList())); + return points; + } + + private List loadData(final String fileName) { + final InputStream inputStream = this.getClass().getResourceAsStream("/" + this.getClass().getSimpleName() + "/" + fileName); + Assertions.assertNotNull(inputStream, "Could not find resource " + fileName); + return IOUtil.readLines(inputStream) + .stream() + .map(line -> line.split(",")) + .map(values -> new double[]{Double.parseDouble(values[0]), Double.parseDouble(values[1])}) + .collect(Collectors.toList()); + } + + private static IModelFitter mockModel() { + return Mockito.mock(IModelFitter.class); + } + + private static void assertThrows(final Class expectedType, final Executable executable, final String message) { + final T exception = Assertions.assertThrows(expectedType, executable); + Assertions.assertEquals(message, exception.getMessage()); + } + + private PolynomialModelFitter.Model getBestModel(final RansacFitterOutputs fitted) { + return fitted.getBestModel().orElseThrow(() -> new RuntimeException("No model found")); + } + } \ No newline at end of file