Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update default channel pool options when creating grpc channels #887

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

from __future__ import annotations

import ipaddress
import re
import sys
from threading import Lock
from types import TracebackType
from typing import (
Dict,
Literal,
Optional,
Type,
TYPE_CHECKING,
)
from typing import TYPE_CHECKING, Dict, Literal, Optional, Type
from urllib.parse import urlparse

import grpc

Expand Down Expand Up @@ -57,9 +54,7 @@ def get_channel(self, target: str) -> grpc.Channel:
with self._lock:
if target not in self._channel_cache:
self._lock.release()
new_channel = grpc.insecure_channel(target)
if ClientLogger.is_enabled():
new_channel = grpc.intercept_channel(new_channel, ClientLogger())
new_channel = self._create_channel(target)
self._lock.acquire()
if target not in self._channel_cache:
self._channel_cache[target] = new_channel
Expand All @@ -78,3 +73,44 @@ def close(self) -> None:
for channel in self._channel_cache.values():
channel.close()
self._channel_cache.clear()

def _create_channel(self, target: str) -> grpc.Channel:
options = [
("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1),
]
if self._is_local(target):
options.append(("grpc.enable_http_proxy", 0))
channel = grpc.insecure_channel(target, options)
if ClientLogger.is_enabled():
channel = grpc.intercept_channel(channel, ClientLogger())
return channel

def _is_local(self, target: str) -> bool:
hostname = ""
# First, check if the target string is in URL format
parse_result = urlparse(target)
if parse_result.scheme and parse_result.hostname and parse_result.port:
hostname = parse_result.hostname
else:
# Next, check for target string in <host_name>:<port> format
match = re.match(r"^(.*):(\d+)$", target)
if match:
hostname = match.group(1)

if not hostname:
return False
if hostname == "localhost" or hostname == "LOCALHOST":
return True

# IPv6 addresses don't support parsing with leading/trailing brackets
# so we need to remove them.
match = re.match(r"^\[(.*)\]$", hostname)
if match:
hostname = match.group(1)

try:
address = ipaddress.ip_address(hostname)
return address.is_loopback
except ValueError:
return False
1 change: 1 addition & 0 deletions packages/service/tests/unit/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ni_measurement_plugin_sdk_service.grpc."""
1 change: 1 addition & 0 deletions packages/service/tests/unit/grpc/channelpool/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ni_measurement_plugin_sdk_service.grpc.channelpool."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool


@pytest.mark.parametrize(
"target,expected_result",
[
("127.0.0.1", False), # Port must be specified explicitly
("[::1]", False), # Port must be specified explicitly
("localhost", False), # Port must be specified explicitly
("127.0.0.1:100", True),
("[::1]:100", True),
("localhost:100", True),
("http://127.0.0.1", False), # Port must be specified explicitly
("http://[::1]", False), # Port must be specified explicitly
("http://localhost", False), # Port must be specified explicitly
("http://127.0.0.1:100", True),
("http://[::1]:100", True),
("http://localhost:100", True),
("1.1.1.1:100", False),
("http://www.google.com:80", False),
],
)
def test___channel_pool___is_local___returns_expected_result(
target: str, expected_result: bool
) -> None:
channel_pool = GrpcChannelPool()

result = channel_pool._is_local(target)

assert result == expected_result