diff --git a/Cargo.toml b/Cargo.toml index b7d610a..ee60e6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ esplora = ["bdk_esplora", "_payjoin-dependencies"] rpc = ["bdk_bitcoind_rpc", "_payjoin-dependencies"] # Internal features -_payjoin-dependencies = ["payjoin", "reqwest", "url"] +_payjoin-dependencies = ["payjoin", "reqwest", "url", "sqlite"] # Use this to consensus verify transactions at sync time verify = [] diff --git a/README.md b/README.md index d905197..8a18ea4 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ And yes, it can do Taproot!! This crate can be used for the following purposes: - Instantly create a miniscript based wallet and connect to your backend of choice (Electrum, Esplora, Core RPC, Kyoto etc) and quickly play around with your own complex bitcoin scripting workflow. With one or many wallets, connected with one or many backends. - The `tests/integration.rs` module is used to document high level complex workflows between BDK and different Bitcoin infrastructure systems, like Core, Electrum and Lightning(soon TM). - - Receive and send Async Payjoins. Note that even though Async Payjoin as a protocol allows the receiver and sender to go offline during the payjoin, the BDK CLI implementation currently does not support persisting. + - Receive and send Async Payjoins with session persistence. Sessions can be resumed if interrupted. - (Planned) Expose the basic command handler via `wasm` to integrate `bdk-cli` functionality natively into the web platform. See also the [playground](https://bitcoindevkit.org/bdk-cli/playground/) page. If you are considering using BDK in your own wallet project bdk-cli is a nice playground to get started with. It allows easy testnet and regtest wallet operations, to try out what's possible with descriptors, miniscript, and BDK APIs. For more information on BDK refer to the [website](https://bitcoindevkit.org/) and the [rust docs](https://docs.rs/bdk_wallet/1.0.0/bdk_wallet/index.html) @@ -140,6 +140,31 @@ cargo run --features rpc -- wallet --wallet payjoin_wallet2 balance cargo run --features rpc -- wallet --wallet payjoin_wallet2 send_payjoin --ohttp_relay "https://pj.bobspacebkk.com" --ohttp_relay "https://pj.benalleng.com" --fee_rate 1 --uri "" ``` +### Payjoin Session Persistence + +Payjoin sessions are automatically persisted to a SQLite database (`payjoin.sqlite`) in the data directory. This allows sessions to be resumed if interrupted. + +#### Resume Payjoin Sessions + +Resume all pending sessions: +``` +cargo run --features rpc -- wallet --wallet resume_payjoin --directory "https://payjo.in" --ohttp_relay "https://pj.bobspacebkk.com" +``` + +Resume a specific session by ID: +``` +cargo run --features rpc -- wallet --wallet resume_payjoin --directory "https://payjo.in" --ohttp_relay "https://pj.bobspacebkk.com" --session_id +``` + +Sessions are processed sequentially (not concurrently) due to BDK-CLI's architecture. Each session waits up to 30 seconds for updates before timing out. If no session ID is specified, the most recent active sessions are resumed first. + +#### View Session History + +View all payjoin sessions (active and completed) and also see their status: +``` +cargo run -- wallet --wallet payjoin_history +``` + ## Justfile We have added the `just` command runner to help you with common commands (during development) and running regtest `bitcoind` if you are using the `rpc` feature. diff --git a/src/commands.rs b/src/commands.rs index 14ad9ea..4cebcf0 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -479,6 +479,20 @@ pub enum OnlineWalletSubCommand { )] fee_rate: u64, }, + /// Resume pending payjoin sessions. + ResumePayjoin { + /// Payjoin directory for the session + #[arg(env = "PAYJOIN_DIRECTORY", long = "directory", required = true)] + directory: String, + /// URL of the Payjoin OHTTP relay. Can be repeated multiple times. + #[arg(env = "PAYJOIN_OHTTP_RELAY", long = "ohttp_relay", required = true)] + ohttp_relay: Vec, + /// Resume only a specific active session ID (sender and/or receiver). + #[arg(env = "PAYJOIN_SESSION_ID", long = "session_id")] + session_id: Option, + }, + /// Show payjoin session history. + PayjoinHistory, } /// Subcommands for Key operations. diff --git a/src/error.rs b/src/error.rs index 064a928..5f40aec 100644 --- a/src/error.rs +++ b/src/error.rs @@ -112,6 +112,38 @@ pub enum BDKCliError { ))] #[error("Reqwest error: {0}")] ReqwestError(#[from] reqwest::Error), + + #[cfg(feature = "payjoin")] + #[error("Payjoin URL parse error: {0}")] + PayjoinUrlParse(#[from] payjoin::IntoUrlError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin send response error: {0}")] + PayjoinSendResponse(#[from] payjoin::send::ResponseError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin sender build error: {0}")] + PayjoinSenderBuild(#[from] payjoin::send::BuildSenderError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin receive error: {0}")] + PayjoinReceive(#[from] payjoin::receive::Error), + + #[cfg(feature = "payjoin")] + #[error("Payjoin selection error: {0}")] + PayjoinSelection(#[from] payjoin::receive::SelectionError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin input contribution error: {0}")] + PayjoinInputContribution(#[from] payjoin::receive::InputContributionError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin create request error: {0}")] + PayjoinCreateRequest(#[from] payjoin::send::v2::CreateRequestError), + + #[cfg(feature = "payjoin")] + #[error("Payjoin database error: {0}")] + PayjoinDb(#[from] PayjoinDbError), } impl From for BDKCliError { @@ -119,3 +151,64 @@ impl From for BDKCliError { BDKCliError::PsbtExtractTxError(Box::new(value)) } } + +/// Error type for payjoin database operations +#[cfg(feature = "payjoin")] +#[derive(Debug)] +pub enum PayjoinDbError { + /// SQLite database error + Rusqlite(bdk_wallet::rusqlite::Error), + /// JSON serialization error + Serialize(serde_json::Error), + /// JSON deserialization error + Deserialize(serde_json::Error), +} + +#[cfg(feature = "payjoin")] +impl std::fmt::Display for PayjoinDbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PayjoinDbError::Rusqlite(e) => write!(f, "Database operation failed: {e}"), + PayjoinDbError::Serialize(e) => write!(f, "Serialization failed: {e}"), + PayjoinDbError::Deserialize(e) => write!(f, "Deserialization failed: {e}"), + } + } +} + +#[cfg(feature = "payjoin")] +impl std::error::Error for PayjoinDbError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + PayjoinDbError::Rusqlite(e) => Some(e), + PayjoinDbError::Serialize(e) => Some(e), + PayjoinDbError::Deserialize(e) => Some(e), + } + } +} + +#[cfg(feature = "payjoin")] +impl From for PayjoinDbError { + fn from(error: bdk_wallet::rusqlite::Error) -> Self { + PayjoinDbError::Rusqlite(error) + } +} + +#[cfg(feature = "payjoin")] +impl From for payjoin::ImplementationError { + fn from(error: PayjoinDbError) -> Self { + payjoin::ImplementationError::new(error) + } +} + +#[cfg(feature = "payjoin")] +impl + From> for BDKCliError +where + ApiErr: std::error::Error, + StorageErr: std::error::Error, + ErrorState: std::fmt::Debug, +{ + fn from(e: payjoin::persist::PersistedError) -> Self { + BDKCliError::Generic(e.to_string()) + } +} diff --git a/src/handlers.rs b/src/handlers.rs index 1f867b4..03abbb8 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -15,6 +15,8 @@ use crate::config::{WalletConfig, WalletConfigInner}; use crate::error::BDKCliError as Error; #[cfg(any(feature = "sqlite", feature = "redb"))] use crate::persister::Persister; +#[cfg(feature = "cbf")] +use crate::utils::BlockchainClient::KyotoClient; use crate::utils::*; #[cfg(feature = "redb")] use bdk_redb::Store as RedbStore; @@ -46,8 +48,6 @@ use bdk_wallet::{ }; use cli_table::{Cell, CellStruct, Style, Table, format::Justify}; use serde_json::json; -#[cfg(feature = "cbf")] -use {crate::utils::BlockchainClient::KyotoClient, bdk_kyoto::LightClient, tokio::select}; #[cfg(feature = "electrum")] use crate::utils::BlockchainClient::Electrum; @@ -594,6 +594,16 @@ pub fn handle_offline_wallet_subcommand( } } +#[cfg(feature = "payjoin")] +pub fn open_payjoin_db( + datadir: Option, +) -> Result, Error> { + use crate::payjoin::db::{DB_FILENAME, Database}; + let home_dir = prepare_home_dir(datadir)?; + let db_path = home_dir.join(DB_FILENAME); + Ok(std::sync::Arc::new(Database::create(&db_path)?)) +} + /// Execute an online wallet sub-command /// /// Online wallet sub-commands are described in [`OnlineWalletSubCommand`]. @@ -605,8 +615,9 @@ pub fn handle_offline_wallet_subcommand( ))] pub(crate) async fn handle_online_wallet_subcommand( wallet: &mut Wallet, - client: BlockchainClient, + client: &BlockchainClient, online_subcommand: OnlineWalletSubCommand, + datadir: Option, ) -> Result { match online_subcommand { FullScan { @@ -632,7 +643,7 @@ pub(crate) async fn handle_online_wallet_subcommand( client .populate_tx_cache(wallet.tx_graph().full_txs().map(|tx_node| tx_node.tx)); - let update = client.full_scan(request, _stop_gap, batch_size, false)?; + let update = client.full_scan(request, _stop_gap, *batch_size, false)?; wallet.apply_update(update)?; } #[cfg(feature = "esplora")] @@ -641,7 +652,7 @@ pub(crate) async fn handle_online_wallet_subcommand( parallel_requests, } => { let update = client - .full_scan(request, _stop_gap, parallel_requests) + .full_scan(request, _stop_gap, *parallel_requests) .await .map_err(|e| *e)?; wallet.apply_update(update)?; @@ -658,7 +669,7 @@ pub(crate) async fn handle_online_wallet_subcommand( hash: genesis_block.block_hash(), }); let mut emitter = Emitter::new( - &*client, + client.as_ref(), genesis_cp.clone(), genesis_cp.height(), NO_EXPECTED_MEMPOOL_TXS, @@ -724,7 +735,8 @@ pub(crate) async fn handle_online_wallet_subcommand( max_fee_rate, } => { let relay_manager = Arc::new(Mutex::new(RelayManager::new())); - let mut payjoin_manager = PayjoinManager::new(wallet, relay_manager); + let db = open_payjoin_db(datadir.clone())?; + let mut payjoin_manager = PayjoinManager::new(wallet, relay_manager, db); return payjoin_manager .receive_payjoin(amount, directory, max_fee_rate, ohttp_relay, client) .await; @@ -735,11 +747,27 @@ pub(crate) async fn handle_online_wallet_subcommand( fee_rate, } => { let relay_manager = Arc::new(Mutex::new(RelayManager::new())); - let mut payjoin_manager = PayjoinManager::new(wallet, relay_manager); + let db = open_payjoin_db(datadir.clone())?; + let mut payjoin_manager = PayjoinManager::new(wallet, relay_manager, db); return payjoin_manager .send_payjoin(uri, fee_rate, ohttp_relay, client) .await; } + ResumePayjoin { + directory, + ohttp_relay, + session_id, + } => { + let relay_manager = Arc::new(Mutex::new(RelayManager::new())); + let db = open_payjoin_db(datadir)?; + let mut payjoin_manager = PayjoinManager::new(wallet, relay_manager, db); + return payjoin_manager + .resume_payjoins(directory, ohttp_relay, session_id, client) + .await; + } + PayjoinHistory => { + return crate::payjoin::PayjoinManager::history(datadir); + } } } @@ -1209,7 +1237,7 @@ pub(crate) async fn handle_command(cli_opts: CliOpts) -> Result { wallet, subcommand: WalletSubCommand::OnlineWalletSubCommand(online_subcommand), } => { - let home_dir = prepare_home_dir(cli_opts.datadir)?; + let home_dir = prepare_home_dir(cli_opts.datadir.clone())?; let (wallet_opts, network) = load_wallet_config(&home_dir, &wallet)?; @@ -1246,8 +1274,9 @@ pub(crate) async fn handle_command(cli_opts: CliOpts) -> Result { let result = handle_online_wallet_subcommand( &mut wallet, - blockchain_client, + &blockchain_client, online_subcommand, + cli_opts.datadir.clone(), ) .await?; wallet.persist(&mut persister)?; @@ -1258,8 +1287,13 @@ pub(crate) async fn handle_command(cli_opts: CliOpts) -> Result { let mut wallet = new_wallet(network, wallet_opts)?; let blockchain_client = crate::utils::new_blockchain_client(wallet_opts, &wallet, database_path)?; - handle_online_wallet_subcommand(&mut wallet, blockchain_client, online_subcommand) - .await? + handle_online_wallet_subcommand( + &mut wallet, + &blockchain_client, + online_subcommand, + cli_opts.datadir.clone(), + ) + .await? }; Ok(result) } @@ -1452,9 +1486,14 @@ async fn respond( } => { let blockchain = new_blockchain_client(wallet_opts, wallet, _datadir).map_err(|e| e.to_string())?; - let value = handle_online_wallet_subcommand(wallet, blockchain, online_subcommand) - .await - .map_err(|e| e.to_string())?; + let value = handle_online_wallet_subcommand( + wallet, + &blockchain, + online_subcommand, + cli_opts.datadir.clone(), + ) + .await + .map_err(|e| e.to_string())?; Some(value) } ReplSubCommand::Wallet { @@ -1508,7 +1547,7 @@ async fn respond( feature = "rpc" ))] /// Syncs a given wallet using the blockchain client. -pub async fn sync_wallet(client: BlockchainClient, wallet: &mut Wallet) -> Result<(), Error> { +pub async fn sync_wallet(client: &BlockchainClient, wallet: &mut Wallet) -> Result<(), Error> { #[cfg(any(feature = "electrum", feature = "esplora"))] let request = wallet .start_sync_with_revealed_spks() @@ -1523,7 +1562,7 @@ pub async fn sync_wallet(client: BlockchainClient, wallet: &mut Wallet) -> Resul // already have. client.populate_tx_cache(wallet.tx_graph().full_txs().map(|tx_node| tx_node.tx)); - let update = client.sync(request, batch_size, false)?; + let update = client.sync(request, *batch_size, false)?; wallet .apply_update(update) .map_err(|e| Error::Generic(e.to_string())) @@ -1534,7 +1573,7 @@ pub async fn sync_wallet(client: BlockchainClient, wallet: &mut Wallet) -> Resul parallel_requests, } => { let update = client - .sync(request, parallel_requests) + .sync(request, *parallel_requests) .await .map_err(|e| *e)?; wallet @@ -1549,7 +1588,7 @@ pub async fn sync_wallet(client: BlockchainClient, wallet: &mut Wallet) -> Resul // reload the last 200 blocks in case of a reorg let emitter_height = wallet_cp.height().saturating_sub(200); let mut emitter = Emitter::new( - &*client, + client.as_ref(), wallet_cp, emitter_height, wallet @@ -1600,7 +1639,7 @@ pub async fn sync_wallet(client: BlockchainClient, wallet: &mut Wallet) -> Resul ))] /// Broadcasts a given transaction using the blockchain client. pub async fn broadcast_transaction( - client: BlockchainClient, + client: &BlockchainClient, tx: Transaction, ) -> Result { match client { @@ -1627,38 +1666,15 @@ pub async fn broadcast_transaction( #[cfg(feature = "cbf")] KyotoClient { client } => { - let LightClient { - requester, - mut info_subscriber, - mut warning_subscriber, - update_subscriber: _, - node, - } = *client; - - let subscriber = tracing_subscriber::FmtSubscriber::new(); - tracing::subscriber::set_global_default(subscriber) - .map_err(|e| Error::Generic(format!("SetGlobalDefault error: {e}")))?; - - tokio::task::spawn(async move { node.run().await }); - tokio::task::spawn(async move { - select! { - info = info_subscriber.recv() => { - if let Some(info) = info { - tracing::info!("{info}"); - } - }, - warn = warning_subscriber.recv() => { - if let Some(warn) = warn { - tracing::warn!("{warn}"); - } - } - } - }); let txid = tx.compute_txid(); - let wtxid = requester.broadcast_random(tx.clone()).await.map_err(|_| { - tracing::warn!("Broadcast was unsuccessful"); - Error::Generic("Transaction broadcast timed out after 30 seconds".into()) - })?; + let wtxid = client + .requester + .broadcast_random(tx.clone()) + .await + .map_err(|_| { + tracing::warn!("Broadcast was unsuccessful"); + Error::Generic("Transaction broadcast timed out after 30 seconds".into()) + })?; tracing::info!("Successfully broadcast WTXID: {wtxid}"); Ok(txid) } diff --git a/src/payjoin/db.rs b/src/payjoin/db.rs new file mode 100644 index 0000000..3525e44 --- /dev/null +++ b/src/payjoin/db.rs @@ -0,0 +1,397 @@ +use std::fmt; +use std::path::Path; +use std::sync::{Arc, Mutex, MutexGuard}; + +use bdk_wallet::rusqlite::{Connection, params}; +use payjoin::HpkePublicKey; +use payjoin::bitcoin::OutPoint; +use payjoin::bitcoin::consensus::encode::serialize; +use payjoin::persist::SessionPersister; +use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent; +use payjoin::send::v2::SessionEvent as SenderSessionEvent; + +use crate::error::PayjoinDbError as Error; + +pub type Result = std::result::Result; + +/// Default filename for the payjoin database +pub const DB_FILENAME: &str = "payjoin.sqlite"; + +/// Returns the current Unix timestamp in seconds +#[inline] +fn now() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 +} + +pub struct Database { + conn: Mutex, +} + +impl Database { + pub fn create(path: impl AsRef) -> Result { + let conn = Connection::open(path.as_ref())?; + Self::init_schema(&conn)?; + Ok(Self { + conn: Mutex::new(conn), + }) + } + + fn conn(&self) -> MutexGuard<'_, Connection> { + self.conn + .lock() + .expect("Database mutex should not be poisoned") + } + + fn init_schema(conn: &Connection) -> Result<()> { + conn.execute("PRAGMA foreign_keys = ON", [])?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS send_sessions ( + session_id INTEGER PRIMARY KEY AUTOINCREMENT, + receiver_pubkey BLOB NOT NULL, + completed_at INTEGER + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS receive_sessions ( + session_id INTEGER PRIMARY KEY AUTOINCREMENT, + completed_at INTEGER + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS send_session_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + event_data TEXT NOT NULL, + created_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES send_sessions(session_id) + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS receive_session_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL, + event_data TEXT NOT NULL, + created_at INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES receive_sessions(session_id) + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS inputs_seen ( + outpoint BLOB PRIMARY KEY, + created_at INTEGER NOT NULL + )", + [], + )?; + + Ok(()) + } + + /// Inserts the input and returns true if the input was seen before, false otherwise. + /// This is used for replay protection to prevent probing attacks. + pub fn insert_input_seen_before(&self, input: OutPoint) -> Result { + let key = serialize(&input); + let was_seen_before = self.conn().execute( + "INSERT OR IGNORE INTO inputs_seen (outpoint, created_at) VALUES (?1, ?2)", + params![key, now()], + )? == 0; + Ok(was_seen_before) + } + + /// Returns IDs of all active (incomplete) receive sessions + pub fn get_recv_session_ids(&self) -> Result> { + let conn = self.conn(); + let mut stmt = + conn.prepare("SELECT session_id FROM receive_sessions WHERE completed_at IS NULL ORDER BY session_id DESC")?; + + let session_rows = stmt.query_map([], |row| { + let session_id: i64 = row.get(0)?; + Ok(SessionId(session_id)) + })?; + + let mut session_ids = Vec::new(); + for session_row in session_rows { + session_ids.push(session_row?); + } + + Ok(session_ids) + } + + /// Returns IDs of all active (incomplete) send sessions + pub fn get_send_session_ids(&self) -> Result> { + let conn = self.conn(); + let mut stmt = + conn.prepare("SELECT session_id FROM send_sessions WHERE completed_at IS NULL ORDER BY session_id DESC")?; + + let session_rows = stmt.query_map([], |row| { + let session_id: i64 = row.get(0)?; + Ok(SessionId(session_id)) + })?; + + let mut session_ids = Vec::new(); + for session_row in session_rows { + session_ids.push(session_row?); + } + + Ok(session_ids) + } + + /// Returns the receiver public key for a send session + pub fn get_send_session_receiver_pk(&self, session_id: &SessionId) -> Result { + let conn = self.conn(); + let mut stmt = + conn.prepare("SELECT receiver_pubkey FROM send_sessions WHERE session_id = ?1")?; + let receiver_pubkey: Vec = stmt.query_row(params![session_id.0], |row| row.get(0))?; + Ok(HpkePublicKey::from_compressed_bytes(&receiver_pubkey).expect("Valid receiver pubkey")) + } + + /// Returns IDs and completion timestamps of all completed send sessions + pub fn get_inactive_send_session_ids(&self) -> Result> { + let conn = self.conn(); + let mut stmt = conn.prepare( + "SELECT session_id, completed_at FROM send_sessions WHERE completed_at IS NOT NULL", + )?; + let session_rows = stmt.query_map([], |row| { + let session_id: i64 = row.get(0)?; + let completed_at: u64 = row.get(1)?; + Ok((SessionId(session_id), completed_at)) + })?; + + let mut session_ids = Vec::new(); + for session_row in session_rows { + session_ids.push(session_row?); + } + Ok(session_ids) + } + + /// Returns IDs and completion timestamps of all completed receive sessions + pub fn get_inactive_recv_session_ids(&self) -> Result> { + let conn = self.conn(); + let mut stmt = conn.prepare( + "SELECT session_id, completed_at FROM receive_sessions WHERE completed_at IS NOT NULL", + )?; + let session_rows = stmt.query_map([], |row| { + let session_id: i64 = row.get(0)?; + let completed_at: u64 = row.get(1)?; + Ok((SessionId(session_id), completed_at)) + })?; + + let mut session_ids = Vec::new(); + for session_row in session_rows { + session_ids.push(session_row?); + } + Ok(session_ids) + } + + /// Formats a Unix timestamp into local date time text. + pub fn format_unix_timestamp(&self, timestamp: u64) -> Result { + let Ok(timestamp) = i64::try_from(timestamp) else { + return Ok(format!("Invalid timestamp ({timestamp})")); + }; + let conn = self.conn(); + let dt: Option = conn.query_row( + "SELECT datetime(?1, 'unixepoch', 'localtime')", + params![timestamp], + |row| row.get(0), + )?; + Ok(dt.unwrap_or_else(|| format!("Invalid timestamp ({timestamp})"))) + } +} + +/// Wrapper type for session IDs +#[derive(Debug, Clone)] +pub struct SessionId(i64); + +impl core::ops::Deref for SessionId { + type Target = i64; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for SessionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl SessionId { + pub fn as_i64(&self) -> i64 { + self.0 + } +} + +/// Persister for payjoin v2 send sessions +#[derive(Clone)] +pub struct SenderPersister { + db: Arc, + session_id: SessionId, +} + +impl SenderPersister { + /// Creates a new sender persister, creating a new session in the database + pub fn new(db: Arc, receiver_pubkey: HpkePublicKey) -> Result { + let session_id: i64 = db.conn().query_row( + "INSERT INTO send_sessions (session_id, receiver_pubkey) VALUES (NULL, ?1) RETURNING session_id", + params![receiver_pubkey.to_compressed_bytes()], + |row| row.get(0), + )?; + + Ok(Self { + db, + session_id: SessionId(session_id), + }) + } + + /// Creates a persister from an existing session ID + pub fn from_id(db: Arc, id: SessionId) -> Self { + Self { db, session_id: id } + } +} + +impl SessionPersister for SenderPersister { + type SessionEvent = SenderSessionEvent; + type InternalStorageError = Error; + + fn save_event( + &self, + event: SenderSessionEvent, + ) -> std::result::Result<(), Self::InternalStorageError> { + let event_data = serde_json::to_string(&event).map_err(Error::Serialize)?; + + self.db.conn().execute( + "INSERT INTO send_session_events (session_id, event_data, created_at) VALUES (?1, ?2, ?3)", + params![*self.session_id, event_data, now()], + )?; + + Ok(()) + } + + fn load( + &self, + ) -> std::result::Result>, Self::InternalStorageError> + { + let conn = self.db.conn(); + let mut stmt = conn.prepare( + "SELECT event_data FROM send_session_events WHERE session_id = ?1 ORDER BY id ASC", + )?; + + let event_rows = stmt.query_map(params![*self.session_id], |row| { + let event_data: String = row.get(0)?; + Ok(event_data) + })?; + + let events: Vec = event_rows + .map(|row| { + let event_data = row.expect("Failed to read event data from database"); + serde_json::from_str::(&event_data) + .expect("Database corruption: failed to deserialize session event") + }) + .collect(); + + Ok(Box::new(events.into_iter())) + } + + fn close(&self) -> std::result::Result<(), Self::InternalStorageError> { + self.db.conn().execute( + "UPDATE send_sessions SET completed_at = ?1 WHERE session_id = ?2", + params![now(), *self.session_id], + )?; + + Ok(()) + } +} + +/// Persister for payjoin v2 receive sessions +#[derive(Clone)] +pub struct ReceiverPersister { + db: Arc, + session_id: SessionId, +} + +impl ReceiverPersister { + /// Creates a new receiver persister, creating a new session in the database + pub fn new(db: Arc) -> Result { + let session_id: i64 = db.conn().query_row( + "INSERT INTO receive_sessions (session_id) VALUES (NULL) RETURNING session_id", + [], + |row| row.get(0), + )?; + + Ok(Self { + db, + session_id: SessionId(session_id), + }) + } + + /// Creates a persister from an existing session ID + pub fn from_id(db: Arc, id: SessionId) -> Self { + Self { db, session_id: id } + } +} + +impl SessionPersister for ReceiverPersister { + type SessionEvent = ReceiverSessionEvent; + type InternalStorageError = Error; + + fn save_event( + &self, + event: ReceiverSessionEvent, + ) -> std::result::Result<(), Self::InternalStorageError> { + let event_data = serde_json::to_string(&event).map_err(Error::Serialize)?; + + self.db.conn().execute( + "INSERT INTO receive_session_events (session_id, event_data, created_at) VALUES (?1, ?2, ?3)", + params![*self.session_id, event_data, now()], + )?; + + Ok(()) + } + + fn load( + &self, + ) -> std::result::Result< + Box>, + Self::InternalStorageError, + > { + let conn = self.db.conn(); + let mut stmt = conn.prepare( + "SELECT event_data FROM receive_session_events WHERE session_id = ?1 ORDER BY id ASC", + )?; + + let event_rows = stmt.query_map(params![*self.session_id], |row| { + let event_data: String = row.get(0)?; + Ok(event_data) + })?; + + let events: Vec = event_rows + .map(|row| { + let event_data = row.expect("Failed to read event data from database"); + serde_json::from_str::(&event_data) + .expect("Database corruption: failed to deserialize session event") + }) + .collect(); + + Ok(Box::new(events.into_iter())) + } + + fn close(&self) -> std::result::Result<(), Self::InternalStorageError> { + self.db.conn().execute( + "UPDATE receive_sessions SET completed_at = ?1 WHERE session_id = ?2", + params![now(), *self.session_id], + )?; + + Ok(()) + } +} diff --git a/src/payjoin/mod.rs b/src/payjoin/mod.rs index f5e1274..974ddf9 100644 --- a/src/payjoin/mod.rs +++ b/src/payjoin/mod.rs @@ -1,29 +1,34 @@ use crate::error::BDKCliError as Error; -use crate::handlers::{broadcast_transaction, sync_wallet}; +use crate::handlers::{broadcast_transaction, open_payjoin_db, sync_wallet}; use crate::utils::BlockchainClient; use bdk_wallet::{ SignOptions, Wallet, bitcoin::{FeeRate, Psbt, Txid, consensus::encode::serialize_hex}, }; +use cli_table::{Cell, CellStruct, Style, Table}; use payjoin::bitcoin::TxIn; use payjoin::persist::{OptionalTransitionOutcome, SessionPersister}; use payjoin::receive::InputPair; use payjoin::receive::v2::{ HasReplyableError, Initialized, MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, ProvisionalProposal, ReceiveSession, Receiver, - SessionEvent as ReceiverSessionEvent, UncheckedOriginalPayload, WantsFeeRange, WantsInputs, - WantsOutputs, + SessionEvent as ReceiverSessionEvent, SessionOutcome as ReceiverSessionOutcome, + UncheckedOriginalPayload, WantsFeeRange, WantsInputs, WantsOutputs, + replay_event_log as replay_receiver_event_log, }; use payjoin::send::v2::{ PollingForProposal, SendSession, Sender, SessionEvent as SenderSessionEvent, SessionOutcome as SenderSessionOutcome, WithReplyKey, + replay_event_log as replay_sender_event_log, }; -use payjoin::{ImplementationError, UriExt}; +use payjoin::{HpkePublicKey, ImplementationError, UriExt}; use serde_json::{json, to_string_pretty}; use std::sync::{Arc, Mutex}; +use crate::payjoin::db::{ReceiverPersister, SenderPersister}; use crate::payjoin::ohttp::{RelayManager, fetch_ohttp_keys}; +pub mod db; pub mod ohttp; /// Implements all of the functions required to go through the Payjoin receive and send processes. @@ -35,13 +40,74 @@ pub mod ohttp; pub(crate) struct PayjoinManager<'a> { wallet: &'a mut Wallet, relay_manager: Arc>, + db: Arc, +} + +trait StatusText { + fn status_text(&self) -> &'static str; +} + +impl StatusText for SendSession { + fn status_text(&self) -> &'static str { + match self { + SendSession::WithReplyKey(_) | SendSession::PollingForProposal(_) => { + "Waiting for proposal" + } + SendSession::Closed(session_outcome) => match session_outcome { + SenderSessionOutcome::Failure => "Session failure", + SenderSessionOutcome::Success(_) => "Session success", + SenderSessionOutcome::Cancel => "Session cancelled", + }, + } + } +} + +impl StatusText for ReceiveSession { + fn status_text(&self) -> &'static str { + match self { + ReceiveSession::Initialized(_) => "Waiting for original proposal", + ReceiveSession::UncheckedOriginalPayload(_) + | ReceiveSession::MaybeInputsOwned(_) + | ReceiveSession::MaybeInputsSeen(_) + | ReceiveSession::OutputsUnknown(_) + | ReceiveSession::WantsOutputs(_) + | ReceiveSession::WantsInputs(_) + | ReceiveSession::WantsFeeRange(_) + | ReceiveSession::ProvisionalProposal(_) => "Processing original proposal", + ReceiveSession::PayjoinProposal(_) => "Payjoin proposal sent", + ReceiveSession::HasReplyableError(_) => { + "Session failure, waiting to post error response" + } + ReceiveSession::Monitor(_) => "Monitoring payjoin proposal", + ReceiveSession::Closed(session_outcome) => match session_outcome { + ReceiverSessionOutcome::Failure => "Session failure", + ReceiverSessionOutcome::Success(_) => { + "Session success, Payjoin proposal was broadcasted" + } + ReceiverSessionOutcome::Cancel => "Session cancelled", + ReceiverSessionOutcome::FallbackBroadcasted => "Fallback broadcasted", + }, + } + } +} + +struct SessionHistoryRow { + id: String, + role: &'static str, + status: String, + completed_at: Option, } impl<'a> PayjoinManager<'a> { - pub fn new(wallet: &'a mut Wallet, relay_manager: Arc>) -> Self { + pub fn new( + wallet: &'a mut Wallet, + relay_manager: Arc>, + db: Arc, + ) -> Self { Self { wallet, relay_manager, + db, } } @@ -51,7 +117,7 @@ impl<'a> PayjoinManager<'a> { directory: String, max_fee_rate: Option, ohttp_relays: Vec, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result { let address = self .wallet @@ -71,8 +137,8 @@ impl<'a> PayjoinManager<'a> { let ohttp_keys = fetch_ohttp_keys(ohttp_relays, &directory, self.relay_manager.clone()).await?; - // TODO: Implement proper persister. - let persister = payjoin::persist::NoopSessionPersister::::default(); + + let persister = crate::payjoin::db::ReceiverPersister::new(self.db.clone())?; let checked_max_fee_rate = max_fee_rate .map(|rate| FeeRate::from_sat_per_kwu(rate)) @@ -82,12 +148,7 @@ impl<'a> PayjoinManager<'a> { address.address, directory, ohttp_keys.ohttp_keys, - ) - .map_err(|e| { - Error::Generic(format!( - "Failed to initialize a Payjoin ReceiverBuilder: {e}" - )) - })? + )? .with_amount(payjoin::bitcoin::Amount::from_sat(amount)) .with_max_fee_rate(checked_max_fee_rate) .build() @@ -119,7 +180,7 @@ impl<'a> PayjoinManager<'a> { uri: String, fee_rate: u64, ohttp_relays: Vec, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result { let uri = payjoin::Uri::try_from(uri) .map_err(|e| Error::Generic(format!("Failed parsing to Payjoin URI: {}", e)))?; @@ -143,11 +204,7 @@ impl<'a> PayjoinManager<'a> { .add_recipient(uri.address.script_pubkey(), sats) .fee_rate(fee_rate); - tx_builder.finish().map_err(|e| { - Error::Generic(format!( - "Error occurred when building original Payjoin transaction: {e}" - )) - })? + tx_builder.finish()? }; if !self .wallet @@ -161,25 +218,16 @@ impl<'a> PayjoinManager<'a> { let txid = match uri.extras.pj_param() { payjoin::PjParam::V1(_) => { let (req, ctx) = payjoin::send::v1::SenderBuilder::new(original_psbt.clone(), uri) - .build_recommended(fee_rate) - .map_err(|e| { - Error::Generic(format!("Failed to build a Payjoin v1 sender: {e}")) - })? + .build_recommended(fee_rate)? .create_v1_post_request(); - let response = self - .send_payjoin_post_request(req) - .await - .map_err(|e| Error::Generic(format!("Failed to send request: {e}")))?; - - let psbt = ctx - .process_response(&response.bytes().await?) - .map_err(|e| Error::Generic(format!("Failed to send a Payjoin v1: {e}")))?; + let response = self.send_payjoin_post_request(req).await?; + let psbt = ctx.process_response(&response.bytes().await?)?; self.process_payjoin_proposal(psbt, blockchain_client) .await? } - payjoin::PjParam::V2(_) => { + payjoin::PjParam::V2(v2_param) => { let ohttp_relays: Vec = ohttp_relays .into_iter() .map(|s| url::Url::parse(&s)) @@ -194,31 +242,64 @@ impl<'a> PayjoinManager<'a> { )); } - // TODO: Implement proper persister. - let persister = - payjoin::persist::NoopSessionPersister::::default(); + use payjoin::send::v2::replay_event_log as replay_sender_event_log; - let sender = payjoin::send::v2::SenderBuilder::new(original_psbt.clone(), uri) - .build_recommended(fee_rate) - .map_err(|e| { - Error::Generic(format!("Failed to build a Payjoin v2 sender: {e}")) - })? - .save(&persister) - .map_err(|e| { - Error::Generic(format!( - "Failed to save the Payjoin v2 sender in the persister: {e}" - )) - })?; + // Check for existing session with the same receiver pubkey + let receiver_pubkey = v2_param.receiver_pubkey(); + let existing_session = self + .db + .get_send_session_ids() + .map_err(|e| Error::Generic(format!("{e}")))? + .into_iter() + .find_map(|session_id| { + let session_receiver_pubkey = self + .db + .get_send_session_receiver_pk(&session_id) + .expect("Receiver pubkey should exist if session id exists"); + if session_receiver_pubkey == *receiver_pubkey { + let sender_persister = + SenderPersister::from_id(self.db.clone(), session_id); + let (send_session, _) = replay_sender_event_log(&sender_persister) + .map_err(|e| { + Error::Generic(format!( + "Failed to replay sender event log: {:?}", + e + )) + }) + .ok()?; + Some((send_session, sender_persister)) + } else { + None + } + }); + + let (sender_state, persister) = if let Some((sender_state, persister)) = + existing_session + { + println!("Resuming existing sender session"); + (sender_state, persister) + } else { + let persister = { + let receiver_pubkey: HpkePublicKey = v2_param.receiver_pubkey().clone(); + SenderPersister::new(self.db.clone(), receiver_pubkey)? + }; - let selected_relay = - fetch_ohttp_keys(ohttp_relays, &sender.endpoint(), self.relay_manager.clone()) - .await? - .relay_url; + let sender = payjoin::send::v2::SenderBuilder::new(original_psbt.clone(), uri) + .build_recommended(fee_rate)? + .save(&persister) + .map_err(|e| { + Error::Generic(format!( + "Failed to save the Payjoin v2 sender in the persister: {e}" + )) + })?; + + (SendSession::WithReplyKey(sender), persister) + }; self.proceed_sender_session( - SendSession::WithReplyKey(sender), + sender_state, &persister, - selected_relay.to_string(), + ohttp_relays, blockchain_client, ) .await? @@ -237,7 +318,7 @@ impl<'a> PayjoinManager<'a> { persister: &impl SessionPersister, relay: impl payjoin::IntoUrl, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { match session { ReceiveSession::Initialized(proposal) => { @@ -306,18 +387,11 @@ impl<'a> PayjoinManager<'a> { persister: &impl SessionPersister, relay: impl payjoin::IntoUrl, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { let mut current_receiver_typestate = receiver; let next_receiver_typestate = loop { - let (req, context) = current_receiver_typestate - .create_poll_request(relay.as_str()) - .map_err(|e| { - Error::Generic(format!( - "Failed to create a poll request to read from the Payjoin directory: {e}" - )) - })?; - println!("Polling receive request..."); + let (req, context) = current_receiver_typestate.create_poll_request(relay.as_str())?; let response = self.send_payjoin_post_request(req).await?; let state_transition = current_receiver_typestate .process_response(response.bytes().await?.to_vec().as_slice(), context) @@ -353,7 +427,7 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { let next_receiver_typestate = receiver .assume_interactive_receiver() @@ -386,16 +460,11 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { let next_receiver_typestate = receiver - .check_inputs_not_owned(&mut |input| { - Ok(self.wallet.is_mine(input.to_owned())) - }) - .save(persister) - .map_err(|e| { - Error::Generic(format!("Error occurred when saving after checking if inputs in the original proposal are not owned: {e}")) - })?; + .check_inputs_not_owned(&mut |input| Ok(self.wallet.is_mine(input.to_owned()))) + .save(persister)?; self.check_no_inputs_seen_before( next_receiver_typestate, @@ -411,18 +480,13 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { - // This is not supported as there is no persistence of previous Payjoin attempts in BDK CLI - // yet. If there is support either in the BDK persister or Payjoin persister, this can be - // implemented, but it is not a concern as the use cases of the CLI does not warrant - // protection against probing attacks. - println!( - "Checking whether the inputs in the proposal were seen before to protect from probing attacks is not supported. Skipping the check..." - ); - let next_receiver_typestate = receiver.check_no_inputs_seen_before(&mut |_| Ok(false)).save(persister).map_err(|e| { - Error::Generic(format!("Error occurred when saving after checking if the inputs in the proposal were seen before: {e}")) - })?; + let db = self.db.clone(); + let next_receiver_typestate = receiver + .check_no_inputs_seen_before(&mut |input| Ok(db.insert_input_seen_before(*input)?)) + .save(persister)?; + self.identify_receiver_outputs( next_receiver_typestate, persister, @@ -437,13 +501,13 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { - let next_receiver_typestate = receiver.identify_receiver_outputs(&mut |output_script| { - Ok(self.wallet.is_mine(output_script.to_owned())) - }).save(persister).map_err(|e| { - Error::Generic(format!("Error occurred when saving after checking if the outputs in the original proposal are owned by the receiver: {e}")) - })?; + let next_receiver_typestate = receiver + .identify_receiver_outputs(&mut |output_script| { + Ok(self.wallet.is_mine(output_script.to_owned())) + }) + .save(persister)?; self.commit_outputs( next_receiver_typestate, @@ -459,7 +523,7 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { // This is a typestate to modify existing receiver-owned outputs in case the receiver wants // to do that. This is a very simple implementation of Payjoin so we are just going @@ -483,7 +547,7 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { let candidate_inputs: Vec = self .wallet @@ -503,18 +567,10 @@ impl<'a> PayjoinManager<'a> { .expect("Failed to create InputPair when contributing outputs to the proposal") }) .collect(); - let selected_input = receiver - .try_preserving_privacy(candidate_inputs) - .map_err(|e| { - Error::Generic(format!( - "Error occurred when trying to pick an unspent UTXO for input contribution: {e}" - )) - })?; + let selected_input = receiver.try_preserving_privacy(candidate_inputs)?; - let next_receiver_typestate = receiver.contribute_inputs(vec![selected_input]) - .map_err(|e| { - Error::Generic(format!("Error occurred when contributing the selected input to the proposal: {e}")) - })?.commit_inputs().save(persister) + let next_receiver_typestate = receiver.contribute_inputs(vec![selected_input])? + .commit_inputs().save(persister) .map_err(|e| { Error::Generic(format!("Error occurred when saving after committing to the inputs after receiver contribution: {e}")) })?; @@ -533,11 +589,11 @@ impl<'a> PayjoinManager<'a> { receiver: Receiver, persister: &impl SessionPersister, max_fee_rate: FeeRate, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { - let next_receiver_typestate = receiver.apply_fee_range(None, Some(max_fee_rate)).save(persister).map_err(|e| { - Error::Generic(format!("Error occurred when saving after applying the receiver fee range to the transaction: {e}")) - })?; + let next_receiver_typestate = receiver + .apply_fee_range(None, Some(max_fee_rate)) + .save(persister)?; self.finalize_proposal(next_receiver_typestate, persister, blockchain_client) .await } @@ -546,7 +602,7 @@ impl<'a> PayjoinManager<'a> { &mut self, receiver: Receiver, persister: &impl SessionPersister, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { let next_receiver_typestate = receiver .finalize_proposal(|psbt| { @@ -565,12 +621,7 @@ impl<'a> PayjoinManager<'a> { Ok(psbt_clone) }) - .save(persister) - .map_err(|e| { - Error::Generic(format!( - "Error occurred when saving after signing the Payjoin proposal: {e}" - )) - })?; + .save(persister)?; self.send_payjoin_proposal(next_receiver_typestate, persister, blockchain_client) .await @@ -580,24 +631,23 @@ impl<'a> PayjoinManager<'a> { &mut self, receiver: Receiver, persister: &impl SessionPersister, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { - let (req, ctx) = receiver.create_post_request( - self.relay_manager - .lock() - .expect("Lock should not be poisoned") - .get_selected_relay() - .expect("A relay should already be selected") - .as_str(), - ).map_err(|e| { + let (req, ctx) = receiver + .create_post_request( + self.unwrap_relay_or_else_fetch(vec![], None::<&str>) + .await? + .as_str(), + ) + .map_err(|e| { Error::Generic(format!("Error occurred when creating a post request for sending final Payjoin proposal: {e}")) })?; let res = self.send_payjoin_post_request(req).await?; let payjoin_psbt = receiver.psbt().clone(); - let next_receiver_typestate = receiver.process_response(&res.bytes().await?, ctx).save(persister).map_err(|e| { - Error::Generic(format!("Error occurred when saving after processing the response to the Payjoin proposal send: {e}")) - })?; + let next_receiver_typestate = receiver + .process_response(&res.bytes().await?, ctx) + .save(persister)?; println!( "Response successful. TXID: {}", payjoin_psbt.extract_tx_unchecked_fee_rate().compute_txid() @@ -607,72 +657,87 @@ impl<'a> PayjoinManager<'a> { .await; } - /// Syncs the blockchain once and then checks whether the Payjoin was broadcasted by the + /// Polls the blockchain periodically and checks whether the Payjoin was broadcasted by the /// sender. /// - /// The currenty implementation does not support checking for the Payjoin broadcast in a loop - /// and returning only when it is detected or if a timeout is reached because the [`sync_wallet`] - /// function consumes the BlockchainClient. BDK CLI supports multiple blockchain clients, and - /// at the time of writing, Kyoto consumes the client since BDK CLI is not designed for long-running - /// tasks. + /// This function syncs the wallet at regular intervals and checks for the Payjoin transaction + /// in a loop until it is detected or a timeout is reached. Since [`sync_wallet`] now accepts + /// a reference to the BlockchainClient, we can call it multiple times in a loop. async fn monitor_payjoin_proposal( &mut self, receiver: Receiver, persister: &impl SessionPersister, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result<(), Error> { - let wait_time_for_sync = 3; - let poll_internal = tokio::time::Duration::from_secs(wait_time_for_sync); + let poll_interval = tokio::time::Duration::from_millis(200); + let sync_interval = tokio::time::Duration::from_secs(3); + let timeout_duration = tokio::time::Duration::from_secs(15); println!( - "Waiting for {wait_time_for_sync} seconds before syncing the blockchain and checking if the transaction has been broadcast..." + "Polling for Payjoin transaction broadcast. This may take up to {} seconds...", + timeout_duration.as_secs() ); - tokio::time::sleep(poll_internal).await; - sync_wallet(blockchain_client, self.wallet).await?; - - let check_result = receiver - .check_payment( - |txid| { - let Some(tx_details) = self.wallet.tx_details(txid) else { - return Err(ImplementationError::from("Cannot find the transaction in the mempool or the blockchain")); - }; - - let is_seen = match tx_details.chain_position { - bdk_wallet::chain::ChainPosition::Confirmed { .. } => true, - bdk_wallet::chain::ChainPosition::Unconfirmed { first_seen: Some(_), .. } => true, - _ => false - }; - - if is_seen { - return Ok(Some(tx_details.tx.as_ref().clone())); + let result = tokio::time::timeout(timeout_duration, async { + let mut poll_timer = tokio::time::interval(poll_interval); + let mut sync_timer = tokio::time::interval(sync_interval); + poll_timer.tick().await; + sync_timer.tick().await; + sync_wallet(blockchain_client, self.wallet).await?; + + loop { + tokio::select! { + _ = poll_timer.tick() => { + // Time to check payment + let check_result = receiver + .check_payment( + |txid| { + let Some(tx_details) = self.wallet.tx_details(txid) else { + return Err(ImplementationError::from("Cannot find the transaction in the mempool or the blockchain")); + }; + + let is_seen = match tx_details.chain_position { + bdk_wallet::chain::ChainPosition::Confirmed { .. } => true, + bdk_wallet::chain::ChainPosition::Unconfirmed { first_seen: Some(_), .. } => true, + _ => false + }; + + if is_seen { + return Ok(Some(tx_details.tx.as_ref().clone())); + } + return Err(ImplementationError::from("Cannot find the transaction in the mempool or the blockchain")); + }, + |outpoint| { + let utxo = self.wallet.get_utxo(outpoint); + match utxo { + Some(_) => Ok(false), + None => Ok(true), + } + } + ) + .save(persister); + + if let Ok(OptionalTransitionOutcome::Progress(_)) = check_result { + println!("Payjoin transaction detected in the mempool!"); + return Ok(()); + } + // For Stasis or Err, continue polling (implicit - falls through to next loop iteration) } - return Err(ImplementationError::from("Cannot find the transaction in the mempool or the blockchain")); - }, - |outpoint| { - let utxo = self.wallet.get_utxo(outpoint); - match utxo { - Some(_) => Ok(false), - None => Ok(true), + _ = sync_timer.tick() => { + // Time to sync wallet + sync_wallet(blockchain_client, self.wallet).await?; } } - ) - .save(persister) - .map_err(|e| { - Error::Generic(format!("Error occurred when saving after checking that sender has broadcasted the Payjoin transaction: {e}")) - }); - - match check_result { - Ok(_) => { - println!("Payjoin transaction detected in the mempool!"); - } - Err(_) => { - println!( - "Transaction was not found in the mempool after {wait_time_for_sync}. Check the state of the transaction manually after running the sync command." - ); } + }) + .await; + + match result { + Ok(ok) => ok, + Err(_) => Err(Error::Generic(format!( + "Timeout waiting for Payjoin transaction broadcast after {:?}. Check the state of the transaction manually after running the sync command.", + timeout_duration + ))), } - - Ok(()) } async fn handle_error( @@ -682,11 +747,8 @@ impl<'a> PayjoinManager<'a> { ) -> Result<(), Error> { let (err_req, err_ctx) = receiver .create_error_request( - self.relay_manager - .lock() - .expect("Lock should not be poisoned") - .get_selected_relay() - .expect("A relay should already be selected") + self.unwrap_relay_or_else_fetch(vec![], None::<&str>) + .await? .as_str(), ) .map_err(|e| { @@ -733,17 +795,33 @@ impl<'a> PayjoinManager<'a> { &self, session: SendSession, persister: &impl SessionPersister, - relay: impl payjoin::IntoUrl, - blockchain_client: BlockchainClient, + ohttp_relays: Vec, + blockchain_client: &BlockchainClient, ) -> Result { match session { SendSession::WithReplyKey(context) => { - self.post_original_proposal(context, relay, persister, blockchain_client) - .await + let relay = self + .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint())) + .await?; + self.post_original_proposal( + context, + persister, + blockchain_client, + relay.to_string(), + ) + .await } SendSession::PollingForProposal(context) => { - self.get_proposed_payjoin_proposal(context, relay, persister, blockchain_client) - .await + let relay = self + .unwrap_relay_or_else_fetch(ohttp_relays, Some(context.endpoint())) + .await?; + self.get_proposed_payjoin_proposal( + context, + persister, + blockchain_client, + relay.to_string(), + ) + .await } SendSession::Closed(SenderSessionOutcome::Success(psbt)) => { self.process_payjoin_proposal(psbt, blockchain_client).await @@ -752,43 +830,57 @@ impl<'a> PayjoinManager<'a> { } } + async fn unwrap_relay_or_else_fetch( + &self, + ohttp_relays: Vec, + directory: Option, + ) -> Result { + let selected_relay = self + .relay_manager + .lock() + .expect("Lock should not be poisoned") + .get_selected_relay(); + match selected_relay { + Some(relay) => Ok(relay), + None => { + let directory = directory.ok_or_else(|| { + Error::Generic("No directory URL provided and no relay selected".to_string()) + })?; + Ok( + fetch_ohttp_keys(ohttp_relays, directory, self.relay_manager.clone()) + .await? + .relay_url, + ) + } + } + } + async fn post_original_proposal( &self, sender: Sender, - relay: impl payjoin::IntoUrl, persister: &impl SessionPersister, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, + relay: String, ) -> Result { - let (req, ctx) = sender.create_v2_post_request(relay.as_str()).map_err(|e| { - Error::Generic(format!( - "Failed to create a post request for a Payjoin send: {e}" - )) - })?; + let (req, ctx) = sender.create_v2_post_request(relay.as_str())?; let response = self.send_payjoin_post_request(req).await?; let sender = sender .process_response(&response.bytes().await?, ctx) - .save(persister) - .map_err(|e| { - Error::Generic(format!("Failed to persist the Payjoin send after successfully sending original proposal: {e}")) - })?; - self.get_proposed_payjoin_proposal(sender, relay, persister, blockchain_client) + .save(persister)?; + self.get_proposed_payjoin_proposal(sender, persister, blockchain_client, relay) .await } async fn get_proposed_payjoin_proposal( &self, sender: Sender, - relay: impl payjoin::IntoUrl, persister: &impl SessionPersister, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, + relay: String, ) -> Result { let mut sender = sender.clone(); loop { - let (req, ctx) = sender.create_poll_request(relay.as_str()).map_err(|e| { - Error::Generic(format!( - "Failed to create a poll request during a Payjoin send: {e}" - )) - })?; + let (req, ctx) = sender.create_poll_request(relay.as_str())?; let response = self.send_payjoin_post_request(req).await?; let processed_response = sender .process_response(&response.bytes().await?, ctx) @@ -815,7 +907,7 @@ impl<'a> PayjoinManager<'a> { async fn process_payjoin_proposal( &self, mut psbt: Psbt, - blockchain_client: BlockchainClient, + blockchain_client: &BlockchainClient, ) -> Result { if !self.wallet.sign(&mut psbt, SignOptions::default())? { return Err(Error::Generic( @@ -838,4 +930,249 @@ impl<'a> PayjoinManager<'a> { .send() .await } + + /// Resume pending payjoin sessions from the database + pub async fn resume_payjoins( + &mut self, + directory: String, + ohttp_relays: Vec, + session_id: Option, + blockchain_client: &BlockchainClient, + ) -> Result { + let db = self.db.clone(); + let mut recv_session_ids = db.get_recv_session_ids()?; + let mut send_session_ids = db.get_send_session_ids()?; + + if let Some(session_id) = session_id { + recv_session_ids.retain(|id| id.as_i64() == session_id); + send_session_ids.retain(|id| id.as_i64() == session_id); + + if recv_session_ids.is_empty() && send_session_ids.is_empty() { + return Ok(serde_json::to_string_pretty(&json!({ + "message": format!("No active session found for session_id {}.", session_id) + }))?); + } + } + + if recv_session_ids.is_empty() && send_session_ids.is_empty() { + return Ok(serde_json::to_string_pretty(&json!({ + "message": "No sessions to resume." + }))?); + } + + let ohttp_relays: Vec = ohttp_relays + .into_iter() + .map(|s| url::Url::parse(&s)) + .collect::>() + .map_err(|e| Error::Generic(format!("Failed to parse OHTTP URLs: {e}")))?; + + let relay = self + .unwrap_relay_or_else_fetch(ohttp_relays, Some(&directory)) + .await?; + + let max_fee_rate = FeeRate::BROADCAST_MIN; + let total_sessions = recv_session_ids.len() + send_session_ids.len(); + let mut completed = 0usize; + let mut timed_out = 0usize; + let mut failed = 0usize; + + println!("Resuming {} payjoin session(s)...\n", total_sessions); + + // Resume receiver sessions + for session_id in recv_session_ids { + let persister = ReceiverPersister::from_id(db.clone(), session_id.clone()); + match replay_receiver_event_log(&persister) { + Ok((receiver_state, _)) => { + println!("Resuming receiver session {}", session_id); + match tokio::time::timeout( + std::time::Duration::from_secs(30), + self.proceed_receiver_session( + receiver_state, + &persister, + relay.as_str(), + max_fee_rate, + blockchain_client, + ), + ) + .await + { + Ok(Ok(_)) => { + completed += 1; + } + Ok(Err(e)) => { + failed += 1; + println!("Receiver session {} failed: {}", session_id, e); + } + Err(_) => { + timed_out += 1; + println!("Receiver session {} timed out", session_id); + } + } + } + Err(e) => { + failed += 1; + println!("Failed to replay receiver session {}: {:?}", session_id, e); + } + } + } + + // Resume sender sessions + for session_id in send_session_ids { + let persister = SenderPersister::from_id(db.clone(), session_id.clone()); + match replay_sender_event_log(&persister) { + Ok((sender_state, _)) => { + println!("Resuming sender session {}", session_id); + match tokio::time::timeout( + std::time::Duration::from_secs(30), + self.proceed_sender_session( + sender_state, + &persister, + vec![relay.clone()], + blockchain_client, + ), + ) + .await + { + Ok(Ok(_)) => { + completed += 1; + } + Ok(Err(e)) => { + failed += 1; + println!("Sender session {} failed: {}", session_id, e); + } + Err(_) => { + timed_out += 1; + println!("Sender session {} timed out", session_id); + } + } + } + Err(e) => { + failed += 1; + println!("Failed to replay sender session {}: {:?}", session_id, e); + } + } + } + + Ok(serde_json::to_string_pretty(&json!({ + "message": format!("Resumed polling for {} session(s).", total_sessions), + "outcome": format!( + "Completed: {}, timed out: {}, failed: {}.", + completed, timed_out, failed + ) + }))?) + } + + /// Show payjoin session history + pub fn history(datadir: Option) -> Result { + let db = open_payjoin_db(datadir)?; + let mut send_rows: Vec = Vec::new(); + let mut recv_rows: Vec = Vec::new(); + + // Active send sessions + for session_id in db + .get_send_session_ids() + .map_err(|e| Error::Generic(format!("{e}")))? + { + let persister = SenderPersister::from_id(db.clone(), session_id.clone()); + let status = match replay_sender_event_log(&persister) { + Ok((state, _)) => state.status_text().to_string(), + Err(e) => e.to_string(), + }; + send_rows.push(SessionHistoryRow { + id: session_id.to_string(), + role: "Sender", + status, + completed_at: None, + }); + } + + // Active receive sessions + for session_id in db + .get_recv_session_ids() + .map_err(|e| Error::Generic(format!("{e}")))? + { + let persister = ReceiverPersister::from_id(db.clone(), session_id.clone()); + let status = match replay_receiver_event_log(&persister) { + Ok((state, _)) => state.status_text().to_string(), + Err(e) => e.to_string(), + }; + recv_rows.push(SessionHistoryRow { + id: session_id.to_string(), + role: "Receiver", + status, + completed_at: None, + }); + } + + // Completed send sessions + for (session_id, completed_at) in db + .get_inactive_send_session_ids() + .map_err(|e| Error::Generic(format!("{e}")))? + { + let persister = SenderPersister::from_id(db.clone(), session_id.clone()); + let status = match replay_sender_event_log(&persister) { + Ok((state, _)) => state.status_text().to_string(), + Err(e) => e.to_string(), + }; + let completed_at = db + .format_unix_timestamp(completed_at) + .map_err(|e| Error::Generic(format!("{e}")))?; + send_rows.push(SessionHistoryRow { + id: session_id.to_string(), + role: "Sender", + status, + completed_at: Some(completed_at), + }); + } + + // Completed receive sessions + for (session_id, completed_at) in db + .get_inactive_recv_session_ids() + .map_err(|e| Error::Generic(format!("{e}")))? + { + let persister = ReceiverPersister::from_id(db.clone(), session_id.clone()); + let status = match replay_receiver_event_log(&persister) { + Ok((state, _)) => state.status_text().to_string(), + Err(e) => e.to_string(), + }; + let completed_at = db + .format_unix_timestamp(completed_at) + .map_err(|e| Error::Generic(format!("{e}")))?; + recv_rows.push(SessionHistoryRow { + id: session_id.to_string(), + role: "Receiver", + status, + completed_at: Some(completed_at), + }); + } + + let rows: Vec> = send_rows + .iter() + .chain(recv_rows.iter()) + .map(|row| { + vec![ + row.id.as_str().cell(), + row.role.cell(), + row.completed_at + .clone() + .unwrap_or_else(|| "Not Completed".to_string()) + .cell(), + row.status.as_str().cell(), + ] + }) + .collect(); + + let table = rows + .table() + .title(vec![ + "Session ID".cell().bold(true), + "Sender/Receiver".cell().bold(true), + "Completed At".cell().bold(true), + "Status".cell().bold(true), + ]) + .display() + .map_err(|e| Error::Generic(e.to_string()))?; + + Ok(format!("{table}")) + } } diff --git a/src/utils.rs b/src/utils.rs index 73d3453..448675a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -158,7 +158,15 @@ pub(crate) enum BlockchainClient { }, #[cfg(feature = "cbf")] - KyotoClient { client: Box }, + KyotoClient { client: KyotoClientHandle }, +} + +/// Handle for the Kyoto client after the node has been started. +/// Contains only the components needed for sync and broadcast operations. +#[cfg(feature = "cbf")] +pub struct KyotoClientHandle { + pub requester: bdk_kyoto::Requester, + pub update_subscriber: tokio::sync::Mutex, } #[cfg(any( @@ -215,13 +223,32 @@ pub(crate) fn new_blockchain_client( let scan_type = Sync; let builder = Builder::new(_wallet.network()); - let client = builder + let light_client = builder .required_peers(wallet_opts.compactfilter_opts.conn_count) .data_dir(&_datadir) .build_with_wallet(_wallet, scan_type)?; + let LightClient { + requester, + info_subscriber, + warning_subscriber, + update_subscriber, + node, + } = light_client; + + let subscriber = tracing_subscriber::FmtSubscriber::new(); + let _ = tracing::subscriber::set_global_default(subscriber); + + tokio::task::spawn(async move { node.run().await }); + tokio::task::spawn( + async move { trace_logger(info_subscriber, warning_subscriber).await }, + ); + BlockchainClient::KyotoClient { - client: Box::new(client), + client: KyotoClientHandle { + requester, + update_subscriber: tokio::sync::Mutex::new(update_subscriber), + }, } } }; @@ -318,29 +345,17 @@ pub async fn trace_logger( // Handle Kyoto Client sync #[cfg(feature = "cbf")] -pub async fn sync_kyoto_client(wallet: &mut Wallet, client: Box) -> Result<(), Error> { - let LightClient { - requester, - info_subscriber, - warning_subscriber, - mut update_subscriber, - node, - } = *client; - - let subscriber = tracing_subscriber::FmtSubscriber::new(); - tracing::subscriber::set_global_default(subscriber) - .map_err(|e| Error::Generic(format!("SetGlobalDefault error: {e}")))?; - - tokio::task::spawn(async move { node.run().await }); - tokio::task::spawn(async move { trace_logger(info_subscriber, warning_subscriber).await }); - - if !requester.is_running() { +pub async fn sync_kyoto_client( + wallet: &mut Wallet, + handle: &KyotoClientHandle, +) -> Result<(), Error> { + if !handle.requester.is_running() { tracing::error!("Kyoto node is not running"); return Err(Error::Generic("Kyoto node failed to start".to_string())); } tracing::info!("Kyoto node is running"); - let update = update_subscriber.update().await?; + let update = handle.update_subscriber.lock().await.update().await?; tracing::info!("Received update: applying to wallet"); wallet .apply_update(update)