Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ on:
- master

jobs:
lint:
runs-on: ubuntu-24.04

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"

- name: Lint
run: make lint CHECK=1

tests:
runs-on: ubuntu-24.04

Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.12
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
10 changes: 8 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ install-test: install ## Install test dependencies in local virtualenv
install-lint: install ## Install lint dependencies in local virtualenv
($(VENV_RUN); $(PIP_CMD) install -r $(LINT_REQS))

lint: install-lint ## Format code with ruff
$(VENV_DIR)/bin/ruff format postgresql_proxy tests plugins
install-pre-commit: install-lint ## Install and register the pre-commit hook
$(VENV_DIR)/bin/pre-commit install

CHECK ?=

lint: install-lint ## Format code with ruff (use CHECK=1 to check without modifying)
$(VENV_DIR)/bin/ruff format $(if $(CHECK),--check,) postgresql_proxy tests plugins
$(VENV_DIR)/bin/ruff check $(if $(CHECK),,--fix) postgresql_proxy tests plugins

start-postgres: ## Start local PostgreSQL test container and wait until ready
@set -euo pipefail; \
Expand Down
40 changes: 25 additions & 15 deletions plugins/tableau_hll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
import re

# The field to replace
field_pattern = re.compile('(?<=[^\w])count\(distinct (?:cast\()?("[^"]+")\.("[^"]+")(?: as text\))?\)', re.IGNORECASE)
field_pattern = re.compile(
'(?<=[^\w])count\(distinct (?:cast\()?("[^"]+")\.("[^"]+")(?: as text\))?\)',
re.IGNORECASE,
)
# Table name
table_pattern = re.compile('from ([^\(\)]+?)\s*\)? (?:AS )?("[^"]+")', re.IGNORECASE | re.DOTALL)
table_pattern = re.compile(
'from ([^\(\)]+?)\s*\)? (?:AS )?("[^"]+")', re.IGNORECASE | re.DOTALL
)


def rewrite_query(query, context):
original_table = ''
table_alias = ''
original_table = ""
table_alias = ""

# cache only works on current query. Mainly because there's no way to tell if the table has been modified between
# 2 different requests.
Expand All @@ -26,8 +32,8 @@ def replace(match):
hll_column_candidate = match.group(2).strip()

# need to know which columns are hll
if not hll_table.lower() in column_cache:
db_conn_info = context['instance_config'].redirect
if hll_table.lower() not in column_cache:
db_conn_info = context["instance_config"].redirect
conn = None
try:
conn = psycopg2.connect(
Expand All @@ -36,17 +42,17 @@ def replace(match):
db_conn_info.host,
db_conn_info.port,
# Get auth information from the proxied request
context['connect_params']['database'],
context['connect_params']['user']
context["connect_params"]["database"],
context["connect_params"]["user"],
)
)

hll_type_code = None
cur = conn.cursor()
try:
cur.execute("SELECT oid FROM pg_type WHERE typname='hll';")
hll_type_code, = cur.fetchone()
except:
(hll_type_code,) = cur.fetchone()
except Exception:
pass
finally:
cur.close()
Expand All @@ -65,7 +71,7 @@ def replace(match):
hll_columns.add(desc.name.lower())

column_cache[hll_table.lower()] = hll_columns
except:
except Exception:
pass
finally:
cur.close()
Expand All @@ -76,13 +82,17 @@ def replace(match):
conn.close()

# Replace
if hll_column_candidate.strip('"').lower() in column_cache[hll_table.lower()]:
return ' hll_cardinality(hll_union_agg({}.{})) :: BIGINT '.format(match.group(1), match.group(2))
if (
hll_column_candidate.strip('"').lower()
in column_cache[hll_table.lower()]
):
return " hll_cardinality(hll_union_agg({}.{})) :: BIGINT ".format(
match.group(1), match.group(2)
)

# Don't replace
return match.group(0)


# Matches this string. The 2 groups are `schema.table` and `"alias"`
# FROM schema.table) "alias"
table_result = table_pattern.search(query)
Expand Down
98 changes: 47 additions & 51 deletions postgresql_proxy/config_schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""This class is used to validate the config"""

import logging

''' This class is used to validate the config
'''

class Schema:
def _validate(self):
pass

def __hyphen_to_underscore(self, k):
return k.replace('-', '_')
return k.replace("-", "_")

def _populate(self, data, definition):
try:
for (k, v) in data.items():
for k, v in data.items():
k = self.__hyphen_to_underscore(k)
if k in definition:
vtype = definition[k]
Expand All @@ -30,59 +31,59 @@ def _populate(self, data, definition):

def _assert_non_empty(self, *attrs):
for attr in attrs:
assert len(getattr(self, attr)) > 0, "{}.{} must not be empty".format(type(self).__name__, attr)
assert len(getattr(self, attr)) > 0, "{}.{} must not be empty".format(
type(self).__name__, attr
)

def _assert_non_null(self, *attrs):
for attr in attrs:
assert getattr(self, attr) is not None, "{}.{} must not be None".format(type(self).__name__, attr)
assert getattr(self, attr) is not None, "{}.{} must not be None".format(
type(self).__name__, attr
)


class InterceptQuerySettings(Schema):
def __init__(self, data):
self.plugin = None
self.function = None

self._populate(data, {
'plugin': str,
'function': str
})
self._populate(data, {"plugin": str, "function": str})

def _validate(self):
self._assert_non_null('plugin', 'function')
self._assert_non_empty('plugin', 'function')
self._assert_non_null("plugin", "function")
self._assert_non_empty("plugin", "function")


class InterceptCommandSettings(Schema):
def __init__(self, data):
self.queries = []
self.connects = None

self._populate(data, {
'queries': [InterceptQuerySettings],
'connects': str
})
self._populate(data, {"queries": [InterceptQuerySettings], "connects": str})


class InterceptResponseSettings(Schema):
def __init__(self, data):
self.parameter_responses = []
self.connects = None

self._populate(data, {
'parameter_status': [InterceptQuerySettings],
'connects': str
})
self._populate(
data, {"parameter_status": [InterceptQuerySettings], "connects": str}
)


class InterceptSettings(Schema):
def __init__(self, data):
self.commands = None
self.responses = None

self._populate(data, {
'commands': InterceptCommandSettings,
'responses': InterceptResponseSettings,
})
self._populate(
data,
{
"commands": InterceptCommandSettings,
"responses": InterceptResponseSettings,
},
)


class Connection(Schema):
Expand All @@ -91,15 +92,11 @@ def __init__(self, data):
self.host = None
self.port = None

self._populate(data, {
'name': str,
'host': str,
'port': int
})
self._populate(data, {"name": str, "host": str, "port": int})

def _validate(self):
self._assert_non_null('name', 'host', 'port')
self._assert_non_empty('name')
self._assert_non_null("name", "host", "port")
self._assert_non_empty("name")


class InstanceSettings(Schema):
Expand All @@ -108,15 +105,17 @@ def __init__(self, data):
self.redirect = None
self.intercept = None

self._populate(data, {
'listen': Connection,
'redirect': Connection,
'intercept': InterceptSettings
})

self._populate(
data,
{
"listen": Connection,
"redirect": Connection,
"intercept": InterceptSettings,
},
)

def _validate(self):
self._assert_non_null('listen', 'redirect')
self._assert_non_null("listen", "redirect")


class Settings(Schema):
Expand All @@ -125,15 +124,13 @@ def __init__(self, data):
self.intercept_log = None
self.general_log = None

self._populate(data, {
'log_level': str,
'intercept_log': str,
'general_log': str
})
self._populate(
data, {"log_level": str, "intercept_log": str, "general_log": str}
)

def _validate(self):
self._assert_non_null('log_level', 'intercept_log', 'general_log')
self._assert_non_empty('log_level', 'intercept_log', 'general_log')
self._assert_non_null("log_level", "intercept_log", "general_log")
self._assert_non_empty("log_level", "intercept_log", "general_log")


class Config(Schema):
Expand All @@ -142,11 +139,10 @@ def __init__(self, data):
self.settings = None
self.instances = []

self._populate(data, {
'plugins' : [str],
'settings' : Settings,
'instances' : [InstanceSettings]
})
self._populate(
data,
{"plugins": [str], "settings": Settings, "instances": [InstanceSettings]},
)

def _validate(self):
self._assert_non_empty('instances')
self._assert_non_empty("instances")
24 changes: 14 additions & 10 deletions postgresql_proxy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,28 @@ def __init__(self, sock, address, name, events, context):
self.context = context
self.interceptor = None
self.redirect_conn: Optional[Connection] = None
self.out_bytes = b''
self.in_bytes = b''
self.out_bytes = b""
self.in_bytes = b""
self.terminated = False

def parse_length(self, length_bytes):
return int.from_bytes(length_bytes, 'big')
return int.from_bytes(length_bytes, "big")

def encode_length(self, length):
return length.to_bytes(4, byteorder='big')
return length.to_bytes(4, byteorder="big")

def received(self, in_bytes):
self.in_bytes += in_bytes
# Read packet from byte array while there are enough bytes to make up a packet.
# Otherwise wait for more bytes to be received (break and exit)
while True:
ptype = self.in_bytes[0:1]
if ptype == b'\x00':
if ptype == b"\x00":
if len(self.in_bytes) < 4:
break
header_length = 4
body_length = self.parse_length(self.in_bytes[0:4]) - 4
elif ptype == b'N':
elif ptype == b"N":
header_length = 1
body_length = 0
else:
Expand All @@ -52,18 +52,22 @@ def received(self, in_bytes):
self.in_bytes = self.in_bytes[length:]

def process_inbound_packet(self, header, body):
if header != b'N':
if header != b"N":
packet_type = header[0:-4]
_logger.info("intercepting packet of type '%s' from %s", packet_type, self.name)
_logger.info(
"intercepting packet of type '%s' from %s", packet_type, self.name
)
body = self.interceptor.intercept(packet_type, body)
header = packet_type + self.encode_length(len(body) + 4)
if packet_type == b'X':
if packet_type == b"X":
# this a termination packet, it will indicate that the proxied client wants to close the
# postgres connection properly
self.terminated = True

message = header + body
_logger.debug("Received message. Relaying. Speaker: %s, message:\n%s", self.name, message)
_logger.debug(
"Received message. Relaying. Speaker: %s, message:\n%s", self.name, message
)

if self.redirect_conn:
# redirect_conn might not be set (anymore) at this stage
Expand Down
1 change: 0 additions & 1 deletion postgresql_proxy/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

ALLOWED_CONNECTION_PARAMETERS = [
"host",
"hostaddr",
Expand Down
Loading
Loading