diff --git a/Cargo.lock b/Cargo.lock index cb6ccfbce8..b07d13098c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4756,7 +4756,10 @@ dependencies = [ "async-trait", "datafusion 0.1.0", "iox_query", + "metric", + "parking_lot 0.12.1", "predicate", + "tokio", "workspace-hack", ] @@ -4816,6 +4819,7 @@ dependencies = [ "observability_deps", "panic_logging", "parking_lot 0.12.1", + "pin-project", "predicate", "prost", "query_functions", diff --git a/querier/src/database.rs b/querier/src/database.rs index b63376874a..470cf4afff 100644 --- a/querier/src/database.rs +++ b/querier/src/database.rs @@ -11,7 +11,7 @@ use iox_query::exec::Executor; use parquet_file::storage::ParquetStorage; use service_common::QueryDatabaseProvider; use std::sync::Arc; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; /// The number of entries to store in the circular query buffer log. /// @@ -60,6 +60,13 @@ impl QueryDatabaseProvider for QuerierDatabase { async fn db(&self, name: &str) -> Option> { self.namespace(name).await } + + async fn acquire_semaphore(&self) -> OwnedSemaphorePermit { + Arc::clone(&self.query_execution_semaphore) + .acquire_owned() + .await + .expect("Semaphore should not be closed by anyone") + } } impl QuerierDatabase { @@ -111,12 +118,6 @@ impl QuerierDatabase { /// This will await the internal namespace semaphore. Existence of namespaces is checked AFTER a semaphore permit /// was acquired since this lowers the chance that we obtain stale data. pub async fn namespace(&self, name: &str) -> Option> { - // get the permit first - let permit = Arc::clone(&self.query_execution_semaphore) - .acquire_owned() - .await - .expect("Semaphore should NOT be closed by now"); - let name = Arc::from(name.to_owned()); let schema = self .catalog_cache @@ -130,7 +131,6 @@ impl QuerierDatabase { Arc::clone(&self.exec), Arc::clone(&self.ingester_connection), Arc::clone(&self.query_log), - permit, ))) } @@ -158,10 +158,7 @@ impl QuerierDatabase { #[cfg(test)] mod tests { - use std::{future::Future, time::Duration}; - use iox_tests::util::TestCatalog; - use tokio::pin; use crate::create_ingester_connection_for_testing; @@ -243,79 +240,4 @@ mod tests { assert_eq!(namespaces[0].name, "ns1"); assert_eq!(namespaces[1].name, "ns2"); } - - #[tokio::test] - async fn test_query_execution_semaphore() { - let catalog = TestCatalog::new(); - - let catalog_cache = Arc::new(CatalogCache::new( - catalog.catalog(), - catalog.time_provider(), - catalog.metric_registry(), - usize::MAX, - )); - let db = QuerierDatabase::new( - catalog_cache, - catalog.metric_registry(), - ParquetStorage::new(catalog.object_store()), - catalog.exec(), - create_ingester_connection_for_testing(), - 2, - ); - - catalog.create_namespace("ns1").await; - catalog.create_namespace("ns2").await; - catalog.create_namespace("ns3").await; - - // consume all semaphore permits - let ns1 = db.namespace("ns1").await.unwrap(); - let ns2 = db.namespace("ns2").await.unwrap(); - - // cannot get any new namespace, even when we already have a namespace for the same name - let fut3 = db.namespace("ns3"); - let fut1 = db.namespace("ns1"); - let fut9 = db.namespace("ns9"); - let fut2 = db.namespace("ns2"); - pin!(fut3); - pin!(fut1); - pin!(fut9); - pin!(fut2); - assert_fut_pending(&mut fut3).await; - assert_fut_pending(&mut fut1).await; - assert_fut_pending(&mut fut9).await; - assert_fut_pending(&mut fut2).await; - - // dropping the newest namespace frees a permit - drop(ns2); - let ns3 = fut3.await.unwrap(); - assert_fut_pending(&mut fut1).await; - assert_fut_pending(&mut fut9).await; - assert_fut_pending(&mut fut2).await; - - // dropping the newest namespace frees a permit - drop(ns3); - let _ns1b = fut1.await.unwrap(); - assert_fut_pending(&mut fut9).await; - assert_fut_pending(&mut fut2).await; - - // dropping the oldest namespace frees a permit - drop(ns1); - assert!(fut9.await.is_none()); - // because "ns9" did not exist, we immediately get a new permit - fut2.await.unwrap(); - } - - /// Assert that given future is pending. - /// - /// This will try to poll the future a bit to ensure that it is not stuck in tokios task preemption. - async fn assert_fut_pending(fut: &mut F) - where - F: Future + Send + Unpin, - F::Output: std::fmt::Debug, - { - tokio::select! { - x = fut => panic!("future is not pending, yielded: {x:?}"), - _ = tokio::time::sleep(Duration::from_millis(10)) => {}, - }; - } } diff --git a/querier/src/namespace/mod.rs b/querier/src/namespace/mod.rs index 922320ed06..d0b2bfa672 100644 --- a/querier/src/namespace/mod.rs +++ b/querier/src/namespace/mod.rs @@ -9,7 +9,6 @@ use iox_query::exec::Executor; use parquet_file::storage::ParquetStorage; use schema::Schema; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; mod query_access; @@ -42,11 +41,6 @@ pub struct QuerierNamespace { /// Query log. query_log: Arc, - - /// Permit that limits the number of concurrent active namespaces (and thus - /// also queries) - #[allow(dead_code)] - permit: OwnedSemaphorePermit, } impl QuerierNamespace { @@ -58,7 +52,6 @@ impl QuerierNamespace { exec: Arc, ingester_connection: Arc, query_log: Arc, - permit: OwnedSemaphorePermit, ) -> Self { let tables: HashMap<_, _> = schema .tables @@ -90,7 +83,6 @@ impl QuerierNamespace { exec, catalog_cache: Arc::clone(chunk_adapter.catalog_cache()), query_log, - permit, } } @@ -114,9 +106,6 @@ impl QuerierNamespace { )); let query_log = Arc::new(QueryLog::new(10, time_provider)); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.try_acquire_owned().unwrap(); - Self::new( chunk_adapter, schema, @@ -124,7 +113,6 @@ impl QuerierNamespace { exec, ingester_connection, query_log, - permit, ) } diff --git a/service_common/Cargo.toml b/service_common/Cargo.toml index 837f5ef09c..176efa86c4 100644 --- a/service_common/Cargo.toml +++ b/service_common/Cargo.toml @@ -10,6 +10,9 @@ edition = "2021" datafusion = { path = "../datafusion" } predicate = { path = "../predicate" } iox_query = { path = "../iox_query" } +metric = { path = "../metric" } +parking_lot = "0.12" +tokio = { version = "1.18", features = ["macros", "parking_lot", "rt-multi-thread", "sync", "time"] } workspace-hack = { path = "../workspace-hack"} # Crates.io dependencies, in alphabetical order diff --git a/service_common/src/lib.rs b/service_common/src/lib.rs index 0d417a4e92..f941c5de79 100644 --- a/service_common/src/lib.rs +++ b/service_common/src/lib.rs @@ -1,11 +1,13 @@ //! Common methods for RPC service implementations pub mod planner; +pub mod test_util; use std::sync::Arc; use async_trait::async_trait; use iox_query::{exec::ExecutionContextProvider, QueryDatabase}; +use tokio::sync::OwnedSemaphorePermit; /// Trait that allows the query engine (which includes flight and storage/InfluxRPC) to access a virtual set of /// databases. @@ -18,4 +20,7 @@ pub trait QueryDatabaseProvider: std::fmt::Debug + Send + Sync + 'static { /// Get database if it exists. async fn db(&self, name: &str) -> Option>; + + /// Acquire concurrency-limiting sempahore + async fn acquire_semaphore(&self) -> OwnedSemaphorePermit; } diff --git a/service_common/src/test_util.rs b/service_common/src/test_util.rs new file mode 100644 index 0000000000..46b99f6534 --- /dev/null +++ b/service_common/src/test_util.rs @@ -0,0 +1,71 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use async_trait::async_trait; +use iox_query::{exec::Executor, test::TestDatabase}; +use parking_lot::Mutex; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +use crate::QueryDatabaseProvider; + +#[derive(Debug)] +pub struct TestDatabaseStore { + databases: Mutex>>, + executor: Arc, + pub metric_registry: Arc, + pub query_semaphore: Arc, +} + +impl TestDatabaseStore { + pub fn new() -> Self { + Self::default() + } + + pub fn new_with_semaphore_size(semaphore_size: usize) -> Self { + Self { + query_semaphore: Arc::new(Semaphore::new(semaphore_size)), + ..Default::default() + } + } + + pub async fn db_or_create(&self, name: &str) -> Arc { + let mut databases = self.databases.lock(); + + if let Some(db) = databases.get(name) { + Arc::clone(db) + } else { + let new_db = Arc::new(TestDatabase::new(Arc::clone(&self.executor))); + databases.insert(name.to_string(), Arc::clone(&new_db)); + new_db + } + } +} + +impl Default for TestDatabaseStore { + fn default() -> Self { + Self { + databases: Mutex::new(BTreeMap::new()), + executor: Arc::new(Executor::new(1)), + metric_registry: Default::default(), + query_semaphore: Arc::new(Semaphore::new(u16::MAX as usize)), + } + } +} + +#[async_trait] +impl QueryDatabaseProvider for TestDatabaseStore { + type Db = TestDatabase; + + /// Retrieve the database specified name + async fn db(&self, name: &str) -> Option> { + let databases = self.databases.lock(); + + databases.get(name).cloned() + } + + async fn acquire_semaphore(&self) -> OwnedSemaphorePermit { + Arc::clone(&self.query_semaphore) + .acquire_owned() + .await + .unwrap() + } +} diff --git a/service_grpc_flight/src/lib.rs b/service_grpc_flight/src/lib.rs index eaa46b97e2..e5aa053b34 100644 --- a/service_grpc_flight/src/lib.rs +++ b/service_grpc_flight/src/lib.rs @@ -23,7 +23,7 @@ use serde::Deserialize; use service_common::{planner::Planner, QueryDatabaseProvider}; use snafu::{ResultExt, Snafu}; use std::{fmt::Debug, pin::Pin, sync::Arc, task::Poll}; -use tokio::task::JoinHandle; +use tokio::{sync::OwnedSemaphorePermit, task::JoinHandle}; use tonic::{Request, Response, Streaming}; #[allow(clippy::enum_variant_names)] @@ -196,6 +196,8 @@ where } }; + let permit = self.server.acquire_semaphore().await; + let database = DatabaseName::new(&read_info.database_name).context(InvalidDatabaseNameSnafu)?; @@ -218,6 +220,7 @@ where physical_plan, read_info.database_name, query_completed_token, + permit, ) .await?; @@ -286,6 +289,8 @@ struct GetStream { rx: futures::channel::mpsc::Receiver>, join_handle: JoinHandle<()>, done: bool, + #[allow(dead_code)] + permit: OwnedSemaphorePermit, } impl GetStream { @@ -294,6 +299,7 @@ impl GetStream { physical_plan: Arc, database_name: String, mut query_completed_token: QueryCompletedToken, + permit: OwnedSemaphorePermit, ) -> Result { // setup channel let (mut tx, rx) = futures::channel::mpsc::channel::>(1); @@ -382,6 +388,7 @@ impl GetStream { rx, join_handle, done: false, + permit, }) } } diff --git a/service_grpc_influxrpc/Cargo.toml b/service_grpc_influxrpc/Cargo.toml index 1d6519aef3..9051be1385 100644 --- a/service_grpc_influxrpc/Cargo.toml +++ b/service_grpc_influxrpc/Cargo.toml @@ -19,6 +19,7 @@ service_common = { path = "../service_common" } arrow = { version = "14.0.0", features = ["prettyprint"] } async-trait = "0.1" futures = "0.3" +pin-project = "1.0" prost = "0.10" regex = "1.5.6" serde = { version = "1.0", features = ["derive"] } diff --git a/service_grpc_influxrpc/src/service.rs b/service_grpc_influxrpc/src/service.rs index f07cacb200..53568b1106 100644 --- a/service_grpc_influxrpc/src/service.rs +++ b/service_grpc_influxrpc/src/service.rs @@ -12,6 +12,7 @@ use crate::{ StorageService, }; use data_types::{org_and_bucket_to_database, DatabaseName}; +use futures::Stream; use generated_types::{ google::protobuf::Empty, literal_or_regex::Value as RegexOrLiteralValue, offsets_response::PartitionOffsetResponse, storage_server::Storage, tag_key_predicate, @@ -31,13 +32,14 @@ use iox_query::{ QueryDatabase, QueryText, }; use observability_deps::tracing::{error, info, trace}; +use pin_project::pin_project; use service_common::{planner::Planner, QueryDatabaseProvider}; use snafu::{OptionExt, ResultExt, Snafu}; use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, OwnedSemaphorePermit}; use tokio_stream::wrappers::ReceiverStream; use tonic::Status; @@ -219,7 +221,8 @@ impl Storage for StorageService where T: QueryDatabaseProvider + 'static, { - type ReadFilterStream = futures::stream::Iter>>; + type ReadFilterStream = + StreamWithPermit>>>; async fn read_filter( &self, @@ -228,6 +231,7 @@ where let span_ctx = req.extensions().get().cloned(); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; info!(%db_name, ?req.range, predicate=%req.predicate.loggable(), "read filter"); @@ -250,10 +254,14 @@ where query_completed_token.set_success(); } - Ok(tonic::Response::new(futures::stream::iter(results))) + Ok(tonic::Response::new(StreamWithPermit::new( + futures::stream::iter(results), + permit, + ))) } - type ReadGroupStream = futures::stream::Iter>>; + type ReadGroupStream = + StreamWithPermit>>>; async fn read_group( &self, @@ -261,6 +269,7 @@ where ) -> Result, Status> { let span_ctx = req.extensions().get().cloned(); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -306,11 +315,14 @@ where query_completed_token.set_success(); } - Ok(tonic::Response::new(futures::stream::iter(results))) + Ok(tonic::Response::new(StreamWithPermit::new( + futures::stream::iter(results), + permit, + ))) } type ReadWindowAggregateStream = - futures::stream::Iter>>; + StreamWithPermit>>>; async fn read_window_aggregate( &self, @@ -318,6 +330,7 @@ where ) -> Result, Status> { let span_ctx = req.extensions().get().cloned(); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -361,10 +374,13 @@ where query_completed_token.set_success(); } - Ok(tonic::Response::new(futures::stream::iter(results))) + Ok(tonic::Response::new(StreamWithPermit::new( + futures::stream::iter(results), + permit, + ))) } - type TagKeysStream = ReceiverStream>; + type TagKeysStream = StreamWithPermit>>; async fn tag_keys( &self, @@ -374,6 +390,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -414,10 +431,13 @@ where .await .expect("sending tag_keys response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } - type TagValuesStream = ReceiverStream>; + type TagValuesStream = StreamWithPermit>>; async fn tag_values( &self, @@ -427,6 +447,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -500,11 +521,15 @@ where .await .expect("sending tag_values response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } - type TagValuesGroupedByMeasurementAndTagKeyStream = - futures::stream::Iter>>; + type TagValuesGroupedByMeasurementAndTagKeyStream = StreamWithPermit< + futures::stream::Iter>>, + >; async fn tag_values_grouped_by_measurement_and_tag_key( &self, @@ -513,6 +538,7 @@ where let span_ctx = req.extensions().get().cloned(); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -542,7 +568,10 @@ where query_completed_token.set_success(); } - Ok(tonic::Response::new(futures::stream::iter(results))) + Ok(tonic::Response::new(StreamWithPermit::new( + futures::stream::iter(results), + permit, + ))) } type ReadSeriesCardinalityStream = ReceiverStream>; @@ -592,7 +621,8 @@ where Ok(tonic::Response::new(caps)) } - type MeasurementNamesStream = ReceiverStream>; + type MeasurementNamesStream = + StreamWithPermit>>; async fn measurement_names( &self, @@ -602,6 +632,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -634,10 +665,14 @@ where .await .expect("sending measurement names response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } - type MeasurementTagKeysStream = ReceiverStream>; + type MeasurementTagKeysStream = + StreamWithPermit>>; async fn measurement_tag_keys( &self, @@ -647,6 +682,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -689,10 +725,14 @@ where .await .expect("sending measurement_tag_keys response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } - type MeasurementTagValuesStream = ReceiverStream>; + type MeasurementTagValuesStream = + StreamWithPermit>>; async fn measurement_tag_values( &self, @@ -702,6 +742,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -746,10 +787,14 @@ where .await .expect("sending measurement_tag_values response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } - type MeasurementFieldsStream = ReceiverStream>; + type MeasurementFieldsStream = + StreamWithPermit>>; async fn measurement_fields( &self, @@ -759,6 +804,7 @@ where let (tx, rx) = mpsc::channel(4); let req = req.into_inner(); + let permit = self.db_store.acquire_semaphore().await; let db_name = get_database_name(&req)?; let db = self @@ -806,7 +852,10 @@ where .await .expect("sending measurement_fields response to server"); - Ok(tonic::Response::new(ReceiverStream::new(rx))) + Ok(tonic::Response::new(StreamWithPermit::new( + ReceiverStream::new(rx), + permit, + ))) } async fn offsets( @@ -1386,10 +1435,39 @@ impl ErrorLogger for Result { } } +/// Helper to keep a semaphore permit attached to a stream. +#[pin_project] +pub struct StreamWithPermit { + #[pin] + stream: S, + #[allow(dead_code)] + permit: OwnedSemaphorePermit, +} + +impl StreamWithPermit { + fn new(stream: S, permit: OwnedSemaphorePermit) -> Self { + Self { stream, permit } + } +} + +impl Stream for StreamWithPermit +where + S: Stream, +{ + type Item = S::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + this.stream.poll_next(cx) + } +} + #[cfg(test)] mod tests { use super::*; - use async_trait::async_trait; use data_types::ChunkId; use datafusion::logical_plan::{col, lit, Expr}; use generated_types::{i_ox_testing_client::IOxTestingClient, tag_key_predicate::Value}; @@ -1398,22 +1476,18 @@ mod tests { generated_types::*, Client as StorageClient, OrgAndBucket, }; - use iox_query::{ - exec::Executor, - test::{TestChunk, TestDatabase}, - }; + use iox_query::test::TestChunk; use metric::{Attributes, Metric, U64Counter}; use panic_logging::SendPanicsToTracing; - use parking_lot::Mutex; use predicate::{PredicateBuilder, PredicateMatch}; - use service_common::QueryDatabaseProvider; + use service_common::test_util::TestDatabaseStore; use std::{ - collections::BTreeMap, net::{IpAddr, Ipv4Addr, SocketAddr}, num::NonZeroU64, sync::Arc, }; use test_helpers::{assert_contains, tracing::TracingCapture}; + use tokio::task::JoinHandle; use tokio_stream::wrappers::TcpListenerStream; fn to_str_vec(s: &[&str]) -> Vec { @@ -2929,13 +3003,18 @@ mod tests { iox_client: IOxTestingClient, storage_client: StorageClient, test_storage: Arc, + join_handle: JoinHandle<()>, } impl Fixture { /// Start up a test storage server listening on `port`, returning /// a fixture with the test server and clients async fn new() -> Result { - let test_storage = Arc::new(TestDatabaseStore::new()); + Self::new_with_semaphore_size(u16::MAX as usize).await + } + + async fn new_with_semaphore_size(semaphore_size: usize) -> Result { + let test_storage = Arc::new(TestDatabaseStore::new_with_semaphore_size(semaphore_size)); // Get a random port from the kernel by asking for port 0. let bind_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); @@ -2970,9 +3049,10 @@ mod tests { .serve_with_incoming(stream) .await .log_if_error("Running Tonic Server") + .ok(); }; - tokio::task::spawn(server); + let join_handle = tokio::task::spawn(server); let conn = ConnectionBuilder::default() .connect_timeout(std::time::Duration::from_secs(30)) @@ -2988,6 +3068,7 @@ mod tests { iox_client, storage_client, test_storage, + join_handle, }) } @@ -3017,50 +3098,9 @@ mod tests { } } - #[derive(Debug)] - pub struct TestDatabaseStore { - databases: Mutex>>, - executor: Arc, - pub metric_registry: Arc, - } - - impl TestDatabaseStore { - pub fn new() -> Self { - Self::default() - } - - async fn db_or_create(&self, name: &str) -> Arc { - let mut databases = self.databases.lock(); - - if let Some(db) = databases.get(name) { - Arc::clone(db) - } else { - let new_db = Arc::new(TestDatabase::new(Arc::clone(&self.executor))); - databases.insert(name.to_string(), Arc::clone(&new_db)); - new_db - } - } - } - - impl Default for TestDatabaseStore { - fn default() -> Self { - Self { - databases: Mutex::new(BTreeMap::new()), - executor: Arc::new(Executor::new(1)), - metric_registry: Default::default(), - } - } - } - - #[async_trait] - impl QueryDatabaseProvider for TestDatabaseStore { - type Db = TestDatabase; - - /// Retrieve the database specified name - async fn db(&self, name: &str) -> Option> { - let databases = self.databases.lock(); - - databases.get(name).cloned() + impl Drop for Fixture { + fn drop(&mut self) { + self.join_handle.abort(); } } }