For introduction of kun
, please refer to the homepage.
We already have the Chinese version here and the English homepage still hasn't finished its translation yet.
TL; DR
- Both Burn and Rust are very suitable for IoT (especially for industry) scenarios.
- If you have a basic understanding of Rust and have worked on deep learning development, just feel easy to use Burn instead.
- Burn is portable: your team can switch between R&D and production environments even by changing one argument.
- Burn is flexible: your code can be very concise. Project maintenance is no longer a difficult task.
The programming language saving our life
The first version of kun
has been implemented with PHP (for web) and Python (for calculation) for simplicity.
However, when deployed on-site, we suffered from two main problems:
- Poor performance: The HTTP server by PHP always need to call a Python script during each request for inference.
- Weak robustness: Though our program has almost 100% line coverage by unit test, some scenarios (like an empty tray without any battery) still haven't been covered during programming, which sometimes caused runtime error.
This means we need another strongly typed programming language with high performance: for the sake of safe concurrency and memory safety, we finally chose Rust.
The Framework flooding me with desire
I used to be an algorithm engineer and have performed deep learning experiments with PyTorch and TensorFlow for many years.
To someone who is new to deep learning, one of the best ways to understand it is to learn concepts rather than frameworks. With this respect, both PyTorch and TensorFlow v2 (Keras) could be helpful. And thanks to Keras, one may feel that TensorFlow is even more simpler than PyTorch since the latter force you to define the training loop, while the PyTorch would be more popular with researchers as they can take finer control over each part during the experiments.
When we turn to Rust, which framework should we use to perform deep learning?
There are several options: Burn, Candle or tch-rs.
Typically the R&D process of a deep learning app can be roughly divided into two stages:
- Research stage: Model architecture design and parameter tuning by experiments
- Development stage: Combine the model functionality and business logic
Both Candle and tch-rs have their own limitations:
- Candle: Though Candle can be extremly fast, it only provides minimalistic APIs, so it may not be suitable for experiments during the research stage.
- tch-rs: tch-rs provides Rust bindings for PyTorch, so it allows developers to leverage the extensive PyTorch ecosystem, making it an ideal framework for experiments. But its dependency of libtorch could be a burden during the development stage.
When it comes to Burn, all the inconveniences disappeared.
Since almost all tensor operations have equivalents in PyTorch, Burn has a gradual learning curve, which greatly reduces the consumption of programmers. And even in the very early stages of development, Burn's performance on CPU was slightly better than that of PyTorch1.
We have servers with AMD GPUs in company for performing experiments, but perfer to use CPU when deployed on-site to save budget. Burn is so portable by supporting a wide variety of backends that we can easily implement two versions via Rust's conditional compilation feature like:
#[cfg(feature = "on-site")]
pub fn build() {
training::<Autodiff<Candle>>(
dataset,
// ...
)?;
}
#[cfg(feature = "dev")]
pub fn build() {
training::<Autodiff<Wgpu<f32, f32>>>(
dataset,
WgpuDevice::default(),
// ...
)?;
}
Burn also provides extreme flexibility. Just like PyTorch and Tensorflow allows, the training loop can also be customized with Burn. It's so useful as the loss with its timestamps needs to be saved in real time to the time-series database. On the other hand, we have a regression model as well as a classification model with exactly the same training process. The code would be very verbose and difficult to maintain with other frameworks. Thanks to Rust's trait bound feature and some of Burn's pre-defined traits, we can only implement the training loop2 once for both purposes:
pub(super) fn fit<M, B: AutodiffBackend>(
model: &mut M,
model_type: ModelType,
dataset: Vec<BatteryItem>,
device: B::Device,
// ...
) -> Result<()>
where
M: AutodiffModule<B> + Forward<B>,
M::InnerModule: Forward<B::InnerBackend>,
{
let config = TrainingConfig::new();
let mut optim = AdamWConfig::new().init();
let train_dataset = BatteryDataset::train(&dataset);
let test_dataset = BatteryDataset::val(&dataset);
let batcher_train = BatteryBatcher::<B>::new(device.clone());
let batcher_test = BatteryBatcher::<B::InnerBackend>::new(device);
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(train_dataset);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(test_dataset);
let mut fit_step =
|batch_name: Arc<str>,
// ..
model: &mut M| {
// ...
let mut last_accuracy = None;
let loss_vec = dataloader_train
.iter()
.map(|batch| {
let output = model.forward(batch.features);
let loss = model_type.train_step(output, batch.targets);
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, model);
*model = optim.step(config.learning_rate, model.clone(), grads);
loss
})
.collect();
let model_valid = model.valid();
let loss_valid_vec = dataloader_test
.iter()
.map(|batch| {
let output = model_valid.forward(batch.features);
let (loss, accuracy) = model_type.valid_step(output, batch.targets);
last_accuracy = accuracy;
loss
})
.collect();
let loss = cal_mean_loss(loss_vec);
// ..
};
let mut best_loss = fit_step(batch_name.clone(), model, td_sender.clone());
// some special operations on the first epoch
let mut best_epoch = 0;
let mut best_model = (*model).clone();
for epoch in 1..config.num_epochs {
if epoch - best_epoch < config.patience {
let valid_loss = fit_step(batch_name.clone(), model, td_sender.clone());
if best_loss - valid_loss > config.min_delta {
best_epoch = epoch;
best_loss = valid_loss;
best_model = model.clone();
}
} else {
break;
}
}
*model = best_model;
Ok(())
}
From the code snippet you can also see that it is very convenient to have a custom early stop3.
Then the only thing left is to define the model (architecture and forward process) and metrics for each model.
Feel easy, right? Don't hesitate to join Burn's community!
See https://burn.dev/blog/burn-rusty-approach-to-tensor-handling. Burn has even higher performance by new backends nowadays, to check the progresses: https://burn.dev/blog
Burn also supports custom Learner type: https://burn.dev/book/custom-training-loop.html#custom-type
One can easily add an early stopping strategy to Burn's pre-defined Learner