feat(processing_engine): proper escaping of LineBuilder arguments.

processing_engine/escape_line_builder
Jackson Newhouse 2025-03-12 16:56:06 -07:00
parent 97b8c471f8
commit d27814263a
2 changed files with 154 additions and 5 deletions

View File

@ -1257,6 +1257,141 @@ async fn distinct_cache_create_and_delete() {
assert_contains!(result, "distinct cache deleted successfully");
}
#[test_log::test(tokio::test)]
async fn test_linebuilder_escaping() {
use crate::server::ConfigProvider;
// Create plugin file with all test cases
let plugin_file = create_plugin_file(
r#"
def process_writes(influxdb3_local, table_batches, args=None):
# Basic test with all field types
basic_line = LineBuilder("metrics")\
.tag("host", "server01")\
.tag("region", "us-west")\
.int64_field("cpu", 42)\
.uint64_field("memory_bytes", 8589934592)\
.float64_field("load", 86.5)\
.string_field("status", "online")\
.bool_field("healthy", True)\
.time_ns(1609459200000000000)
influxdb3_local.write(basic_line)
# Test escaping spaces in tag values
spaces_line = LineBuilder("system_metrics")\
.tag("server", "app server 1")\
.tag("datacenter", "us west")\
.int64_field("count", 1)
influxdb3_local.write(spaces_line)
# Test escaping commas in tag values
commas_line = LineBuilder("network")\
.tag("servers", "web,app,db")\
.tag("location", "floor1,rack3")\
.int64_field("connections", 256)
influxdb3_local.write(commas_line)
# Test escaping equals signs in tag values
equals_line = LineBuilder("formulas")\
.tag("equation", "y=mx+b")\
.tag("result", "a=b=c")\
.float64_field("value", 3.14159)
influxdb3_local.write(equals_line)
# Test escaping backslashes in tag values
backslash_line = LineBuilder("paths")\
.tag("windows_path", "C:\\Program Files\\App")\
.tag("regex", "\\d+\\w+")\
.string_field("description", "Windows\\Unix paths")\
.int64_field("count", 42)
influxdb3_local.write(backslash_line)
# Test escaping quotes in string fields
quotes_line = LineBuilder("messages")\
.tag("type", "notification")\
.string_field("content", "User said \"Hello World\"")\
.string_field("json", "{\"key\": \"value\"}")\
.int64_field("priority", 1)
influxdb3_local.write(quotes_line)
# Test a complex case with multiple escape characters
complex_line = LineBuilder("complex,measurement")\
.tag("location", "New York, USA")\
.tag("details", "floor=5, room=3")\
.tag("path", "C:\\Users\\Admin\\Documents")\
.string_field("message", "Error in line: \"x = y + z\"")\
.string_field("query", "SELECT * FROM table WHERE id=\"abc\"")\
.float64_field("value", 123.456)\
.time_ns(1609459200000000000)
influxdb3_local.write(complex_line)
# Test writing to a specific database
specific_db_line = LineBuilder("memory_stats")\
.tag("host", "server1")\
.int64_field("usage", 75)
influxdb3_local.write_to_db("metrics_db", specific_db_line)
# Test multiple chained methods
chained_line = LineBuilder("sensor_data")\
.tag("device", "thermostat").tag("room", "living room").tag("floor", "1")\
.float64_field("temperature", 72.5).int64_field("humidity", 45).string_field("mode", "auto")
influxdb3_local.write(chained_line)
influxdb3_local.info("All LineBuilder tests completed")"#,
);
let plugin_dir = plugin_file.path().parent().unwrap().to_str().unwrap();
let plugin_name = plugin_file.path().file_name().unwrap().to_str().unwrap();
let server = TestServer::configure()
.with_plugin_dir(plugin_dir)
.spawn()
.await;
let server_addr = server.client_addr();
let db_name = "test_db";
// Run the test
let result = run_with_confirmation(&[
"test",
"wal_plugin",
"--database",
db_name,
"--host",
&server_addr,
"--lp",
"test_input,tag1=tag1_value field1=1i 500",
"--input-arguments",
"arg1=test",
plugin_name,
]);
let res = serde_json::from_str::<Value>(&result).unwrap();
let expected_result = r#"{
"log_lines": [
"INFO: All LineBuilder tests completed"
],
"database_writes": {
"test_db": [
"metrics,host=server01,region=us-west cpu=42i,memory_bytes=8589934592u,load=86.5,status=\"online\",healthy=t 1609459200000000000",
"system_metrics,server=app\\ server\\ 1,datacenter=us\\ west count=1i",
"network,servers=web\\,app\\,db,location=floor1\\,rack3 connections=256i",
"formulas,equation=y\\=mx+b,result=a\\=b\\=c value=3.14159",
"paths,windows_path=C:\\\\Program\\ Files\\\\App,regex=\\\\d+\\\\w+ description=\"Windows\\\\Unix paths\",count=42i",
"messages,type=notification content=\"User said \\\"Hello World\\\"\",json=\"{\\\"key\\\": \\\"value\\\"}\",priority=1i",
"complex\\,measurement,location=New\\ York\\,\\ USA,details=floor\\=5\\,\\ room\\=3,path=C:\\\\Users\\\\Admin\\\\Documents message=\"Error in line: \\\"x = y + z\\\"\",query=\"SELECT * FROM table WHERE id=\\\"abc\\\"\",value=123.456 1609459200000000000",
"sensor_data,device=thermostat,room=living\\ room,floor=1 temperature=72.5,humidity=45i,mode=\"auto\""
],
"metrics_db": [
"memory_stats,host=server1 usage=75i"
]
},
"errors": []
}"#;
let expected_result = serde_json::from_str::<Value>(expected_result).unwrap();
assert_eq!(res, expected_result);
}
#[test_log::test(tokio::test)]
async fn test_wal_plugin_test() {

View File

@ -366,6 +366,18 @@ class LineBuilder:
if '=' in key:
raise InvalidKeyError(f"{key_type} key '{key}' cannot contain equals signs")
def _escape_measurement(self, value: str) -> str:
"""Escape characters in measurement names according to line protocol."""
return value.replace(',', '\\,').replace(' ', '\\ ')
def _escape_tag_value(self, value: str) -> str:
"""Escape characters in tag values according to line protocol."""
return value.replace('\\', '\\\\').replace(',', '\\,').replace('=', '\\=').replace(' ', '\\ ')
def _escape_field_key(self, value: str) -> str:
"""Escape characters in field keys according to line protocol."""
return value.replace('\\', '\\\\').replace(',', '\\,').replace('=', '\\=').replace(' ', '\\ ')
def tag(self, key: str, value: str) -> 'LineBuilder':
"""Add a tag to the line protocol."""
self._validate_key(key, "tag")
@ -397,7 +409,7 @@ class LineBuilder:
"""Add a string field to the line protocol."""
self._validate_key(key, "field")
# Escape quotes and backslashes in string values
escaped_value = value.replace('"', '\\"').replace('\\', '\\\\')
escaped_value = value.replace('\\', '\\\\').replace('"', '\\"')
self.fields[key] = f'"{escaped_value}"'
return self
@ -414,13 +426,14 @@ class LineBuilder:
def build(self) -> str:
"""Build the line protocol string."""
# Start with measurement name (escape commas only)
line = self.measurement.replace(',', '\\,')
# Start with measurement name (escape commas and spaces)
line = self._escape_measurement(self.measurement)
# Add tags if present
if self.tags:
tags_str = ','.join(
f"{k}={v}" for k, v in self.tags.items()
f"{key}={self._escape_tag_value(value)}"
for key, value in self.tags.items()
)
line += f",{tags_str}"
@ -429,7 +442,8 @@ class LineBuilder:
raise InvalidLineError(f"At least one field is required: {line}")
fields_str = ','.join(
f"{k}={v}" for k, v in self.fields.items()
f"{self._escape_field_key(key)}={value}"
for key, value in self.fields.items()
)
line += f" {fields_str}"