Skip to content

Commit 301db9e

Browse files
committed
Avoid ID clashes
Currently the the ID for a new paste is randomly generated in the caller of the database insert() function. Then the insert() function tries to insert a new row into the database with that passed ID. There can however already exists a paste in the database with the same ID leading to an insert failure, due to a constraint violation due to the PRIMARY KEY attribute. Checking prior the the INSERT via a SELECT query would open the window for a race condition. A failure to push a new paste is quite severe, since the user might have spent some some to format the input. Generate the ID in a loop inside, until the INSERT succeeds.
1 parent 9d7df4f commit 301db9e

4 files changed

Lines changed: 78 additions & 65 deletions

File tree

src/db.rs

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -247,32 +247,68 @@ impl Database {
247247
Ok(Self { conn })
248248
}
249249

250-
/// Insert `entry` under `id` into the database and optionally set owner to `uid`.
251-
pub async fn insert(&self, id: Id, entry: write::Entry) -> Result<(), Error> {
250+
/// Insert `entry` with a new generated `id` into the database and optionally set owner to `uid`.
251+
pub async fn insert(&self, entry: write::Entry) -> Result<Id, Error> {
252252
let conn = self.conn.clone();
253-
let id = id.as_u32();
254253
let write::DatabaseEntry { entry, data, nonce } = entry.compress().await?.encrypt().await?;
255254

256-
spawn_blocking(move || match entry.expires {
257-
None => conn.lock().execute(
258-
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)",
259-
params![id, entry.uid, data, entry.burn_after_reading, nonce],
260-
),
261-
Some(expires) => conn.lock().execute(
262-
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))",
263-
params![
264-
id,
265-
entry.uid,
266-
data,
267-
entry.burn_after_reading,
268-
nonce,
269-
format!("{expires} seconds")
270-
],
271-
),
255+
let id = spawn_blocking(move || {
256+
const COUNTER_LIMIT: u32 = 10;
257+
let mut counter = 0;
258+
259+
let mut rng = rand::thread_rng();
260+
261+
loop {
262+
let id: Id = rand::Rng::gen::<u32>(&mut rng).into();
263+
let id_inner = id.as_u32();
264+
265+
let result = match entry.expires {
266+
None => conn.lock().execute(
267+
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce) VALUES (?1, ?2, ?3, ?4, ?5)",
268+
params![id_inner, entry.uid, data, entry.burn_after_reading, nonce],
269+
),
270+
Some(expires) => conn.lock().execute(
271+
"INSERT INTO entries (id, uid, data, burn_after_reading, nonce, expires) VALUES (?1, ?2, ?3, ?4, ?5, datetime('now', ?6))",
272+
params![
273+
id_inner,
274+
entry.uid,
275+
data,
276+
entry.burn_after_reading,
277+
nonce,
278+
format!("{expires} seconds")
279+
],
280+
),
281+
};
282+
283+
match result {
284+
Err(rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code, extended_code }, Some(ref _message)))
285+
if code == rusqlite::ErrorCode::ConstraintViolation && extended_code == rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY && counter < COUNTER_LIMIT => {
286+
/* Retry if ID is already existent */
287+
counter += 1;
288+
continue;
289+
},
290+
Err(err) => {
291+
if counter >= COUNTER_LIMIT {
292+
tracing::error!("Failed to generate ID after {counter} retries");
293+
}
294+
295+
break Err(err)
296+
},
297+
Ok(rows) => {
298+
debug_assert!(rows == 1);
299+
300+
if counter > 4 {
301+
tracing::warn!("Required {counter} retries to generate new ID");
302+
}
303+
304+
break Ok(id)
305+
},
306+
}
307+
}
272308
})
273309
.await??;
274310

275-
Ok(())
311+
Ok(id)
276312
}
277313

278314
/// Get entire entry for `id`.
@@ -383,8 +419,7 @@ mod tests {
383419
..Default::default()
384420
};
385421

386-
let id = Id::from(1234);
387-
db.insert(id, entry).await?;
422+
let id = db.insert(entry).await?;
388423

389424
let entry = db.get(id, None).await?;
390425
assert_eq!(entry.text, "hello world");
@@ -406,8 +441,7 @@ mod tests {
406441
..Default::default()
407442
};
408443

409-
let id = Id::from(1234);
410-
db.insert(id, entry).await?;
444+
let id = db.insert(entry).await?;
411445

412446
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
413447

@@ -422,8 +456,7 @@ mod tests {
422456
async fn delete() -> Result<(), Box<dyn std::error::Error>> {
423457
let db = new_db()?;
424458

425-
let id = Id::from(1234);
426-
db.insert(id, write::Entry::default()).await?;
459+
let id = db.insert(write::Entry::default()).await?;
427460

428461
assert!(db.get(id, None).await.is_ok());
429462
assert!(db.delete(id).await.is_ok());

src/id.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::db::write::Entry;
21
use crate::errors::Error;
32
use std::fmt;
43
use std::str::FromStr;
@@ -23,11 +22,8 @@ impl Id {
2322
}
2423

2524
/// Generate a URL path from the string representation and `entry`'s extension.
26-
pub fn to_url_path(self, entry: &Entry) -> String {
27-
entry
28-
.extension
29-
.as_ref()
30-
.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}"))
25+
pub fn to_url_path(self, extension: Option<&str>) -> String {
26+
extension.map_or_else(|| format!("{self}"), |ext| format!("{self}.{ext}"))
3127
}
3228
}
3329

src/routes/form.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ use std::num::NonZeroU32;
22

33
use crate::db::write;
44
use crate::env::BASE_PATH;
5-
use crate::id::Id;
65
use crate::{pages, AppState, Error};
76
use axum::extract::{Form, State};
87
use axum::response::Redirect;
98
use axum_extra::extract::cookie::{Cookie, SameSite, SignedCookieJar};
10-
use rand::Rng;
119
use serde::{Deserialize, Serialize};
1210

1311
#[derive(Debug, Serialize, Deserialize)]
@@ -45,14 +43,6 @@ pub async fn insert(
4543
Form(entry): Form<Entry>,
4644
is_https: bool,
4745
) -> Result<(SignedCookieJar, Redirect), pages::ErrorResponse<'static>> {
48-
let id: Id = tokio::task::spawn_blocking(|| {
49-
let mut rng = rand::thread_rng();
50-
rng.gen::<u32>()
51-
})
52-
.await
53-
.map_err(Error::from)?
54-
.into();
55-
5646
// Retrieve uid from cookie or generate a new one.
5747
let uid = if let Some(cookie) = jar.get("uid") {
5848
cookie
@@ -66,22 +56,24 @@ pub async fn insert(
6656
let mut entry: write::Entry = entry.into();
6757
entry.uid = Some(uid);
6858

69-
let mut url = id.to_url_path(&entry);
70-
71-
let burn_after_reading = entry.burn_after_reading.unwrap_or(false);
72-
if burn_after_reading {
73-
url = format!("burn/{url}");
74-
}
75-
76-
let url_with_base = BASE_PATH.join(&url);
77-
7859
if let Some(max_exp) = state.max_expiration {
7960
entry.expires = entry
8061
.expires
8162
.map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp)));
8263
}
8364

84-
state.db.insert(id, entry).await?;
65+
let burn = entry.burn_after_reading.unwrap_or(false);
66+
let extension = entry.extension.clone();
67+
68+
let id = state.db.insert(entry).await?;
69+
70+
let mut url = id.to_url_path(extension.as_deref());
71+
72+
if burn {
73+
url = format!("burn/{url}");
74+
}
75+
76+
let url_with_base = BASE_PATH.join(&url);
8577

8678
let cookie = Cookie::build(("uid", uid.to_string()))
8779
.http_only(true)

src/routes/json.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ use std::num::NonZeroU32;
22

33
use crate::db::write;
44
use crate::env::BASE_PATH;
5-
use crate::errors::{Error, JsonErrorResponse};
6-
use crate::id::Id;
5+
use crate::errors::JsonErrorResponse;
76
use crate::AppState;
87
use axum::extract::State;
98
use axum::Json;
10-
use rand::Rng;
119
use serde::{Deserialize, Serialize};
1210

1311
#[derive(Debug, Serialize, Deserialize)]
@@ -41,14 +39,6 @@ pub async fn insert(
4139
state: State<AppState>,
4240
Json(entry): Json<Entry>,
4341
) -> Result<Json<RedirectResponse>, JsonErrorResponse> {
44-
let id: Id = tokio::task::spawn_blocking(|| {
45-
let mut rng = rand::thread_rng();
46-
rng.gen::<u32>()
47-
})
48-
.await
49-
.map_err(Error::from)?
50-
.into();
51-
5242
let mut entry: write::Entry = entry.into();
5343

5444
if let Some(max_exp) = state.max_expiration {
@@ -57,9 +47,11 @@ pub async fn insert(
5747
.map_or_else(|| Some(max_exp), |value| Some(value.min(max_exp)));
5848
}
5949

60-
let url = id.to_url_path(&entry);
50+
let extension = entry.extension.clone();
51+
52+
let id = state.db.insert(entry).await?;
53+
let url = id.to_url_path(extension.as_deref());
6154
let path = BASE_PATH.join(&url);
62-
state.db.insert(id, entry).await?;
6355

6456
Ok(Json::from(RedirectResponse { path }))
6557
}

0 commit comments

Comments
 (0)