diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index ee7153ef1..ae58749cb 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -6,7 +6,7 @@ use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, Mutex, RwLock}; use tokio::task::JoinSet; use tokio::time; @@ -791,6 +791,15 @@ impl PeerNetworkManager { } }); } + Some(NetworkRequest::SendMessageToPeer(addr, msg)) => { + log::debug!("Request processor: sending {} to {}", msg.cmd(), addr); + let this = this.clone(); + tokio::spawn(async move { + if let Err(e) = this.send_to_specific_peer(addr, msg).await { + log::error!("Request processor: failed to send to {}: {}", addr, e); + } + }); + } None => { log::info!("Request processor: channel closed"); break; @@ -1068,54 +1077,7 @@ impl PeerNetworkManager { .find(|(a, _)| *a == selected_peer) .ok_or_else(|| NetworkError::ConnectionFailed("Selected peer not found".to_string()))?; - // Upgrade GetHeaders to GetHeaders2 if this specific peer supports it and not disabled - let peer_supports_headers2 = { - let peer_guard = peer.read().await; - peer_guard.can_request_headers2() - }; - let message = match message { - NetworkMessage::GetHeaders(get_headers) - if !self.headers2_disabled.lock().await.contains(addr) - && peer_supports_headers2 => - { - log::debug!( - "Upgrading GetHeaders to GetHeaders2 for peer {}: {:?}", - addr, - get_headers - ); - NetworkMessage::GetHeaders2(get_headers) - } - other => other, - }; - // Reduce verbosity for common sync messages - match &message { - NetworkMessage::GetHeaders(_) - | NetworkMessage::GetCFilters(_) - | NetworkMessage::GetCFHeaders(_) => { - log::debug!("Sending {} to {}", message.cmd(), addr); - } - NetworkMessage::GetHeaders2(gh2) => { - log::info!("📤 Sending GetHeaders2 to {} - version: {}, locator_count: {}, locator: {:?}, stop: {}", - addr, - gh2.version, - gh2.locator_hashes.len(), - gh2.locator_hashes.iter().take(2).collect::>(), - gh2.stop_hash - ); - } - NetworkMessage::SendHeaders2 => { - log::info!("🤝 Sending SendHeaders2 to {} - requesting compressed headers", addr); - } - _ => { - log::trace!("Sending {:?} to {}", message.cmd(), addr); - } - } - - let mut peer_guard = peer.write().await; - peer_guard - .send_message(message) - .await - .map_err(|e| NetworkError::ProtocolError(format!("Failed to send to {}: {}", addr, e))) + self.send_message_to_peer(addr, peer, message).await } /// Send a message distributed across connected peers using round-robin selection. @@ -1185,14 +1147,40 @@ impl PeerNetworkManager { let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % selected_peers.len(); let (addr, peer) = &selected_peers[idx]; - // Upgrade GetHeaders to GetHeaders2 if peer supports it + log::debug!( + "Distributing {} request to peer {} (round-robin idx {})", + message.cmd(), + addr, + idx + ); + + self.send_message_to_peer(addr, peer, message).await + } + + /// Send a message to a specific peer by address. + async fn send_to_specific_peer( + &self, + addr: SocketAddr, + message: NetworkMessage, + ) -> NetworkResult<()> { + let peer = self.pool.get_peer(&addr).await.ok_or_else(|| { + NetworkError::ConnectionFailed(format!("Peer {} not connected", addr)) + })?; + self.send_message_to_peer(&addr, &peer, message).await + } + + /// Send a message to the given peer. + /// For GetHeaders messages upgrade to GetHeaders2 if the peer supports it. + async fn send_message_to_peer( + &self, + addr: &SocketAddr, + peer: &Arc>, + message: NetworkMessage, + ) -> NetworkResult<()> { let message = match message { NetworkMessage::GetHeaders(get_headers) => { - let peer_supports_headers2 = { - let peer_guard = peer.read().await; - peer_guard.can_request_headers2() - }; - if peer_supports_headers2 && !self.headers2_disabled.lock().await.contains(addr) { + let supports_headers2 = peer.read().await.can_request_headers2(); + if supports_headers2 && !self.headers2_disabled.lock().await.contains(addr) { log::debug!("Upgrading GetHeaders to GetHeaders2 for peer {}", addr); NetworkMessage::GetHeaders2(get_headers) } else { @@ -1202,13 +1190,6 @@ impl PeerNetworkManager { other => other, }; - log::debug!( - "Distributing {} request to peer {} (round-robin idx {})", - message.cmd(), - addr, - idx - ); - let mut peer_guard = peer.write().await; peer_guard .send_message(message) diff --git a/dash-spv/src/network/mod.rs b/dash-spv/src/network/mod.rs index 70a12b477..2f6d9882b 100644 --- a/dash-spv/src/network/mod.rs +++ b/dash-spv/src/network/mod.rs @@ -43,8 +43,10 @@ const FILTER_TYPE_DEFAULT: u8 = 0; /// Request to send to network. #[derive(Debug)] pub enum NetworkRequest { - /// Send a message to the network. + /// Send a message to the network (distributed across peers). SendMessage(NetworkMessage), + /// Send a message to a specific peer by address. + SendMessageToPeer(SocketAddr, NetworkMessage), } /// Handle for managers to queue outgoing network requests. @@ -68,6 +70,13 @@ impl RequestSender { .map_err(|e| NetworkError::ProtocolError(e.to_string())) } + /// Queue a message to be sent to a specific peer by address. + fn send_message_to_peer(&self, address: SocketAddr, msg: NetworkMessage) -> NetworkResult<()> { + self.tx + .send(NetworkRequest::SendMessageToPeer(address, msg)) + .map_err(|e| NetworkError::ProtocolError(e.to_string())) + } + pub fn request_inventory(&self, inventory: Vec) -> NetworkResult<()> { self.send_message(NetworkMessage::GetData(inventory)) } @@ -79,6 +88,20 @@ impl RequestSender { ))) } + pub fn request_block_headers_from_peer( + &self, + start_hash: BlockHash, + address: SocketAddr, + ) -> NetworkResult<()> { + self.send_message_to_peer( + address, + NetworkMessage::GetHeaders(GetHeadersMessage::new( + vec![start_hash], + BlockHash::all_zeros(), + )), + ) + } + pub fn request_filter_headers( &self, start_height: u32, diff --git a/dash-spv/src/network/tests.rs b/dash-spv/src/network/tests.rs index ffb992fd1..8f4ba5f1f 100644 --- a/dash-spv/src/network/tests.rs +++ b/dash-spv/src/network/tests.rs @@ -37,3 +37,38 @@ mod pool_tests { // Verify pool limits indirectly through methods; avoid constant assertions } } + +#[cfg(test)] +mod request_sender_tests { + use crate::network::{NetworkRequest, RequestSender}; + use dashcore::network::message::NetworkMessage; + use tokio::sync::mpsc; + + #[test] + fn test_send_message_to_peer_queues_correct_variant() { + let (tx, mut rx) = mpsc::unbounded_channel(); + let sender = RequestSender::new(tx); + let addr = "192.168.1.1:9999".parse().unwrap(); + let msg = NetworkMessage::Verack; + + sender.send_message_to_peer(addr, msg).unwrap(); + + let request = rx.try_recv().unwrap(); + let NetworkRequest::SendMessageToPeer(recv_addr, recv_msg) = request else { + panic!("Expected SendMessageToPeer variant"); + }; + assert_eq!(recv_addr, addr); + assert!(matches!(recv_msg, NetworkMessage::Verack)); + } + + #[test] + fn test_send_message_to_peer_returns_error_on_closed_channel() { + let (tx, rx) = mpsc::unbounded_channel(); + let sender = RequestSender::new(tx); + drop(rx); + + let addr = "192.168.1.1:9999".parse().unwrap(); + let result = sender.send_message_to_peer(addr, NetworkMessage::Verack); + assert!(result.is_err()); + } +} diff --git a/dash-spv/src/sync/filters/pipeline.rs b/dash-spv/src/sync/filters/pipeline.rs index 6906afef1..e83e0ad65 100644 --- a/dash-spv/src/sync/filters/pipeline.rs +++ b/dash-spv/src/sync/filters/pipeline.rs @@ -785,7 +785,9 @@ mod tests { // Verify message was sent let request = rx.try_recv().unwrap(); - let NetworkRequest::SendMessage(msg) = request; + let NetworkRequest::SendMessage(msg) = request else { + panic!("Expected SendMessage"); + }; if let NetworkMessage::GetCFilters(gcf) = msg { assert_eq!(gcf.start_height, 0); assert_eq!(gcf.filter_type, 0);