diff --git a/service_grpc_flight/src/request.rs b/service_grpc_flight/src/request.rs index 70f1219b1f..d6afbdfa48 100644 --- a/service_grpc_flight/src/request.rs +++ b/service_grpc_flight/src/request.rs @@ -3,6 +3,7 @@ use arrow_flight::Ticket; use bytes::Bytes; use flightsql::FlightSQLCommand; +use generated_types::google::protobuf::Any; use generated_types::influxdata::iox::querier::v1 as proto; use generated_types::influxdata::iox::querier::v1::read_info::QueryType; use observability_deps::tracing::trace; @@ -24,6 +25,20 @@ pub enum Error { } pub type Result = std::result::Result; +/// AnyError is an internal error that contains the result of attempting +/// to decode a protobuf "Any" message. This is separate from Error so +/// that an error resulting from attempting to decode the value can be +/// embedded as a source. +#[derive(Debug, Snafu)] +enum AnyError { + #[snafu(display("Invalid Protobuf: {}", source))] + DecodeAny { source: prost::DecodeError }, + #[snafu(display("Unknown type_url: {}", type_url))] + UnknownTypeURL { type_url: String }, + #[snafu(display("Invalid value: {}", source))] + InvalidValue { source: Error }, +} + /// Request structure of the "opaque" tickets used for IOx Arrow /// Flight DoGet endpoint. /// @@ -115,6 +130,8 @@ impl Display for RunQuery { } impl IoxGetRequest { + const READ_INFO_TYPE_URL: &str = "type.googleapis.com/influxdata.iox.querier.v1.ReadInfo"; + /// Create a new request to run the specified query pub fn new(database: impl Into, query: RunQuery, is_debug: bool) -> Self { Self { @@ -127,7 +144,19 @@ impl IoxGetRequest { /// try to decode a ReadInfo structure from a Token pub fn try_decode(ticket: Ticket) -> Result { // decode ticket - IoxGetRequest::decode_protobuf(ticket.ticket.clone()) + IoxGetRequest::decode_protobuf_any(ticket.ticket.clone()) + .or_else(|e| { + match e { + // If the ticket decoded as an Any with a type_url that was recognised + // don't attempt to fall back to ReadInfo it will almost certainly + // succeed, but with invalid parameters. + AnyError::InvalidValue { source } => Err(source), + e => { + trace!(%e, "Error decoding ticket as Any, trying as ReadInfo"); + IoxGetRequest::decode_protobuf(ticket.ticket.clone()) + } + } + }) .or_else(|e| { trace!(%e, ticket=%String::from_utf8_lossy(&ticket.ticket), "Error decoding ticket as ProtoBuf, trying as JSON"); @@ -175,7 +204,11 @@ impl IoxGetRequest { }, }; - let ticket = read_info.encode_to_vec(); + let any = Any { + type_url: Self::READ_INFO_TYPE_URL.to_string(), + value: read_info.encode_to_vec().into(), + }; + let ticket = any.encode_to_vec(); Ok(Ticket { ticket: ticket.into(), @@ -227,6 +260,19 @@ impl IoxGetRequest { }) } + /// Decode a ReadInfo ticket wrapped in a protobuf Any message. + fn decode_protobuf_any(ticket: Bytes) -> Result { + let any = Any::decode(ticket).context(DecodeAnySnafu)?; + if any.type_url == Self::READ_INFO_TYPE_URL { + Self::decode_protobuf(any.value).context(InvalidValueSnafu) + } else { + UnknownTypeURLSnafu { + type_url: any.type_url, + } + .fail() + } + } + /// See comments on [`IoxGetRequest`] for details of this format fn decode_protobuf(ticket: Bytes) -> Result { let read_info = proto::ReadInfo::decode(ticket).context(DecodeSnafu)?; @@ -628,6 +674,123 @@ mod tests { assert_matches!(e, Error::Invalid); } + #[test] + fn any_ticket_decoding_unspecified() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::Unspecified.into(), + flightsql_command: vec![], + is_debug: false, + }); + + // Reverts to default (unspecified) for invalid query_type enumeration, and thus SQL + let ri = IoxGetRequest::try_decode(ticket).unwrap(); + assert_eq!(ri.database, "_"); + assert_matches!(ri.query, RunQuery::Sql(query) => assert_eq!(query, "SELECT 1")); + } + + #[test] + fn any_ticket_decoding_sql() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::Sql.into(), + flightsql_command: vec![], + is_debug: false, + }); + + let ri = IoxGetRequest::try_decode(ticket).unwrap(); + assert_eq!(ri.database, "_"); + assert_matches!(ri.query, RunQuery::Sql(query) => assert_eq!(query, "SELECT 1")); + } + + #[test] + fn any_ticket_decoding_influxql() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::InfluxQl.into(), + flightsql_command: vec![], + is_debug: false, + }); + + let ri = IoxGetRequest::try_decode(ticket).unwrap(); + assert_eq!(ri.database, "_"); + assert_matches!(ri.query, RunQuery::InfluxQL(query) => assert_eq!(query, "SELECT 1")); + } + + #[test] + fn any_ticket_decoding_too_new() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".into(), + query_type: 42, // not a known query type + flightsql_command: vec![], + is_debug: false, + }); + + // Reverts to default (unspecified) for invalid query_type enumeration, and thus SQL + let ri = IoxGetRequest::try_decode(ticket).unwrap(); + assert_eq!(ri.database, "_"); + assert_matches!(ri.query, RunQuery::Sql(query) => assert_eq!(query, "SELECT 1")); + } + + #[test] + fn any_ticket_decoding_sql_too_many_fields() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::Sql.into(), + // can't have both sql_query and flightsql + flightsql_command: vec![1, 2, 3], + is_debug: false, + }); + + let e = IoxGetRequest::try_decode(ticket).unwrap_err(); + assert_matches!(e, Error::Invalid); + } + + #[test] + fn any_ticket_decoding_influxql_too_many_fields() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::InfluxQl.into(), + // can't have both sql_query and flightsql + flightsql_command: vec![1, 2, 3], + is_debug: false, + }); + + let e = IoxGetRequest::try_decode(ticket).unwrap_err(); + assert_matches!(e, Error::Invalid); + } + + #[test] + fn any_ticket_decoding_flightsql_too_many_fields() { + let ticket = make_any_wrapped_proto_ticket(&proto::ReadInfo { + database: "_".to_string(), + sql_query: "SELECT 1".to_string(), + query_type: QueryType::FlightSqlMessage.into(), + // can't have both sql_query and flightsql + flightsql_command: vec![1, 2, 3], + is_debug: false, + }); + + let e = IoxGetRequest::try_decode(ticket).unwrap_err(); + assert_matches!(e, Error::Invalid); + } + + #[test] + fn any_ticket_decoding_error() { + let ticket = Ticket { + ticket: b"invalid ticket".to_vec().into(), + }; + + let e = IoxGetRequest::try_decode(ticket).unwrap_err(); + assert_matches!(e, Error::Invalid); + } + #[test] fn round_trip_sql() { let request = IoxGetRequest { @@ -693,6 +856,16 @@ mod tests { assert_eq!(request, roundtripped) } + fn make_any_wrapped_proto_ticket(read_info: &proto::ReadInfo) -> Ticket { + let any = Any { + type_url: IoxGetRequest::READ_INFO_TYPE_URL.to_string(), + value: read_info.encode_to_vec().into(), + }; + Ticket { + ticket: any.encode_to_vec().into(), + } + } + fn make_proto_ticket(read_info: &proto::ReadInfo) -> Ticket { Ticket { ticket: read_info.encode_to_vec().into(),