diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..fcd87f451d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -299,7 +299,7 @@ required-features = ["sqlite"] [[test]] name = "sqlite-any" path = "tests/sqlite/any.rs" -required-features = ["sqlite"] +required-features = ["sqlite", "any"] [[test]] name = "sqlite-types" @@ -401,6 +401,11 @@ name = "mysql-rustsec" path = "tests/mysql/rustsec.rs" required-features = ["mysql"] +[[test]] +name = "mysql-any" +path = "tests/mysql/any.rs" +required-features = ["mysql", "any"] + # # PostgreSQL # @@ -454,3 +459,8 @@ required-features = ["postgres"] name = "postgres-rustsec" path = "tests/postgres/rustsec.rs" required-features = ["postgres", "macros", "migrate"] + +[[test]] +name = "postgres-any" +path = "tests/postgres/any.rs" +required-features = ["postgres", "any"] diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 241900560e..92a8142711 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -96,7 +96,7 @@ impl AnyConnectionBackend for MySqlConnection { .try_flatten_stream() .map(|res| { Ok(match res? { - Either::Left(result) => Either::Left(map_result(result)), + Either::Left(result) => Either::Left(result.into()), Either::Right(row) => Either::Right(AnyRow::try_from(&row)?), }) }), @@ -210,11 +210,12 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for MySqlConnectOptions { } } -fn map_result(result: MySqlQueryResult) -> AnyQueryResult { - AnyQueryResult { - rows_affected: result.rows_affected, - // Don't expect this to be a problem - #[allow(clippy::cast_possible_wrap)] - last_insert_id: Some(result.last_insert_id as i64), +/// This conversion attempts to save last_insert_id by converting to i64. +impl From for AnyQueryResult { + fn from(done: MySqlQueryResult) -> Self { + AnyQueryResult { + rows_affected: done.rows_affected(), + last_insert_id: done.last_insert_id().try_into().ok(), + } } } diff --git a/sqlx-mysql/src/query_result.rs b/sqlx-mysql/src/query_result.rs index f008db06ae..9951b26acb 100644 --- a/sqlx-mysql/src/query_result.rs +++ b/sqlx-mysql/src/query_result.rs @@ -24,13 +24,3 @@ impl Extend for MySqlQueryResult { } } } -#[cfg(feature = "any")] -/// This conversion attempts to save last_insert_id by converting to i64. -impl From for sqlx_core::any::AnyQueryResult { - fn from(done: MySqlQueryResult) -> Self { - sqlx_core::any::AnyQueryResult { - rows_affected: done.rows_affected(), - last_insert_id: done.last_insert_id().try_into().ok(), - } - } -} diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 1f248505ae..891b430e98 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -98,7 +98,7 @@ impl AnyConnectionBackend for PgConnection { .try_flatten_stream() .map( move |res: sqlx_core::Result>| match res? { - Either::Left(result) => Ok(Either::Left(map_result(result))), + Either::Left(result) => Ok(Either::Left(result.into())), Either::Right(row) => Ok(Either::Right(AnyRow::try_from(&row)?)), }, ), @@ -243,9 +243,11 @@ impl<'a> TryFrom<&'a AnyConnectOptions> for PgConnectOptions { } } -fn map_result(res: PgQueryResult) -> AnyQueryResult { - AnyQueryResult { - rows_affected: res.rows_affected(), - last_insert_id: None, +impl From for AnyQueryResult { + fn from(done: PgQueryResult) -> Self { + AnyQueryResult { + rows_affected: done.rows_affected(), + last_insert_id: None, + } } } diff --git a/sqlx-postgres/src/query_result.rs b/sqlx-postgres/src/query_result.rs index 3a243f3ee6..f96f4f3a6a 100644 --- a/sqlx-postgres/src/query_result.rs +++ b/sqlx-postgres/src/query_result.rs @@ -18,13 +18,3 @@ impl Extend for PgQueryResult { } } } - -#[cfg(feature = "any")] -impl From for sqlx_core::any::AnyQueryResult { - fn from(done: PgQueryResult) -> Self { - sqlx_core::any::AnyQueryResult { - rows_affected: done.rows_affected, - last_insert_id: None, - } - } -} diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 83b141decd..7915eb5cb1 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -96,7 +96,7 @@ impl AnyConnectionBackend for SqliteConnection { .try_flatten_stream() .map( move |res: sqlx_core::Result>| match res? { - Either::Left(result) => Ok(Either::Left(map_result(result))), + Either::Left(result) => Ok(Either::Left(result.into())), Either::Right(row) => Ok(Either::Right(AnyRow::try_from(&row)?)), }, ), @@ -231,9 +231,16 @@ fn map_arguments(args: AnyArguments) -> SqliteArguments { } } -fn map_result(res: SqliteQueryResult) -> AnyQueryResult { - AnyQueryResult { - rows_affected: res.rows_affected(), - last_insert_id: None, +impl From for AnyQueryResult { + fn from(done: SqliteQueryResult) -> Self { + // logic as per: https://www.sqlite.org/c3ref/last_insert_rowid.html + let last_insert_id = match done.last_insert_rowid() { + 0 => None, + n => Some(n), + }; + AnyQueryResult { + rows_affected: done.rows_affected(), + last_insert_id, + } } } diff --git a/sqlx-sqlite/src/query_result.rs b/sqlx-sqlite/src/query_result.rs index 8c8c27fcf4..088e032db6 100644 --- a/sqlx-sqlite/src/query_result.rs +++ b/sqlx-sqlite/src/query_result.rs @@ -24,17 +24,3 @@ impl Extend for SqliteQueryResult { } } } - -#[cfg(feature = "any")] -impl From for sqlx_core::any::AnyQueryResult { - fn from(done: SqliteQueryResult) -> Self { - let last_insert_id = match done.last_insert_rowid() { - 0 => None, - n => Some(n), - }; - sqlx_core::any::AnyQueryResult { - rows_affected: done.rows_affected(), - last_insert_id, - } - } -} diff --git a/tests/mysql/any.rs b/tests/mysql/any.rs new file mode 100644 index 0000000000..309d804152 --- /dev/null +++ b/tests/mysql/any.rs @@ -0,0 +1,28 @@ +use sqlx::Any; +use sqlx_test::new; + +/// ensure Any type with MySQL backing returns last_insert_id properly +/// https://github.com/launchbadge/sqlx/issues/2982 +#[sqlx_macros::test] +async fn any_sets_last_insert_id() -> anyhow::Result<()> { + sqlx::any::install_default_drivers(); + + let mut conn = new::().await?; + // syntax as per: https://dev.mysql.com/doc/refman/9.6/en/example-auto-increment.html + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, name TEXT NOT NULL) + "#, + ) + .await?; + + let result = sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("Glorbo") + .execute(&mut conn) + .await?; + + assert_eq!(result.last_insert_id(), Some(1)); + + Ok(()) +} diff --git a/tests/postgres/any.rs b/tests/postgres/any.rs new file mode 100644 index 0000000000..ec5dc0e38b --- /dev/null +++ b/tests/postgres/any.rs @@ -0,0 +1,29 @@ +use sqlx::Any; +use sqlx_test::new; + +/// ensure Any type with PostgreSQL backing returns last_insert_id properly +/// https://github.com/launchbadge/sqlx/issues/2982 +#[sqlx_macros::test] +async fn any_sets_last_insert_id() -> anyhow::Result<()> { + sqlx::any::install_default_drivers(); + + let mut conn = new::().await?; + // syntax as per: https://www.postgresql.org/docs/current/ddl-identity-columns.html + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER GENERATED ALWAYS AS IDENTITY, name TEXT NOT NULL) + "#, + ) + .await?; + + let result = sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("Glorbo") + .execute(&mut conn) + .await?; + + // NOTE: PgQueryResult does not implement an equivalent concept and can only return None + assert_eq!(result.last_insert_id(), None); + + Ok(()) +} diff --git a/tests/sqlite/any.rs b/tests/sqlite/any.rs index b71c3ba43d..e83e3a021e 100644 --- a/tests/sqlite/any.rs +++ b/tests/sqlite/any.rs @@ -33,3 +33,29 @@ async fn issue_3179() -> anyhow::Result<()> { Ok(()) } + +/// ensure Any type with SQLite backing returns last_insert_id properly +/// https://github.com/launchbadge/sqlx/issues/2982 +#[sqlx_macros::test] +async fn any_sets_last_insert_id() -> anyhow::Result<()> { + sqlx::any::install_default_drivers(); + + let mut conn = new::().await?; + // syntax as per: https://sqlite.org/autoinc.html + let _ = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL) + "#, + ) + .await?; + + let result = sqlx::query("INSERT INTO users (name) VALUES (?)") + .bind("Glorbo") + .execute(&mut conn) + .await?; + + assert_eq!(result.last_insert_id(), Some(1)); + + Ok(()) +}