4 releases
0.1.3 | Jan 24, 2024 |
---|---|
0.1.2 | Nov 7, 2023 |
0.1.1 | May 15, 2019 |
0.1.0 | Dec 19, 2018 |
#46 in Machine learning
1,073 downloads per month
Used in 4 crates
(2 directly)
175KB
2.5K
SLoC
MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone
MesaTEE GBDT-RS is a gradient boost decision tree library written in Safe Rust. There is no unsafe rust code in the library.
MesaTEE GBDT-RS provides the training and inference capabilities. And it can use the models trained by xgboost to do inference tasks.
New! The MesaTEE GBDT-RS paper has been accepted by IEEE S&P'19!
Supported Task
Supppoted task for both training and inference
- Linear regression: use SquaredError and LAD loss types
- Binary classification (labeled with 1 and -1): use LogLikelyhood loss type
Compatibility with xgboost
At this time, MesaTEE GBDT-RS support to use model trained by xgboost to do inference. The model should be trained by xgboost with following configruation:
- booster: gbtree
- objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" or "rank:pairwise".
We have tested that MesaTEE GBDT-RS is compatible with xgboost 0.81 and 0.82
Quick Start
Training Steps
- Set configuration
- Load training data
- Train the model
- (optional) Save the model
Inference Steps
- Load the model
- Load the test data
- Inference the test data
Example
use gbdt::config::Config;
use gbdt::decision_tree::{DataVec, PredVec};
use gbdt::gradient_boost::GBDT;
use gbdt::input::{InputFormat, load};
let mut cfg = Config::new();
cfg.set_feature_size(22);
cfg.set_max_depth(3);
cfg.set_iterations(50);
cfg.set_shrinkage(0.1);
cfg.set_loss("LogLikelyhood");
cfg.set_debug(true);
cfg.set_data_sample_ratio(1.0);
cfg.set_feature_sample_ratio(1.0);
cfg.set_training_optimization_level(2);
// load data
let train_file = "dataset/agaricus-lepiota/train.txt";
let test_file = "dataset/agaricus-lepiota/test.txt";
let mut input_format = InputFormat::csv_format();
input_format.set_feature_size(22);
input_format.set_label_index(22);
let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");
// train and save model
let mut gbdt = GBDT::new(&cfg);
gbdt.fit(&mut train_dv);
gbdt.save_model("gbdt.model").expect("failed to save the model");
// load model and do inference
let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
let predicted: PredVec = model.predict(&test_dv);
Example code
- Linear regression: examples/iris.rs
- Binary classification: examples/agaricus-lepiota.rs
Use models trained by xgboost
Steps
- Use xgboost to train a model
- Use examples/convert_xgboost.py to convert the model
- Usage: python convert_xgboost.py xgboost_model_path objective output_path
- Note convert_xgboost.py depends on xgboost python libraries. The converted model can be used on machines without xgboost
- In rust code, call GBDT::load_from_xgboost(model_path, objective) to load the model
- Do inference
- (optional) Call GBDT::save_model to save the model to MesaTEE GBDT-RS native format.
Example code
- "reg:linear": examples/test-xgb-reg-linear.rs
- "reg:logistic": examples/test-xgb-reg-logistic.rs
- "binary:logistic": examples/test-xgb-binary-logistic.rs
- "binary:logitraw": examples/test-xgb-binary-logistic.rs
- "multi:softprob": examples/test-xgb-multi-softprob.rs
- "multi:softmax": examples/test-xgb-multi-softmax.rs
- "rank:pairwise": examples/test-xgb-rank-pairwise.rs
Multi-threading
Training:
At this time, training in MesaTEE GBDT-RS is single-threaded.
Inference:
The related inference functions are single-threaded. But they are thread-safe. We provide an inference example using multi threads in example/test-multithreads.rs
SGX usage
Because MesaTEE GBDT-RS is written in pure rust, with the help of rust-sgx-sdk, it can be used in sgx enclave easily as:
gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }
This would import a crate named gbdt_sgx
. If you prefer gbdt
as normal:
gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }
For more information and concret examples, please look at directory sgx/gbdt-sgx-test
.
License
Apache 2.0
Authors
Tianyi Li @n0b0dyCN n0b0dypku@gmail.com
Tongxin Li @litongxin1991 litongxin1991@gmail.com
Yu Ding @dingelish dingelish@gmail.com
Steering Committee
Tao Wei, Yulong Zhang
Acknowledgment
Thanks to @qiyiping for his/her great previous work gbdt. We read his/her code before starting this project.
Dependencies
~0.9–2.3MB
~46K SLoC