diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index 99bd8ad..d5c4545 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -82,9 +82,8 @@ def _python_to_value(value: Any) -> Any: from betterproto.lib.google.protobuf import ListValue, Value if value is None: - from betterproto.lib.google.protobuf import NullValue - - return Value(null_value=NullValue.NULL_VALUE) # type: ignore[arg-type] + # betterproto 2.0.0b7 uses integer 0 for null value (NullValue.NULL_VALUE doesn't exist) + return Value(null_value=0) # type: ignore[arg-type] elif isinstance(value, bool): return Value(bool_value=value) elif isinstance(value, (int, float)): diff --git a/drift/core/span_serialization.py b/drift/core/span_serialization.py index 05eb6fb..80cff92 100644 --- a/drift/core/span_serialization.py +++ b/drift/core/span_serialization.py @@ -37,9 +37,8 @@ def _value_to_proto(value: Any) -> ProtoValue: proto_value = ProtoValue() if value is None: - from betterproto.lib.google.protobuf import NullValue - - proto_value.null_value = NullValue.NULL_VALUE # type: ignore[assignment] + # betterproto 2.0.0b7 uses integer 0 for null value (NullValue.NULL_VALUE doesn't exist) + proto_value.null_value = 0 # type: ignore[assignment] elif isinstance(value, bool): proto_value.bool_value = value elif isinstance(value, (int, float)): diff --git a/drift/core/tracing/adapters/api.py b/drift/core/tracing/adapters/api.py index 727e2f0..d8c0329 100644 --- a/drift/core/tracing/adapters/api.py +++ b/drift/core/tracing/adapters/api.py @@ -239,9 +239,8 @@ def _dict_to_struct(data: dict[str, Any]) -> Struct: def value_to_proto(val: Any) -> Value: """Convert a Python value to protobuf Value.""" if val is None: - from betterproto.lib.google.protobuf import NullValue - - return Value(null_value=NullValue.NULL_VALUE) # type: ignore[arg-type] + # betterproto 2.0.0b7 uses integer 0 for null value (NullValue.NULL_VALUE doesn't exist) + return Value(null_value=0) # type: ignore[arg-type] elif isinstance(val, bool): return Value(bool_value=val) elif isinstance(val, (int, float)): diff --git a/drift/instrumentation/redis/e2e-tests/src/app.py b/drift/instrumentation/redis/e2e-tests/src/app.py index 7caf226..112ab11 100644 --- a/drift/instrumentation/redis/e2e-tests/src/app.py +++ b/drift/instrumentation/redis/e2e-tests/src/app.py @@ -89,6 +89,162 @@ def redis_keys(pattern): return jsonify({"error": str(e)}), 500 +@app.route("/test/mget-mset", methods=["GET"]) +def test_mget_mset(): + """Test MGET/MSET - multiple key operations.""" + try: + # MSET multiple keys + redis_client.mset({"test:mset:key1": "value1", "test:mset:key2": "value2", "test:mset:key3": "value3"}) + # MGET multiple keys + result = redis_client.mget(["test:mset:key1", "test:mset:key2", "test:mset:key3", "test:mset:nonexistent"]) + # Clean up + redis_client.delete("test:mset:key1", "test:mset:key2", "test:mset:key3") + return jsonify({"success": True, "result": result}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/pipeline-basic", methods=["GET"]) +def test_pipeline_basic(): + """Test basic pipeline operations.""" + try: + pipe = redis_client.pipeline() + pipe.set("test:pipe:key1", "value1") + pipe.set("test:pipe:key2", "value2") + pipe.get("test:pipe:key1") + pipe.get("test:pipe:key2") + results = pipe.execute() + # Clean up + redis_client.delete("test:pipe:key1", "test:pipe:key2") + return jsonify({"success": True, "results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/pipeline-no-transaction", methods=["GET"]) +def test_pipeline_no_transaction(): + """Test pipeline with transaction=False.""" + try: + pipe = redis_client.pipeline(transaction=False) + pipe.set("test:pipe:notx:key1", "value1") + pipe.incr("test:pipe:notx:counter") + pipe.get("test:pipe:notx:key1") + results = pipe.execute() + # Clean up + redis_client.delete("test:pipe:notx:key1", "test:pipe:notx:counter") + return jsonify({"success": True, "results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/async-pipeline", methods=["GET"]) +def test_async_pipeline(): + """Test async pipeline operations using asyncio.""" + import asyncio + + import redis.asyncio as aioredis + + async def run_async_pipeline(): + # Create async Redis client + async_client = aioredis.Redis( + host=os.getenv("REDIS_HOST", "redis"), + port=int(os.getenv("REDIS_PORT", "6379")), + db=0, + decode_responses=True, + ) + + try: + # Create async pipeline + pipe = async_client.pipeline() + pipe.set("test:async:pipe:key1", "async_value1") + pipe.set("test:async:pipe:key2", "async_value2") + pipe.get("test:async:pipe:key1") + pipe.get("test:async:pipe:key2") + results = await pipe.execute() + + # Clean up + await async_client.delete("test:async:pipe:key1", "test:async:pipe:key2") + + return results + finally: + await async_client.aclose() + + try: + results = asyncio.run(run_async_pipeline()) + return jsonify({"success": True, "results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/binary-data", methods=["GET"]) +def test_binary_data(): + """Test binary data that cannot be decoded as UTF-8.""" + try: + # Create a Redis client without decode_responses for binary data + binary_client = redis.Redis( + host=os.getenv("REDIS_HOST", "redis"), + port=int(os.getenv("REDIS_PORT", "6379")), + db=0, + decode_responses=False, + ) + + # Binary data that cannot be decoded as UTF-8 + binary_value = bytes([0x80, 0x81, 0x82, 0xFF, 0xFE, 0xFD]) + + # Set binary data + binary_client.set("test:binary:key", binary_value) + + # Get binary data back + retrieved = binary_client.get("test:binary:key") + + # Clean up + binary_client.delete("test:binary:key") + + return jsonify( + { + "success": True, + "original_hex": binary_value.hex(), + "retrieved_hex": retrieved.hex() if retrieved else None, + "match": binary_value == retrieved, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/transaction-watch", methods=["GET"]) +def test_transaction_watch(): + """Test transaction with WATCH pattern. + + This tests whether WATCH/MULTI/EXEC transaction pattern works correctly. + """ + try: + # Set initial value + redis_client.set("test:watch:counter", "10") + + # Start a watched transaction + pipe = redis_client.pipeline(transaction=True) + pipe.watch("test:watch:counter") + + # Get current value (this happens outside the transaction) + current = int(redis_client.get("test:watch:counter")) + + # Start the transaction + pipe.multi() + pipe.set("test:watch:counter", str(current + 5)) + pipe.get("test:watch:counter") + + # Execute + results = pipe.execute() + + # Clean up + redis_client.delete("test:watch:counter") + + return jsonify({"success": True, "initial_value": 10, "expected_final": 15, "results": results}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + if __name__ == "__main__": sdk.mark_app_as_ready() app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/redis/e2e-tests/src/test_requests.py b/drift/instrumentation/redis/e2e-tests/src/test_requests.py index 3ee9606..9cc893c 100644 --- a/drift/instrumentation/redis/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/redis/e2e-tests/src/test_requests.py @@ -33,8 +33,7 @@ def make_request(method, endpoint, **kwargs): # Get operations make_request("GET", "/redis/get/test_key") make_request("GET", "/redis/get/test_key_expiry") - # TODO: figure out why this test fails during replay - # make_request("GET", "/redis/get/nonexistent_key") + make_request("GET", "/redis/get/nonexistent_key") # Increment operations make_request("POST", "/redis/incr/counter") @@ -49,4 +48,18 @@ def make_request(method, endpoint, **kwargs): make_request("DELETE", "/redis/delete/test_key") make_request("DELETE", "/redis/delete/counter") + make_request("GET", "/test/mget-mset") + + # Pipeline operations + make_request("GET", "/test/pipeline-basic") + make_request("GET", "/test/pipeline-no-transaction") + + # Async Pipeline operations + make_request("GET", "/test/async-pipeline") + + # Binary data handling + make_request("GET", "/test/binary-data") + + make_request("GET", "/test/transaction-watch") + print("\nAll requests completed successfully") diff --git a/drift/instrumentation/redis/instrumentation.py b/drift/instrumentation/redis/instrumentation.py index 1689865..0b671da 100644 --- a/drift/instrumentation/redis/instrumentation.py +++ b/drift/instrumentation/redis/instrumentation.py @@ -108,6 +108,29 @@ def patched_pipeline_execute(pipeline_self, *args, **kwargs): Pipeline.execute = patched_pipeline_execute logger.debug("redis.client.Pipeline.execute instrumented") + + # Patch Pipeline.immediate_execute_command for WATCH and other immediate commands + if hasattr(Pipeline, "immediate_execute_command"): + original_immediate = Pipeline.immediate_execute_command + self._original_pipeline_immediate_execute = original_immediate + + def patched_pipeline_immediate_execute(pipeline_self, *args, **kwargs): + """Patched Pipeline.immediate_execute_command method.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_immediate(pipeline_self, *args, **kwargs) + + return instrumentation._traced_pipeline_immediate_execute( + pipeline_self, + original_immediate, + sdk, + args, + kwargs, + ) + + Pipeline.immediate_execute_command = patched_pipeline_immediate_execute + logger.debug("redis.client.Pipeline.immediate_execute_command instrumented") except ImportError: logger.debug("redis.client.Pipeline not available") @@ -137,6 +160,55 @@ async def patched_async_execute_command(redis_self, *args, **kwargs): async_redis_class.execute_command = patched_async_execute_command logger.debug("redis.asyncio.Redis.execute_command instrumented") + + # Patch async Pipeline.execute + try: + from redis.asyncio.client import Pipeline as AsyncPipeline + + if hasattr(AsyncPipeline, "execute"): + original_async_pipeline_execute = AsyncPipeline.execute + + async def patched_async_pipeline_execute(pipeline_self, *args, **kwargs): + """Patched async Pipeline.execute method.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return await original_async_pipeline_execute(pipeline_self, *args, **kwargs) + + return await instrumentation._traced_async_pipeline_execute( + pipeline_self, + original_async_pipeline_execute, + sdk, + args, + kwargs, + ) + + AsyncPipeline.execute = patched_async_pipeline_execute + logger.debug("redis.asyncio.client.Pipeline.execute instrumented") + + # Patch async Pipeline.immediate_execute_command for WATCH and other immediate commands + if hasattr(AsyncPipeline, "immediate_execute_command"): + original_async_immediate = AsyncPipeline.immediate_execute_command + + async def patched_async_pipeline_immediate_execute(pipeline_self, *args, **kwargs): + """Patched async Pipeline.immediate_execute_command method.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return await original_async_immediate(pipeline_self, *args, **kwargs) + + return await instrumentation._traced_async_pipeline_immediate_execute( + pipeline_self, + original_async_immediate, + sdk, + args, + kwargs, + ) + + AsyncPipeline.immediate_execute_command = patched_async_pipeline_immediate_execute + logger.debug("redis.asyncio.client.Pipeline.immediate_execute_command instrumented") + except ImportError: + logger.debug("redis.asyncio.client.Pipeline not available") except ImportError: logger.debug("redis.asyncio not available") @@ -169,6 +241,58 @@ def original_call(): span_kind=OTelSpanKind.CLIENT, ) + def _traced_pipeline_immediate_execute( + self, pipeline: Any, original_execute: Any, sdk: TuskDrift, args: tuple, kwargs: dict + ) -> Any: + """Traced Pipeline.immediate_execute_command method for WATCH and other immediate commands.""" + if sdk.mode == TuskDriftMode.DISABLED: + return original_execute(pipeline, *args, **kwargs) + + command_name = args[0] if args else "UNKNOWN" + command_str = self._format_command(args) + + def original_call(): + return original_execute(pipeline, *args, **kwargs) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_execute_command(sdk, command_name, command_str, args), + no_op_request_handler=lambda: self._get_default_response(command_name), + is_server_request=False, + ) + + # RECORD mode + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._record_execute_command( + pipeline, original_execute, sdk, args, kwargs, command_name, command_str, is_pre_app_start + ), + span_kind=OTelSpanKind.CLIENT, + ) + + async def _traced_async_pipeline_immediate_execute( + self, pipeline: Any, original_execute: Any, sdk: TuskDrift, args: tuple, kwargs: dict + ) -> Any: + """Traced async Pipeline.immediate_execute_command method for WATCH and other immediate commands.""" + if sdk.mode == TuskDriftMode.DISABLED: + return await original_execute(pipeline, *args, **kwargs) + + command_name = args[0] if args else "UNKNOWN" + command_str = self._format_command(args) + + # For REPLAY mode, use sync mocking (mocks are retrieved synchronously) + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_execute_command(sdk, command_name, command_str, args), + no_op_request_handler=lambda: self._get_default_response(command_name), + is_server_request=False, + ) + + # RECORD mode with async execution + return await self._record_async_execute_command( + pipeline, original_execute, sdk, args, kwargs, command_name, command_str + ) + def _replay_execute_command(self, sdk: TuskDrift, command_name: str, command_str: str, args: tuple) -> Any: """Handle REPLAY mode for execute_command.""" span_name = f"redis.{command_name}" @@ -194,8 +318,11 @@ def _replay_execute_command(self, sdk: TuskDrift, command_name: str, command_str raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): + # Build input_value using shared helper + input_value = self._build_command_input_value(command_str, args) + mock_result = self._try_get_mock( - sdk, command_name, command_str, args, span_info.trace_id, span_info.span_id, span_info.parent_span_id + sdk, command_name, input_value, span_info.trace_id, span_info.span_id, span_info.parent_span_id ) if mock_result is None: @@ -369,6 +496,83 @@ def original_call(): span_kind=OTelSpanKind.CLIENT, ) + async def _traced_async_pipeline_execute( + self, pipeline: Any, original_execute: Any, sdk: TuskDrift, args: tuple, kwargs: dict + ) -> Any: + """Traced async Pipeline.execute method.""" + if sdk.mode == TuskDriftMode.DISABLED: + return await original_execute(pipeline, *args, **kwargs) + + # Get commands from pipeline + command_stack = self._get_pipeline_commands(pipeline) + command_str = self._format_pipeline_commands(command_stack) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._replay_pipeline_execute(sdk, command_str, command_stack), + no_op_request_handler=lambda: [], # Empty list for pipeline + is_server_request=False, + ) + + # RECORD mode with async execution + return await self._record_async_pipeline_execute( + pipeline, original_execute, sdk, args, kwargs, command_str, command_stack + ) + + async def _record_async_pipeline_execute( + self, + pipeline: Any, + original_execute: Any, + sdk: TuskDrift, + args: tuple, + kwargs: dict, + command_str: str, + command_stack: list, + ) -> Any: + """Handle async RECORD mode for pipeline execute.""" + is_pre_app_start = not sdk.app_ready + span_name = "redis.pipeline" + + # Create span using SpanUtils + span_info = SpanUtils.create_span( + CreateSpanOptions( + name=span_name, + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: span_name, + TdSpanAttributes.PACKAGE_NAME: "redis", + TdSpanAttributes.INSTRUMENTATION_NAME: "RedisInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "pipeline", + TdSpanAttributes.PACKAGE_TYPE: PackageType.REDIS.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + if not span_info: + # Fallback to original call if span creation fails + return await original_execute(pipeline, *args, **kwargs) + + error = None + result = None + + with SpanUtils.with_span(span_info): + try: + result = await original_execute(pipeline, *args, **kwargs) + return result + except Exception as e: + error = e + raise + finally: + self._finalize_pipeline_span( + span_info.span, + command_str, + command_stack, + result if error is None else None, + error, + ) + def _replay_pipeline_execute(self, sdk: TuskDrift, command_str: str, command_stack: list) -> Any: """Handle REPLAY mode for pipeline execute.""" span_name = "redis.pipeline" @@ -394,11 +598,13 @@ def _replay_pipeline_execute(self, sdk: TuskDrift, command_str: str, command_sta raise RuntimeError("Error creating span in replay mode") with SpanUtils.with_span(span_info): + # Build input_value the same way as _finalize_pipeline_span + input_value = self._build_pipeline_input_value(command_str, command_stack) + mock_result = self._try_get_mock( sdk, "pipeline", - command_str, - command_stack, + input_value, span_info.trace_id, span_info.span_id, span_info.parent_span_id, @@ -516,25 +722,34 @@ def _format_pipeline_commands(self, command_stack: list) -> str: return "PIPELINE: " + " ".join(commands) + def _build_command_input_value(self, command_str: str, args: tuple) -> dict[str, Any]: + """Build input_value for single commands (used by both record and replay).""" + input_value: dict[str, Any] = {"command": command_str.strip()} + if args is not None: + input_value["arguments"] = self._serialize_args(args) + return input_value + + def _build_pipeline_input_value(self, command_str: str, command_stack: list) -> dict[str, Any]: + """Build input_value for pipeline operations (used by both record and replay).""" + serialized_commands = [ + self._serialize_args(cmd.args if hasattr(cmd, "args") else cmd[0]) for cmd in command_stack + ] + return { + "command": command_str, + "commands": serialized_commands, + } + def _try_get_mock( self, sdk: TuskDrift, command_name: str, - command_str: str, - args: Any, + input_value: dict[str, Any], trace_id: str, span_id: str, parent_span_id: str | None, ) -> dict[str, Any] | None: """Try to get a mocked response from CLI.""" try: - # Build input value - input_value = { - "command": command_str.strip(), - } - if args is not None: - input_value["arguments"] = self._serialize_args(args) - # Generate schema and hashes for CLI matching input_result = JsonSchemaHelper.generate_schema_and_hash(input_value, {}) @@ -577,6 +792,7 @@ def _try_get_mock( outbound_span=mock_span, ) + command_str = input_value.get("command", "") logger.debug(f"Requesting mock from CLI for command: {command_str[:50]}...") mock_response_output = sdk.request_mock_sync(mock_request) logger.debug(f"CLI returned: found={mock_response_output.found}") @@ -601,12 +817,8 @@ def _finalize_command_span( ) -> None: """Finalize span with command data.""" try: - # Build input value - input_value = { - "command": command.strip(), - } - if args is not None: - input_value["arguments"] = self._serialize_args(args) + # Build input value using shared helper + input_value = self._build_command_input_value(command, args) # Build output value output_value = {} @@ -653,14 +865,8 @@ def _finalize_pipeline_span( ) -> None: """Finalize span with pipeline data.""" try: - # Build input value - serialized_commands = [ - self._serialize_args(cmd.args if hasattr(cmd, "args") else cmd[0]) for cmd in command_stack - ] - input_value: dict[str, Any] = { - "command": command_str, - "commands": serialized_commands, - } + # Build input value using shared helper + input_value = self._build_pipeline_input_value(command_str, command_stack) # Build output value output_value = {} @@ -707,9 +913,10 @@ def _serialize_value(self, value: Any) -> Any: """Serialize a single value for JSON.""" if isinstance(value, bytes): try: - return value.decode("utf-8") + decoded = value.decode("utf-8") + return {"__bytes__": True, "encoding": "utf8", "value": decoded} except UnicodeDecodeError: - return value.hex() + return {"__bytes__": True, "encoding": "hex", "value": value.hex()} elif isinstance(value, (str, int, float, bool, type(None))): return value elif isinstance(value, (list, tuple)): @@ -725,6 +932,24 @@ def _serialize_response(self, response: Any) -> Any: """Serialize Redis response for recording.""" return self._serialize_value(response) + def _deserialize_value(self, value: Any) -> Any: + """Deserialize a value, converting typed wrappers back to original types.""" + if isinstance(value, dict): + # Check for bytes wrapper + if value.get("__bytes__") is True: + encoding = value.get("encoding") + data = value.get("value", "") + if encoding == "utf8": + return data.encode("utf-8") + elif encoding == "hex": + return bytes.fromhex(data) + return data # fallback + # Recursively deserialize dict values + return {k: self._deserialize_value(v) for k, v in value.items()} + elif isinstance(value, list): + return [self._deserialize_value(v) for v in value] + return value + def _deserialize_response(self, mock_data: dict[str, Any]) -> Any: """Deserialize mocked response data from CLI. @@ -735,9 +960,9 @@ def _deserialize_response(self, mock_data: dict[str, Any]) -> Any: if isinstance(mock_data, dict): if "result" in mock_data: - return mock_data["result"] + return self._deserialize_value(mock_data["result"]) elif "results" in mock_data: - return mock_data["results"] + return [self._deserialize_value(r) for r in mock_data["results"]] logger.warning(f"Could not deserialize mock_data structure: {mock_data}") return None diff --git a/scripts/generate_manifest.py b/scripts/generate_manifest.py index 5a1a71c..482f09a 100644 --- a/scripts/generate_manifest.py +++ b/scripts/generate_manifest.py @@ -176,6 +176,7 @@ def main() -> int: manifest = { "sdkVersion": sdk_version, "language": "python", + "pythonVersion": ">=3.12", "generatedAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"), "instrumentations": instrumentations, }