Rust and Node.js: Harmonizing Performance and Safety

Prelude

In the Rust world, the interaction between Python and Rust is very well-known through the amazing PyO3 ecosystem. There is a similar relation between Python and Javascript in particular Node.js that I’m going to describe in this post. All the code is available here.

Most programming language interactions happen through C layer ABI i.e. FFI. However, interacting Rust with JavaScript is commonly achieved through WebAssembly (WASM). Furthermore, Node.js (written in C++) addon-api offers writing extending Node functionalities through C++ (FFI) without stepping into the WASM and the Rust ecosystem has created two frameworks on top

We are going to explore neon as well since it is also the more mature alternative.

As a quick recap, companies such as 1Password and Signal have adopted Rust in their Node applications, and more recently, a number of other companies like LogRocket and RisingStack have supercharged their Node apps. They’ve achieved this by delegating critical parts to Rust where Node.js falls short. Consequently, Rust enhances these applications with its memory and type safety, while also being more efficient in CPU and memory usage. This leads to orders of magnitude higher Requests Per Second (RPS), showcasing Rust’s robust capabilities in optimizing performance.

I’m assuming you have working Rust toolchain , NPM, Node.js. Then install neon module with npm i -g neon-cli. First, the “hello, world!”

  1. npm init neon hello creates
.
├── Cargo.toml
├── README.md
├── package.json
└── src
    └── lib.rs
2 directories, 4 files

The content of the preloaded src/lib.rs is

use neon::prelude::*;
fn hello(mut cx: FunctionContext) -> JsResult<JsString> {
    Ok(cx.string("hello node"))
}
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
    cx.export_function("hello", hello)?;

2. npm install and

3. Run node prompt and

> require('.').hello()
 'hello node'

Super easy! All the compilation dependencies are included that calls for high DevX. For more details, check out the official neon documentation.

Cheat Table

Rust Neon Construct Description Example Usage
neon::prelude::* Imports the essential traits and types for Neon modules. use neon::prelude::*;
FunctionContext<'a> Represents the execution context of a JavaScript function call. fn my_function(mut cx: FunctionContext) -> JsResult<JsValue> { ... }
JsResult<T> A result type for Neon functions, either Ok(T) or Err(Throw). fn my_function(...) -> JsResult<JsString> { ... }
JsValue Represents any JavaScript value. let js_value: Handle<JsValue> = cx.argument(0)?;
JsString, JsNumber, etc. Specific JavaScript value types. let js_string: Handle<JsString> = cx.argument(0)?;
Handle<T> A handle to a JavaScript value, keeping it alive across the JS-Rust boundary. let handle: Handle<JsString> = cx.argument(0)?;
ModuleContext<'a> Represents the context of a module during initialization. fn neon_module_init(cx: ModuleContext) -> NeonResult<()?> { ... }
register_module! A macro to register the module with Node.js. register_module!(cx, neon_module_init);
JsArray, JsObject Types for JavaScript arrays and objects. let js_array: Handle<JsArray> = JsArray::new(&mut cx, 3);
.to_string(), .to_number(), etc. Methods to convert Neon types to JavaScript types. let js_num = cx.number(42.0).upcast<JsValue>();
cx.argument<T>(i) Retrieves the ith argument of a function call. let arg0: Handle<JsString> = cx.argument<JsString>(0)?;
cx.throw_error() Throws a JavaScript error from Rust. cx.throw_error("Something went wrong")?;
cx.borrow(), cx.borrow_mut() Borrows a reference to a JavaScript value. let guard = cx.borrow(&js_array);

Resize image example

For another example, let’s say you want to enhance resizing jpeg image.

  1. npm init neon image-resize and cd image-resize
  2. Then cargo add image and
  3. Include the following in src/lib.rs. Check out the comments below

use std::io::Cursor;
use image;
use neon::{prelude::*, types::buffer::TypedArray};
fn resize_image(mut cx: FunctionContext) -> JsResult<JsBuffer> {
    // Retrieve image buffer and dimensions from JavaScript arguments
    let buffer = cx.argument::<JsBuffer>(0)?; // <- gets the first argument in node
    let width = cx.argument::<JsNumber>(1)?.value(&mut cx) as u32;
    let height = cx.argument::<JsNumber>(2)?.value(&mut cx) as u32;
    // Convert JS Buffer to a byte slice
    let image_data: &[u8] = buffer.as_slice(&cx);
    // Perform image resizing in rust
    let img = image::load_from_memory(&image_data).expect("Failed to load image from memory");
    let resized = img.resize(width, height, image::imageops::FilterType::Nearest);
    let mut resized_buffer = Cursor::new(Vec::new());
    resized
        .write_to(&mut resized_buffer, image::ImageOutputFormat::Jpeg(100))
        .expect("Failed to write image to buffer");
    let img_data = resized_buffer.into_inner();
    // Convert the byte vector back to a JS Buffer
    let js_buffer = JsBuffer::external(&mut cx, img_data);
    Ok(js_buffer)
}
// finally export the function as a module
register_module!(mut cx, { cx.export_function("resizeImage", resize_image) });

4. npm install

5. Run node prompt and test with an image

const nativeModule = require('.');
const fs = require('fs');
let imageBuf = fs.readFileSync('cat.jpeg');
const resizedBuf = nativeModule.resizeImage(imageBuf, 50, 50);
fs.writeFileSync('resized_cat.jpeg', resizedBuf);

neon also provides a lot of “promise-api”, “task-api” to handle js promise and node worker pool jobs.

Segment-Anything in Node

Our examples won’t be complete with some deep learning heavy computation. We use Segment-Anything which is a state-of-the-art segmentation model by Meta. It produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. We will using Huggingface Candle Rust Deep Learning library as the final example that delegates heavy lifting to Rust.

  1. npm init neon sam-node and cd sam-node
  2. Include the dependencies in Cargo.toml as follows

[dependencies]
anyhow = "1.0.75"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1", feautures = [
    "accelerate",
] }
candle-examples = { git = "https://github.com/huggingface/candle.git", version = "0.3.1", feautures = [
    "accelerate",
] }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1", feautures = [
    "accelerate",
] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1", feautures = [
    "accelerate",
] }
hf-hub = "0.3.2"
image = "0.24.7"
imageproc = "0.23.0"

and make sure you may need to set Huggingface API key if you haven’t.

3. Now in src/lib.rs

fn generate_sam(mut cx: FunctionContext) -> JsResult<JsString> {
    let image_path = cx.argument::<JsString>(0)?.value(&mut cx);
    let points = cx.argument::<JsArray>(1)?;
    let neg_points = cx.argument::<JsArray>(2)?;
    let points = get_points(&mut cx, points);
    let neg_points = get_points(&mut cx, neg_points);
    _generate_sam(image_path, points, neg_points).expect("error generating sam");
    Ok(cx.string("ok"))
}
fn get_points(cx: &mut FunctionContext, handle: Handle<JsArray>) -> Vec<String> {
    let points = handle.to_vec(cx).expect("error converting to vec");
    let mut ret: Vec<String> = Vec::new();
    for point in points {
        let point_string = point
            .downcast::<JsString, FunctionContext>(cx)
            .or_else(|_| cx.throw_error("Array element is not a string"))
            .unwrap()
            .value(cx);
        ret.push(point_string);
    }
    // return as a vec but a single contiguous string
    ret = vec![ret.join(",")];
    ret
}

where

fn _generate_sam(
    image_path: String,
    points: Vec<String>,
    neg_points: Vec<String>,
) -> anyhow::Result<()> {
    let device = candle_examples::device(true)?; // use CPU
    let (image, initial_h, initial_w) =
        candle_examples::load_image(&image_path, Some(sam::IMAGE_SIZE))?;
    let image = image.to_device(&device)?;
    println!("loaded image {image:?}");
    let api = hf_hub::api::sync::Api::new()?;
    let api = api.model("lmz/candle-sam".to_string());
    let filename = "mobile_sam-tiny-vitt.safetensors";
    let model = api.get(filename)?;
    let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
    let sam = sam::Sam::new_tiny(vb)?;
    // Default options similar to the Python version.
    let bboxes = sam.generate_masks(
        &image,
        /* points_per_side */ 32,
        /* crop_n_layer */ 0,
        /* crop_overlap_ratio */ 512. / 1500.,
        /* crop_n_points_downscale_factor */ 1,
    )?;
    for (idx, bbox) in bboxes.iter().enumerate() {
        println!("{idx} {bbox:?}");
        let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
        let (h, w) = mask.dims2()?;
        let mask = mask.broadcast_as((3, h, w))?;
        candle_examples::save_image_resize(
            &mask,
            format!("sam_mask{idx}.png"),
            initial_h,
            initial_w,
        )?;
    }
    let iter_points = points.iter().map(|p| (p, true));
    let iter_neg_points = neg_points.iter().map(|p| (p, false));
    let points = iter_points
        .chain(iter_neg_points)
        .map(|(point, b)| {
            use std::str::FromStr;
            let xy = point.split(',').collect::<Vec<_>>();
            if xy.len() != 2 {
                anyhow::bail!("expected format for points is 0.4,0.2")
            }
            Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
        })
        .collect::<anyhow::Result<Vec<_>>>()?;
    let start_time = std::time::Instant::now();
    let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
    println!(
        "mask generated in {:.2}s",
        start_time.elapsed().as_secs_f32()
    );
    println!("mask:\n{mask}");
    println!("iou_predictions: {iou_predictions}");
    let mask = (mask.ge(0.)? * 255.)?;
    let (_one, h, w) = mask.dims3()?;
    let mask = mask.expand((3, h, w))?;
    let mut img = image::io::Reader::open(&image_path)?
        .decode()
        .map_err(candle::Error::wrap)?;
    let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
    let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
        match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
            Some(image) => image,
            None => anyhow::bail!("error saving merged image"),
        };
    let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
        img.width(),
        img.height(),
        image::imageops::FilterType::CatmullRom,
    );
    for x in 0..img.width() {
        for y in 0..img.height() {
            let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
            if mask_p.0[0] > 100 {
                let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
                img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
                img_p.0[1] /= 2;
                img_p.0[0] /= 2;
                imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
            }
        }
    }
    for (x, y, b) in points {
        let x = (x * img.width() as f64) as i32;
        let y = (y * img.height() as f64) as i32;
        let color = if b {
            image::Rgba([255, 0, 0, 200])
        } else {
            image::Rgba([0, 255, 0, 200])
        };
        imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
    }
    img.save("sam_merged.jpg")?;
    Ok(())
}
// finally register as node module
register_module!(mut cx, { cx.export_function("generateSam", generate_sam) });

Next, after sam-node/index.js we add simple Node.js express server. (Note the dependencies npm install axios express)

const express = require('express');
const nativeModule = require('.');
const app = express();
const port = 3000;
const axios = require('axios');
const fs = require('fs');
const path = require('path');
async function downloadImage(url, filePath) {
  try {
    const response = await axios({
      method: 'GET',
      url: url,
      responseType: 'stream'
    });
    const writer = fs.createWriteStream(filePath);
    response.data.pipe(writer);
    return new Promise((resolve, reject) => {
      writer.on('finish', resolve);
      writer.on('error', reject);
    });
  } catch (error) {
    console.error('Error downloading the image:', error);
    throw error;
  }
}
app.use(express.json());
app.listen(port, () => console.log(`Listening on port ${port}`));
app.post("/generate-sam", async (req, res) => {
    try {
        const { imagUrl, points, negPoints } = req.body;
        const name = imagUrl.split('/').pop();
        const filePath = path.join(__dirname, name);
        await downloadImage(imagUrl, filePath);
        await nativeModule.generateSam(filePath, points, negPoints);
    } catch (error) {
        console.log(error);
        res.status(500).send(error);
    }
});

Then node index.js and can test with npm install node-fetch script in test.js

that downloads the sample JPG image

// testing using node-fetch to send a POST request to the server
const url = 'http://localhost:3000/generate-sam';
const imageUrl = 'https://githubraw.com/huggingface/candle/main/candle-examples/examples/yolo-v8/assets/bike.jpg';
const points = ['0.6', '0.6'];
const negPoints = ['0.6', '0.55'];
fetch(url, {
    method: 'POST',
    headers: {
        'Content-Type': 'application/json',
    },
    body: JSON.stringify({
        imagUrl: imageUrl,
        points: points,
        negPoints: negPoints
    })
})
.then(response => {
    if (!response.ok) {
        throw new Error('Network response was not ok');
    }
    return response.blob();
})
.then(blob => {
    console.log('Image received:', blob);
    const imageUrl = URL.createObjectURL(blob);
    const img = document.createElement('img');
    img.src = imageUrl;
    document.body.appendChild(img);
})
.catch(error => {
    console.error('Fetch error:', error);
});

where segmentation was applied to the right-most cyclist’s right foot.

Hope this post has ignited some spark to explore neon and the Rust-Node.js interactions.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.