For introduction of kun, please refer to the homepage.

We already have the Chinese version here and the English homepage still hasn't finished its translation yet.

TL; DR

factory.png

The programming language saving our life

The first version of kun has been implemented with PHP (for web) and Python (for calculation) for simplicity. However, when deployed on-site, we suffered from two main problems:

This means we need another strongly typed programming language with high performance: for the sake of safe concurrency and memory safety, we finally chose Rust.

The Framework flooding me with desire

I used to be an algorithm engineer and have performed deep learning experiments with PyTorch and TensorFlow for many years.

To someone who is new to deep learning, one of the best ways to understand it is to learn concepts rather than frameworks. With this respect, both PyTorch and TensorFlow v2 (Keras) could be helpful. And thanks to Keras, one may feel that TensorFlow is even more simpler than PyTorch since the latter force you to define the training loop, while the PyTorch would be more popular with researchers as they can take finer control over each part during the experiments.

When we turn to Rust, which framework should we use to perform deep learning?

There are several options: Burn, Candle or tch-rs.

Typically the R&D process of a deep learning app can be roughly divided into two stages:

Both Candle and tch-rs have their own limitations:

When it comes to Burn, all the inconveniences disappeared.

Since almost all tensor operations have equivalents in PyTorch, Burn has a gradual learning curve, which greatly reduces the consumption of programmers. And even in the very early stages of development, Burn's performance on CPU was slightly better than that of PyTorch1.

We have servers with AMD GPUs in company for performing experiments, but perfer to use CPU when deployed on-site to save budget. Burn is so portable by supporting a wide variety of backends that we can easily implement two versions via Rust's conditional compilation feature like:

#[cfg(feature = "on-site")]
pub fn build() {
    training::<Autodiff<Candle>>(
        dataset,
        // ...
    )?;
}
#[cfg(feature = "dev")]
pub fn build() {
    training::<Autodiff<Wgpu<f32, f32>>>(
        dataset,
        WgpuDevice::default(),
        // ...
    )?;
}

Burn also provides extreme flexibility. Just like PyTorch and Tensorflow allows, the training loop can also be customized with Burn. It's so useful as the loss with its timestamps needs to be saved in real time to the time-series database. On the other hand, we have a regression model as well as a classification model with exactly the same training process. The code would be very verbose and difficult to maintain with other frameworks. Thanks to Rust's trait bound feature and some of Burn's pre-defined traits, we can only implement the training loop2 once for both purposes:

pub(super) fn fit<M, B: AutodiffBackend>(
    model: &mut M,
    model_type: ModelType,
    dataset: Vec<BatteryItem>,
    device: B::Device,
    // ...
) -> Result<()>
where
    M: AutodiffModule<B> + Forward<B>,
    M::InnerModule: Forward<B::InnerBackend>,
{
    let config = TrainingConfig::new();
    let mut optim = AdamWConfig::new().init();
    let train_dataset = BatteryDataset::train(&dataset);
    let test_dataset = BatteryDataset::val(&dataset);

    let batcher_train = BatteryBatcher::<B>::new(device.clone());
    let batcher_test = BatteryBatcher::<B::InnerBackend>::new(device);

    let dataloader_train = DataLoaderBuilder::new(batcher_train)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(train_dataset);

    let dataloader_test = DataLoaderBuilder::new(batcher_test)
        .batch_size(config.batch_size)
        .shuffle(config.seed)
        .num_workers(config.num_workers)
        .build(test_dataset);

    let mut fit_step =
        |batch_name: Arc<str>,
        // ..
         model: &mut M| {
            // ...
            let mut last_accuracy = None;
            let loss_vec = dataloader_train
                .iter()
                .map(|batch| {
                    let output = model.forward(batch.features);
                    let loss = model_type.train_step(output, batch.targets);
                    let grads = loss.backward();
                    let grads = GradientsParams::from_grads(grads, model);
                    *model = optim.step(config.learning_rate, model.clone(), grads);
                    loss
                })
                .collect();

            let model_valid = model.valid();
            let loss_valid_vec = dataloader_test
                .iter()
                .map(|batch| {
                    let output = model_valid.forward(batch.features);
                    let (loss, accuracy) = model_type.valid_step(output, batch.targets);
                    last_accuracy = accuracy;
                    loss
                })
                .collect();

            let loss = cal_mean_loss(loss_vec);
            // ..
        };

    let mut best_loss = fit_step(batch_name.clone(), model, td_sender.clone());
    // some special operations on the first epoch
    let mut best_epoch = 0;
    let mut best_model = (*model).clone();
    for epoch in 1..config.num_epochs {
        if epoch - best_epoch < config.patience {
            let valid_loss = fit_step(batch_name.clone(), model, td_sender.clone());
            if best_loss - valid_loss > config.min_delta {
                best_epoch = epoch;
                best_loss = valid_loss;
                best_model = model.clone();
            }
        } else {
            break;
        }
    }
    *model = best_model;
    Ok(())
}

From the code snippet you can also see that it is very convenient to have a custom early stop3.

Then the only thing left is to define the model (architecture and forward process) and metrics for each model.

Feel easy, right? Don't hesitate to join Burn's community!

1

See https://burn.dev/blog/burn-rusty-approach-to-tensor-handling. Burn has even higher performance by new backends nowadays, to check the progresses: https://burn.dev/blog

3

One can easily add an early stopping strategy to Burn's pre-defined Learner