From 17b3af7c55c049ad89b91fc31ce5a6344b30495d Mon Sep 17 00:00:00 2001 From: Christian Troelsen Date: Fri, 15 May 2026 14:23:33 +0000 Subject: [PATCH] Support Databricks query tags from session properties Signed-off-by: Christian Troelsen --- pyproject.toml | 2 +- sqlmesh/core/engine_adapter/databricks.py | 56 +++++++++ tests/core/engine_adapter/test_databricks.py | 120 +++++++++++++++++++ 3 files changed, 177 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bcc69c667e..c8059445f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ bigquery = [ # pinned an older SQLGlot which is incompatible with SQLMesh bigframes = ["bigframes>=1.32.0"] clickhouse = ["clickhouse-connect"] -databricks = ["databricks-sql-connector[pyarrow]"] +databricks = ["databricks-sql-connector[pyarrow]>=4.2.6"] dev = [ "agate", "beautifulsoup4", diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index e3d029a17d..2c8ab901b7 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -30,6 +30,43 @@ logger = logging.getLogger(__name__) +def _query_tags( + query_tags: t.Optional[t.Union[exp.Expr, str, int, float, bool]], +) -> t.Optional[t.Dict[str, t.Optional[str]]]: + if not query_tags: + return None + + if not isinstance(query_tags, exp.Map): + raise SQLMeshError("Invalid value for `session_properties.query_tags`. Must be a map.") + + keys = query_tags.args.get("keys") + values = query_tags.args.get("values") + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + raise SQLMeshError( + "Invalid value for `session_properties.query_tags`. Must be a map with array " + "keys and array values." + ) + + tags: t.Dict[str, t.Optional[str]] = {} + for key, value in zip(keys.expressions, values.expressions): + if not isinstance(key, exp.Literal) or not key.is_string: + raise SQLMeshError( + "Invalid key in `session_properties.query_tags`. Keys must be string literals." + ) + + if isinstance(value, exp.Null): + tags[key.this] = None + elif isinstance(value, exp.Literal) and value.is_string: + tags[key.this] = value.this + else: + raise SQLMeshError( + "Invalid value in `session_properties.query_tags`. Values must be string " + "literals or NULL." + ) + + return tags + + class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): DIALECT = "databricks" INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE @@ -98,6 +135,12 @@ def _use_spark_session(self) -> bool: def is_spark_session_connection(self) -> bool: return isinstance(self.connection, SparkSessionConnection) + @property + def _is_databricks_sql_connector_connection(self) -> bool: + return not self.is_spark_session_connection and not self._connection_pool.get_attribute( + "use_spark_engine_adapter" + ) + def _set_spark_engine_adapter_if_needed(self) -> None: self._spark_engine_adapter = None @@ -181,10 +224,23 @@ def _begin_session(self, properties: SessionProperties) -> t.Any: """Begin a new session.""" # Align the different possible connectors to a single catalog self.set_current_catalog(self.default_catalog) # type: ignore + self._connection_pool.set_attribute("query_tags", _query_tags(properties.get("query_tags"))) def _end_session(self) -> None: + self._connection_pool.set_attribute("query_tags", None) self._connection_pool.set_attribute("use_spark_engine_adapter", False) + def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None: + query_tags = self._connection_pool.get_attribute("query_tags") + if ( + query_tags + and "query_tags" not in kwargs + and self._is_databricks_sql_connector_connection + ): + kwargs["query_tags"] = query_tags + + return super()._execute(sql, track_rows_processed, **kwargs) + def _df_to_source_queries( self, df: DF, diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index de91fd3b70..510f9c7349 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -10,11 +10,23 @@ from sqlmesh.core.engine_adapter import DatabricksEngineAdapter from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType from sqlmesh.core.node import IntervalUnit +from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.databricks, pytest.mark.engine] +def _query_tags_map(*items: t.Optional[str]) -> exp.Map: + return exp.Map( + keys=exp.Array(expressions=[exp.Literal.string(item) for item in items[::2]]), + values=exp.Array( + expressions=[ + exp.Null() if item is None else exp.Literal.string(item) for item in items[1::2] + ] + ), + ) + + def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): mocker.patch( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", @@ -117,6 +129,114 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t. assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"] +def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "app", "sqlmesh")}): + adapter.execute("SELECT 1") + + adapter.cursor.execute.assert_called_with( + "SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"} + ) + + adapter.execute("SELECT 2") + + adapter.cursor.execute.assert_called_with("SELECT 2") + + +def test_session_query_tags_allow_none_values( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "feature", None)}): + adapter.execute("SELECT 1") + + adapter.cursor.execute.assert_called_with( + "SELECT 1", query_tags={"team": "data-eng", "feature": None} + ) + + +def test_session_query_tags_do_not_override_explicit_query_tags( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}): + adapter.execute("SELECT 1", query_tags={"team": "analytics"}) + + adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"}) + + +def test_session_query_tags_not_applied_to_spark_session_connection( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + mocker.patch.object( + DatabricksEngineAdapter, + "is_spark_session_connection", + new_callable=mocker.PropertyMock, + return_value=True, + ) + + with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}): + adapter.execute("SELECT 1") + + adapter.cursor.execute.assert_called_with("SELECT 1") + + +def test_session_query_tags_not_applied_to_spark_engine_adapter( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + spark_cursor = mocker.Mock() + adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor) + adapter._connection_pool.set_attribute("use_spark_engine_adapter", True) + + with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}): + adapter._connection_pool.set_attribute("use_spark_engine_adapter", True) + adapter.execute("SELECT 1") + + spark_cursor.execute.assert_called_with("SELECT 1") + + +@pytest.mark.parametrize( + "query_tags", + [ + "team:data-eng", + exp.Map( + keys=exp.Array(expressions=[exp.Literal.number(1)]), + values=exp.Array(expressions=[exp.Literal.string("data-eng")]), + ), + exp.Map( + keys=exp.Array(expressions=[exp.Literal.string("team")]), + values=exp.Array(expressions=[exp.Literal.number(1)]), + ), + ], +) +def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + with pytest.raises(SQLMeshError, match="session_properties.query_tags"): + with adapter.session({"query_tags": query_tags}): + pass + + def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): mocker.patch( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"