From c274a8987771d755889da441efde7a0d94467a3d Mon Sep 17 00:00:00 2001 From: Jason Reding Date: Fri, 13 Sep 2024 18:22:00 -0500 Subject: [PATCH 1/3] Updating channel pool so that it doesn't limit message size for a channel and will bypass proxies when target address is on the localhost. --- .../grpc/channelpool.py | 48 +++++++++++++++---- packages/service/tests/unit/grpc/__init__.py | 1 + .../tests/unit/grpc/channelpool/__init__.py | 1 + .../grpc/channelpool/test_channel_pool.py | 32 +++++++++++++ 4 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 packages/service/tests/unit/grpc/__init__.py create mode 100644 packages/service/tests/unit/grpc/channelpool/__init__.py create mode 100644 packages/service/tests/unit/grpc/channelpool/test_channel_pool.py diff --git a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py index 1ec44e21c..2e982d661 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py +++ b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py @@ -2,16 +2,13 @@ from __future__ import annotations +import ipaddress +import re import sys from threading import Lock from types import TracebackType -from typing import ( - Dict, - Literal, - Optional, - Type, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING, Dict, Literal, Optional, Type +from urllib.parse import urlparse import grpc @@ -57,7 +54,13 @@ def get_channel(self, target: str) -> grpc.Channel: with self._lock: if target not in self._channel_cache: self._lock.release() - new_channel = grpc.insecure_channel(target) + options = [ + ("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1), + ] + if self._is_local(target): + options.append(("grpc.enable_http_proxy", 0)) + new_channel = grpc.insecure_channel(target, options) if ClientLogger.is_enabled(): new_channel = grpc.intercept_channel(new_channel, ClientLogger()) self._lock.acquire() @@ -78,3 +81,32 @@ def close(self) -> None: for channel in self._channel_cache.values(): channel.close() self._channel_cache.clear() + + def _is_local(self, target: str) -> bool: + hostname = "" + # First, check if the target string is in URL format + parse_result = urlparse(target) + if parse_result.scheme and parse_result.hostname and parse_result.port: + hostname = parse_result.hostname + else: + # Next, check for target string in : format + match = re.match(r"^(.*):(\d+)$", target) + if match: + hostname = match.group(1) + + if not hostname: + return False + if hostname == "localhost" or hostname == "LOCALHOST": + return True + + # IPv6 addresses don't support parsing with leading/trailing brackets so we need to remove them. + match = re.match(r"^\[(.*)\]$", hostname) + if match: + hostname = match.group(1) + + try: + address = ipaddress.ip_address(hostname) + return address.is_loopback + except: + pass + return False diff --git a/packages/service/tests/unit/grpc/__init__.py b/packages/service/tests/unit/grpc/__init__.py new file mode 100644 index 000000000..c8b601ecb --- /dev/null +++ b/packages/service/tests/unit/grpc/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ni_measurement_plugin_sdk_service.grpc.""" diff --git a/packages/service/tests/unit/grpc/channelpool/__init__.py b/packages/service/tests/unit/grpc/channelpool/__init__.py new file mode 100644 index 000000000..48a9614f8 --- /dev/null +++ b/packages/service/tests/unit/grpc/channelpool/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ni_measurement_plugin_sdk_service.grpc.channelpool.""" diff --git a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py new file mode 100644 index 000000000..8b30bae68 --- /dev/null +++ b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py @@ -0,0 +1,32 @@ +import pytest + +from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool + + +@pytest.mark.parametrize( + "target,expected_result", + [ + ("127.0.0.1", False), + ("[::1]", False), + ("localhost", False), + ("127.0.0.1:100", True), + ("[::1]:100", True), + ("localhost:100", True), + ("http://127.0.0.1", False), + ("http://[::1]", False), + ("http://localhost", False), + ("http://127.0.0.1:100", True), + ("http://[::1]:100", True), + ("http://localhost:100", True), + ("1.1.1.1:100", False), + ("http://www.google.com:80", False), + ], +) +def test___channel_pool___is_local___returns_expected_result( + target: str, expected_result: bool +) -> None: + channel_pool = GrpcChannelPool() + + result = channel_pool._is_local(target) + + assert result == expected_result From a6ec4d67a9ecc60b1fadfa61e1b2fd2ad396d6ad Mon Sep 17 00:00:00 2001 From: Jason Reding Date: Mon, 16 Sep 2024 17:40:41 -0500 Subject: [PATCH 2/3] Brad's feedback. --- .../grpc/channelpool.py | 29 ++++++++++--------- .../grpc/channelpool/test_channel_pool.py | 12 ++++---- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py index 2e982d661..3840d6154 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py +++ b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py @@ -54,15 +54,7 @@ def get_channel(self, target: str) -> grpc.Channel: with self._lock: if target not in self._channel_cache: self._lock.release() - options = [ - ("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1), - ] - if self._is_local(target): - options.append(("grpc.enable_http_proxy", 0)) - new_channel = grpc.insecure_channel(target, options) - if ClientLogger.is_enabled(): - new_channel = grpc.intercept_channel(new_channel, ClientLogger()) + new_channel = self._create_channel(target) self._lock.acquire() if target not in self._channel_cache: self._channel_cache[target] = new_channel @@ -82,6 +74,18 @@ def close(self) -> None: channel.close() self._channel_cache.clear() + def _create_channel(self, target: str) -> grpc.Channel: + options = [ + ("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1), + ] + if self._is_local(target): + options.append(("grpc.enable_http_proxy", 0)) + channel = grpc.insecure_channel(target, options) + if ClientLogger.is_enabled(): + channel = grpc.intercept_channel(channel, ClientLogger()) + return channel + def _is_local(self, target: str) -> bool: hostname = "" # First, check if the target string is in URL format @@ -98,7 +102,7 @@ def _is_local(self, target: str) -> bool: return False if hostname == "localhost" or hostname == "LOCALHOST": return True - + # IPv6 addresses don't support parsing with leading/trailing brackets so we need to remove them. match = re.match(r"^\[(.*)\]$", hostname) if match: @@ -107,6 +111,5 @@ def _is_local(self, target: str) -> bool: try: address = ipaddress.ip_address(hostname) return address.is_loopback - except: - pass - return False + except ValueError: + return False diff --git a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py index 8b30bae68..e79295c6c 100644 --- a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py +++ b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py @@ -6,15 +6,15 @@ @pytest.mark.parametrize( "target,expected_result", [ - ("127.0.0.1", False), - ("[::1]", False), - ("localhost", False), + ("127.0.0.1", False), # Port must be specified explicitly + ("[::1]", False), # Port must be specified explicitly + ("localhost", False), # Port must be specified explicitly ("127.0.0.1:100", True), ("[::1]:100", True), ("localhost:100", True), - ("http://127.0.0.1", False), - ("http://[::1]", False), - ("http://localhost", False), + ("http://127.0.0.1", False), # Port must be specified explicitly + ("http://[::1]", False), # Port must be specified explicitly + ("http://localhost", False), # Port must be specified explicitly ("http://127.0.0.1:100", True), ("http://[::1]:100", True), ("http://localhost:100", True), From 56d048e62a30bba1f1a459db73a9cdaafcd680a8 Mon Sep 17 00:00:00 2001 From: Jason Reding Date: Mon, 16 Sep 2024 18:50:17 -0500 Subject: [PATCH 3/3] Formatting to conform to style guide. --- .../grpc/channelpool.py | 3 ++- .../tests/unit/grpc/channelpool/test_channel_pool.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py index 3840d6154..a6719ef5a 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py +++ b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py @@ -103,7 +103,8 @@ def _is_local(self, target: str) -> bool: if hostname == "localhost" or hostname == "LOCALHOST": return True - # IPv6 addresses don't support parsing with leading/trailing brackets so we need to remove them. + # IPv6 addresses don't support parsing with leading/trailing brackets + # so we need to remove them. match = re.match(r"^\[(.*)\]$", hostname) if match: hostname = match.group(1) diff --git a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py index e79295c6c..8d402fc14 100644 --- a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py +++ b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py @@ -6,15 +6,15 @@ @pytest.mark.parametrize( "target,expected_result", [ - ("127.0.0.1", False), # Port must be specified explicitly - ("[::1]", False), # Port must be specified explicitly - ("localhost", False), # Port must be specified explicitly + ("127.0.0.1", False), # Port must be specified explicitly + ("[::1]", False), # Port must be specified explicitly + ("localhost", False), # Port must be specified explicitly ("127.0.0.1:100", True), ("[::1]:100", True), ("localhost:100", True), - ("http://127.0.0.1", False), # Port must be specified explicitly - ("http://[::1]", False), # Port must be specified explicitly - ("http://localhost", False), # Port must be specified explicitly + ("http://127.0.0.1", False), # Port must be specified explicitly + ("http://[::1]", False), # Port must be specified explicitly + ("http://localhost", False), # Port must be specified explicitly ("http://127.0.0.1:100", True), ("http://[::1]:100", True), ("http://localhost:100", True),