From d27814263abf7fa5a7088fac1378208063e277fb Mon Sep 17 00:00:00 2001 From: Jackson Newhouse Date: Wed, 12 Mar 2025 16:56:06 -0700 Subject: [PATCH] feat(processing_engine): proper escaping of LineBuilder arguments. --- influxdb3/tests/cli/mod.rs | 135 ++++++++++++++++++++++++++++++ influxdb3_py_api/src/system_py.rs | 24 ++++-- 2 files changed, 154 insertions(+), 5 deletions(-) diff --git a/influxdb3/tests/cli/mod.rs b/influxdb3/tests/cli/mod.rs index 24fd6cabd1..5387afea85 100644 --- a/influxdb3/tests/cli/mod.rs +++ b/influxdb3/tests/cli/mod.rs @@ -1257,6 +1257,141 @@ async fn distinct_cache_create_and_delete() { assert_contains!(result, "distinct cache deleted successfully"); } +#[test_log::test(tokio::test)] +async fn test_linebuilder_escaping() { + use crate::server::ConfigProvider; + + // Create plugin file with all test cases + let plugin_file = create_plugin_file( + r#" +def process_writes(influxdb3_local, table_batches, args=None): + # Basic test with all field types + basic_line = LineBuilder("metrics")\ + .tag("host", "server01")\ + .tag("region", "us-west")\ + .int64_field("cpu", 42)\ + .uint64_field("memory_bytes", 8589934592)\ + .float64_field("load", 86.5)\ + .string_field("status", "online")\ + .bool_field("healthy", True)\ + .time_ns(1609459200000000000) + influxdb3_local.write(basic_line) + + # Test escaping spaces in tag values + spaces_line = LineBuilder("system_metrics")\ + .tag("server", "app server 1")\ + .tag("datacenter", "us west")\ + .int64_field("count", 1) + influxdb3_local.write(spaces_line) + + # Test escaping commas in tag values + commas_line = LineBuilder("network")\ + .tag("servers", "web,app,db")\ + .tag("location", "floor1,rack3")\ + .int64_field("connections", 256) + influxdb3_local.write(commas_line) + + # Test escaping equals signs in tag values + equals_line = LineBuilder("formulas")\ + .tag("equation", "y=mx+b")\ + .tag("result", "a=b=c")\ + .float64_field("value", 3.14159) + influxdb3_local.write(equals_line) + + # Test escaping backslashes in tag values + backslash_line = LineBuilder("paths")\ + .tag("windows_path", "C:\\Program Files\\App")\ + .tag("regex", "\\d+\\w+")\ + .string_field("description", "Windows\\Unix paths")\ + .int64_field("count", 42) + influxdb3_local.write(backslash_line) + + # Test escaping quotes in string fields + quotes_line = LineBuilder("messages")\ + .tag("type", "notification")\ + .string_field("content", "User said \"Hello World\"")\ + .string_field("json", "{\"key\": \"value\"}")\ + .int64_field("priority", 1) + influxdb3_local.write(quotes_line) + + # Test a complex case with multiple escape characters + complex_line = LineBuilder("complex,measurement")\ + .tag("location", "New York, USA")\ + .tag("details", "floor=5, room=3")\ + .tag("path", "C:\\Users\\Admin\\Documents")\ + .string_field("message", "Error in line: \"x = y + z\"")\ + .string_field("query", "SELECT * FROM table WHERE id=\"abc\"")\ + .float64_field("value", 123.456)\ + .time_ns(1609459200000000000) + influxdb3_local.write(complex_line) + + # Test writing to a specific database + specific_db_line = LineBuilder("memory_stats")\ + .tag("host", "server1")\ + .int64_field("usage", 75) + influxdb3_local.write_to_db("metrics_db", specific_db_line) + + # Test multiple chained methods + chained_line = LineBuilder("sensor_data")\ + .tag("device", "thermostat").tag("room", "living room").tag("floor", "1")\ + .float64_field("temperature", 72.5).int64_field("humidity", 45).string_field("mode", "auto") + influxdb3_local.write(chained_line) + + influxdb3_local.info("All LineBuilder tests completed")"#, + ); + + let plugin_dir = plugin_file.path().parent().unwrap().to_str().unwrap(); + let plugin_name = plugin_file.path().file_name().unwrap().to_str().unwrap(); + + let server = TestServer::configure() + .with_plugin_dir(plugin_dir) + .spawn() + .await; + let server_addr = server.client_addr(); + + let db_name = "test_db"; + + // Run the test + let result = run_with_confirmation(&[ + "test", + "wal_plugin", + "--database", + db_name, + "--host", + &server_addr, + "--lp", + "test_input,tag1=tag1_value field1=1i 500", + "--input-arguments", + "arg1=test", + plugin_name, + ]); + + let res = serde_json::from_str::(&result).unwrap(); + + let expected_result = r#"{ + "log_lines": [ + "INFO: All LineBuilder tests completed" + ], + "database_writes": { + "test_db": [ + "metrics,host=server01,region=us-west cpu=42i,memory_bytes=8589934592u,load=86.5,status=\"online\",healthy=t 1609459200000000000", + "system_metrics,server=app\\ server\\ 1,datacenter=us\\ west count=1i", + "network,servers=web\\,app\\,db,location=floor1\\,rack3 connections=256i", + "formulas,equation=y\\=mx+b,result=a\\=b\\=c value=3.14159", + "paths,windows_path=C:\\\\Program\\ Files\\\\App,regex=\\\\d+\\\\w+ description=\"Windows\\\\Unix paths\",count=42i", + "messages,type=notification content=\"User said \\\"Hello World\\\"\",json=\"{\\\"key\\\": \\\"value\\\"}\",priority=1i", + "complex\\,measurement,location=New\\ York\\,\\ USA,details=floor\\=5\\,\\ room\\=3,path=C:\\\\Users\\\\Admin\\\\Documents message=\"Error in line: \\\"x = y + z\\\"\",query=\"SELECT * FROM table WHERE id=\\\"abc\\\"\",value=123.456 1609459200000000000", + "sensor_data,device=thermostat,room=living\\ room,floor=1 temperature=72.5,humidity=45i,mode=\"auto\"" + ], + "metrics_db": [ + "memory_stats,host=server1 usage=75i" + ] + }, + "errors": [] +}"#; + let expected_result = serde_json::from_str::(expected_result).unwrap(); + assert_eq!(res, expected_result); +} #[test_log::test(tokio::test)] async fn test_wal_plugin_test() { diff --git a/influxdb3_py_api/src/system_py.rs b/influxdb3_py_api/src/system_py.rs index 04b6d6e5d0..064f8dc037 100644 --- a/influxdb3_py_api/src/system_py.rs +++ b/influxdb3_py_api/src/system_py.rs @@ -366,6 +366,18 @@ class LineBuilder: if '=' in key: raise InvalidKeyError(f"{key_type} key '{key}' cannot contain equals signs") + def _escape_measurement(self, value: str) -> str: + """Escape characters in measurement names according to line protocol.""" + return value.replace(',', '\\,').replace(' ', '\\ ') + + def _escape_tag_value(self, value: str) -> str: + """Escape characters in tag values according to line protocol.""" + return value.replace('\\', '\\\\').replace(',', '\\,').replace('=', '\\=').replace(' ', '\\ ') + + def _escape_field_key(self, value: str) -> str: + """Escape characters in field keys according to line protocol.""" + return value.replace('\\', '\\\\').replace(',', '\\,').replace('=', '\\=').replace(' ', '\\ ') + def tag(self, key: str, value: str) -> 'LineBuilder': """Add a tag to the line protocol.""" self._validate_key(key, "tag") @@ -397,7 +409,7 @@ class LineBuilder: """Add a string field to the line protocol.""" self._validate_key(key, "field") # Escape quotes and backslashes in string values - escaped_value = value.replace('"', '\\"').replace('\\', '\\\\') + escaped_value = value.replace('\\', '\\\\').replace('"', '\\"') self.fields[key] = f'"{escaped_value}"' return self @@ -414,13 +426,14 @@ class LineBuilder: def build(self) -> str: """Build the line protocol string.""" - # Start with measurement name (escape commas only) - line = self.measurement.replace(',', '\\,') + # Start with measurement name (escape commas and spaces) + line = self._escape_measurement(self.measurement) # Add tags if present if self.tags: tags_str = ','.join( - f"{k}={v}" for k, v in self.tags.items() + f"{key}={self._escape_tag_value(value)}" + for key, value in self.tags.items() ) line += f",{tags_str}" @@ -429,7 +442,8 @@ class LineBuilder: raise InvalidLineError(f"At least one field is required: {line}") fields_str = ','.join( - f"{k}={v}" for k, v in self.fields.items() + f"{self._escape_field_key(key)}={value}" + for key, value in self.fields.items() ) line += f" {fields_str}"