1 unstable release
new 0.1.0-alpha.1 | Apr 12, 2025 |
---|
#9 in #evaluation
1MB
16K
SLoC
SciRS2 Metrics
Evaluation metrics module for the SciRS2 scientific computing library. This module provides functions to evaluate prediction performance for classification, regression, and clustering tasks.
Features
- Classification Metrics: Accuracy, precision, recall, F1-score, ROC curves, AUC, etc.
- Regression Metrics: MSE, MAE, R2 score, explained variance, etc.
- Clustering Metrics: Silhouette score, Calinski-Harabasz index, Davies-Bouldin index, etc.
- General Evaluation: Cross-validation, learning curves, confusion matrices
Usage
Add the following to your Cargo.toml
:
[dependencies]
scirs2-metrics = { workspace = true }
Basic usage examples:
use scirs2_metrics::{classification, regression, clustering};
use ndarray::{Array1, array};
use scirs2_core::error::CoreResult;
// Classification metrics example
fn classification_metrics_example() -> CoreResult<()> {
// True labels
let y_true = array![0, 1, 0, 1, 0, 1, 0, 1];
// Predicted labels
let y_pred = array![0, 1, 1, 1, 0, 0, 0, 1];
// Calculate classification metrics
let accuracy = classification::accuracy_score(&y_true, &y_pred)?;
let precision = classification::precision_score(&y_true, &y_pred, None, None, None)?;
let recall = classification::recall_score(&y_true, &y_pred, None, None, None)?;
let f1 = classification::f1_score(&y_true, &y_pred, None, None, None)?;
println!("Accuracy: {}", accuracy);
println!("Precision: {}", precision);
println!("Recall: {}", recall);
println!("F1 Score: {}", f1);
// Predicted probabilities for ROC curve
let y_scores = array![0.1, 0.9, 0.8, 0.7, 0.2, 0.3, 0.4, 0.8];
// Calculate ROC curve
let (fpr, tpr, thresholds) = classification::roc_curve(&y_true, &y_scores, None, None)?;
// Calculate Area Under the ROC Curve (AUC)
let auc = classification::roc_auc_score(&y_true, &y_scores)?;
println!("AUC: {}", auc);
Ok(())
}
// Regression metrics example
fn regression_metrics_example() -> CoreResult<()> {
// True values
let y_true = array![3.0, -0.5, 2.0, 7.0, 2.0];
// Predicted values
let y_pred = array![2.5, 0.0, 2.1, 7.8, 1.8];
// Calculate regression metrics
let mse = regression::mean_squared_error(&y_true, &y_pred, None)?;
let mae = regression::mean_absolute_error(&y_true, &y_pred, None)?;
let r2 = regression::r2_score(&y_true, &y_pred, None)?;
let explained_variance = regression::explained_variance_score(&y_true, &y_pred, None)?;
println!("Mean Squared Error: {}", mse);
println!("Mean Absolute Error: {}", mae);
println!("R² Score: {}", r2);
println!("Explained Variance: {}", explained_variance);
Ok(())
}
// Clustering metrics example
fn clustering_metrics_example() -> CoreResult<()> {
// Sample data points
let data = array![
[1.0, 2.0],
[1.5, 1.8],
[5.0, 8.0],
[8.0, 8.0],
[1.0, 0.6],
[9.0, 11.0]
];
// Cluster labels
let labels = array![0, 0, 1, 1, 0, 1];
// Calculate clustering metrics
let silhouette = clustering::silhouette_score(&data, &labels, None, None)?;
let calinski_harabasz = clustering::calinski_harabasz_score(&data, &labels)?;
let davies_bouldin = clustering::davies_bouldin_score(&data, &labels)?;
println!("Silhouette Score: {}", silhouette);
println!("Calinski-Harabasz Index: {}", calinski_harabasz);
println!("Davies-Bouldin Index: {}", davies_bouldin);
Ok(())
}
Components
Classification Metrics
Functions for classification evaluation:
use scirs2_metrics::classification::{
// Basic Metrics
accuracy_score, // Calculate accuracy score
precision_score, // Calculate precision score
recall_score, // Calculate recall score
f1_score, // Calculate F1 score
fbeta_score, // Calculate F-beta score
precision_recall_fscore_support, // Calculate precision, recall, F-score, and support
// Multi-class and Multi-label Metrics
jaccard_score, // Calculate Jaccard similarity coefficient
hamming_loss, // Calculate Hamming loss
// Probability-based Metrics
log_loss, // Calculate logarithmic loss
brier_score_loss, // Calculate Brier score loss
// ROC and AUC
roc_curve, // Calculate Receiver Operating Characteristic (ROC) curve
roc_auc_score, // Calculate Area Under the ROC Curve (AUC)
average_precision_score, // Calculate average precision score
precision_recall_curve, // Calculate precision-recall curve
// Confusion Matrix and Derived Metrics
confusion_matrix, // Calculate confusion matrix
classification_report, // Generate text report of classification metrics
// Probabilities to Labels
binarize, // Transform probabilities to binary labels
label_binarize, // Transform multi-class labels to binary labels
// Other Metrics
cohen_kappa_score, // Calculate Cohen's kappa
matthews_corrcoef, // Calculate Matthews correlation coefficient
hinge_loss, // Calculate hinge loss
};
Regression Metrics
Functions for regression evaluation:
use scirs2_metrics::regression::{
// Error Metrics
mean_squared_error, // Calculate mean squared error
mean_absolute_error, // Calculate mean absolute error
mean_absolute_percentage_error, // Calculate mean absolute percentage error
median_absolute_error, // Calculate median absolute error
max_error, // Calculate maximum error
// Goodness of Fit
r2_score, // Calculate R² score (coefficient of determination)
explained_variance_score, // Calculate explained variance score
// Other Metrics
mean_squared_log_error, // Calculate mean squared logarithmic error
mean_poisson_deviance, // Calculate mean Poisson deviance
mean_gamma_deviance, // Calculate mean Gamma deviance
mean_tweedie_deviance, // Calculate mean Tweedie deviance
};
Clustering Metrics
Functions for clustering evaluation:
use scirs2_metrics::clustering::{
// Internal Metrics (no ground truth)
silhouette_score, // Calculate Silhouette Coefficient
calinski_harabasz_score, // Calculate Calinski-Harabasz Index
davies_bouldin_score, // Calculate Davies-Bouldin Index
// External Metrics (with ground truth)
adjusted_rand_score, // Calculate Adjusted Rand Index
normalized_mutual_info_score, // Calculate normalized mutual information
adjusted_mutual_info_score, // Calculate adjusted mutual information
fowlkes_mallows_score, // Calculate Fowlkes-Mallows Index
// Contingency Matrix
contingency_matrix, // Calculate contingency matrix
pair_confusion_matrix, // Calculate pair confusion matrix
};
Evaluation Functions
General evaluation tools:
use scirs2_metrics::evaluation::{
// Cross-validation
cross_val_score, // Evaluate a score by cross-validation
cross_validate, // Evaluate metrics by cross-validation
// Train/Test Splitting
train_test_split, // Split arrays into random train and test subsets
// Learning Curves
learning_curve, // Generate learning curve
validation_curve, // Generate validation curve
// Hyperparameter Optimization
grid_search_cv, // Exhaustive search over parameter grid
randomized_search_cv, // Random search over parameters
};
Contributing
See the CONTRIBUTING.md file for contribution guidelines.
License
This project is licensed under the Apache License, Version 2.0 - see the LICENSE file for details.
Dependencies
~7.5MB
~132K SLoC