Clean Code with Rust & Axum

Clean Code with Rust & Axum

One of my favorite parts of the book Clean Code was seeing a code snippet that starts off gross and unmanageable and watching it iteratively improve.

In this post, we’re going to do the same thing but specifically use a close-to real-world example with Rust and Axum. For each refactoring we do, we’ll also call out why this change improves the code.

We’ll start with a single, messy API route:

#[derive(Deserialize)]
struct CreateUrl {
    url: String,
}

#[derive(Serialize)]
struct CreateUrlResponse {
    id: String,
}

// Save a URL and return an ID associated with it
async fn save_url(
    headers: HeaderMap,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Response {
    // First grab an auth token from a custom header
    let token_header = headers.get("X-Auth-Token");
    let token = match token_header {
        None => return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(),
        Some(header) => header.to_str().unwrap(),
    };

    // Then verify the token is correct
    let verify_result = verify_auth_token(token).await;
    let user = match verify_result {
        Some(user) => user,
        None => return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(),
    };

    // Insert our URL into the database and get back an ID that the database generated
    let insert_result =
        sqlx::query!(
            "INSERT INTO urls (url, user_id) VALUE (lower($1), $2) RETURNING id",
            create_url.url, user.user_id()
        )
            .map(|row| row.id)
            .fetch_one(pool)
            .await;
    let id = match insert_result {
        Ok(id) => id,
        Err(_) => {
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                "An unexpected exception has occurred",
            )
                .into_response()
        }
    };

    (StatusCode::CREATED, Json(CreateUrlResponse { id })).into_response()
}

Take a second to think about all the things you don’t like about that code snippet, and then let’s dive in and make it better!

FromRequest

The first thing you might notice is that there are 10 lines of code dedicated to grabbing the X-Auth-Token header and verifying it. A reasonable first attempt at addressing this is to turn this into a method so it can be re-used:

pub async fn extract_and_verify_auth_token(headers: HeaderMap) -> Option<User>

which we can call:

async fn save_url(
    headers: HeaderMap,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Response {
    let user = match extract_and_verify_auth_token(headers) {
        Some(user) => user,
        None => return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(),
    };
    // ...
}

But we can actually do a bit better. What if our function signature looked like this:

async fn save_url(
    user: User, // <-- new
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>
) -> Response

If a valid X-Auth-Token header is present, we automatically extract and verify it. And if a valid X-Auth-Token header is NOT present, we never even call save_url.

That’s essentially what FromRequest does. Axum actually has both FromRequest and FromRequestParts where the difference is FromRequest will consume the body of the request. Since we only need a header, we can just use FromRequestParts:

#[async_trait]
impl<S> FromRequestParts<S> for User
    where
        S: Send + Sync,
{
    type Rejection = (StatusCode, &'static str);

    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
        let auth_header = parts.headers.get("X-Auth-Token")
            .and_then(|header| header.to_str().ok())
            .ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))?;

        verify_auth_token(auth_header).await
            .map_err(|_| (StatusCode::UNAUTHORIZED, "Unauthorized"))
        }
    }
}

And now our API looks like this:

// Save a URL and return an ID associated with it
async fn save_url(
    user: User,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Response {
    // Insert our URL into the database and get back an ID that the database generated
    let insert_result = sqlx::query!(
        "INSERT INTO urls (url, user_id) VALUE (lower($1), $2) RETURNING id", 
        create_url.url, user.user_id())        
        .map(|row| row.id)
        .fetch_one(pool)
        .await;
    let id = match insert_result {
        Ok(id) => id,
        Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "An unexpected exception has occurred").into_response(),
    };

    (StatusCode::CREATED, Json(CreateUrlResponse { id })).into_response()
}

Why is this change beneficial?

The biggest advantage here actually has to do with what happens when we make our next API route. Instead of copying and pasting that large block of code, we just add User to the function parameters and we are good.

Another benefit is that, when I am thinking about the save_url function, I don’t really care about the details of how we extract and verify tokens. If I want to understand that, I can opt-in by reading where the User comes from. I care way more about the details of how we save urls, and that’s all that’s left here.

Note: Be careful about extractor ordering

FromRequest and FromRequestParts are both called extractors. Since FromRequest consumes the body of the request, you can only have one FromRequest extractor per route - and it must be the last argument.

So while this function signature:

async fn save_url(
    user: User,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Response

and this function signature:

async fn save_url(
    Json(create_url): Json<CreateUrl>,
    user: User,
    State(pool): State<PgPool>,
) -> Response

might look the same, only the first one will compile. Thanks to j_platte for pointing it out. If you are running into compile errors with your routes, axum-macros can provide more detailed errors about what's wrong.

IntoResponse

Rust’s ? operator is one of my favorite parts of the language—it allows you to propagate Errors or Options succinctly. However, our save_url function doesn’t use it at all, opting to instead match on the result.

Let’s fix that:


async fn save_url(
    user: User,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Response {
    // Insert our URL into the database and get back an ID that the database generated
    let id = sqlx::query!(
        "INSERT INTO urls (url, user_id) VALUE (lower($1), $2) RETURNING id", 
        create_url.url, user.user_id())        
        .map(|row| row.id)
        .fetch_one(pool)
        .await?; // <-- 

    (StatusCode::CREATED, Json(CreateUrlResponse { id })).into_response()
}

This won’t compile just yet, and Rust tells us why:

the `?` operator can only be used in an async function that returns `Result` or `Option`

We can change the return type to a Result to signify that it can fail. We’ll define our own error type and implement From<sqlx::Error> so our sql errors will automatically get converted to our error:

// Our Error
pub enum ApiError {
    DatabaseError(sqlx::Error),
}

// The ? operator will automatically convert sqlx::Error to ApiError
impl From<sqlx::Error> for ApiError {
    fn from(e: sqlx::Error) -> Self {
        SaveUrlError::DatabaseError(e)
    }
}

async fn save_url(
    user: User,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Result<Response, ApiError> {
    // Insert our URL into the database and get back an ID that the database generated
    let id = sqlx::query!(
        "INSERT INTO urls (url, user_id) VALUE (lower($1), $2) RETURNING id", 
        create_url.url, user.user_id)        
        .map(|row| row.id)
        .fetch_one(pool)
        .await?;

    Ok((StatusCode::CREATED, Json(CreateUrlResponse { id })).into_response())
}

We’re pretty close. The only remaining piece is that Axum has no idea what to do with an ApiError. We can fix that by implementing IntoResponse

impl IntoResponse for ApiError {
    fn into_response(self) -> Response {
        match self {
            ApiError::DatabaseError(_) => 
                (StatusCode::INTERNAL_SERVER_ERROR, "An unexpected exception has occured").into_response()
    }
}

Why is this change beneficial?

For common errors, this saves us a lot of boilerplate. Otherwise, every time we make a DB call, we need to match the result and return an internal server error.

It is worth calling out that this change isn’t universally beneficial. If you have a function that can fail in 3-4 different ways, you may want to explicitly show how your route turns those different errors into an HTTP response.

A much-needed refactor

If you’ve been yelling the whole time “Don’t put SQL calls directly in the API route,” I’ve got some good news for you.

pub struct Db;

impl Db {
    pub async fn save_url(
        url: &str, 
        user: &User, 
        pool: &PgPool
    ) -> Result<String, sqlx::Error> {
        sqlx::query!(
            "INSERT INTO urls (url, user_id) VALUE (lower($1), $2) RETURNING id", 
            url, user.user_id())        
        .map(|row| row.id )
        .fetch_one(pool)
        .await
    }
}

(Note this function originally returned a CreateUrlResponse - but as Kulinda pointed out, that is a route-specific type and should be the responsibility of the API to construct. Thank you to them for the fix)

By abstracting all of that away, our route gets even more succinct:

async fn save_url(
    user: User,
    State(pool): State<PgPool>,
    Json(create_url): Json<CreateUrl>,
) -> Result<Response, SaveUrlError> {

    let id = Db::save_url(&create_url.url, &user, &pool).await?;
    Ok((StatusCode::CREATED, Json(CreateUrlResponse { id })).into_response())
}

Why is this change beneficial?

I like to call this the “at a glance” problem. If you are scanning through the final code snippet, and you are just talking out loud to yourself, you might say:

“Ok so the save_url route takes in a User and CreateUrl JSON, then it saves that to the database alongside the user and returns the response”

If you look at the example with the full SQL query in the route, you might say:

“Ok so the save_url route takes in a User and CreateUrl JSON, then it does some SQL… what does this sql do.. oh ok it just saves the URL and user id and returns the ID”

In this example, you might be able to parse the SQL query quickly enough, but imagine if you had a join or anything remotely complicated there. By replacing the SQL query with a short description save_url you can immediately understand what it does and then opt-in if you want to know the details.

Looking to write more Rust?

I love Rust. Our whole backend at PropelAuth is written in Rust. If you are considering writing your backend in Rust and need an authentication provider with multi-tenant/B2B support built in, you can learn more about us here.

Now back to regularly scheduled programming.

Summary

In the end, we took a pretty verbose code snippet and broke it down into its key components. Not only did this make the code significantly easier to read at a glance, but it also will help us when we create more API routes.