Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bigquery = [
# pinned an older SQLGlot which is incompatible with SQLMesh
bigframes = ["bigframes>=1.32.0"]
clickhouse = ["clickhouse-connect"]
databricks = ["databricks-sql-connector[pyarrow]"]
databricks = ["databricks-sql-connector[pyarrow]>=4.2.6"]
dev = [
"agate",
"beautifulsoup4",
Expand Down
56 changes: 56 additions & 0 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,43 @@
logger = logging.getLogger(__name__)


def _query_tags(
query_tags: t.Optional[t.Union[exp.Expr, str, int, float, bool]],
) -> t.Optional[t.Dict[str, t.Optional[str]]]:
if not query_tags:
return None

if not isinstance(query_tags, exp.Map):
raise SQLMeshError("Invalid value for `session_properties.query_tags`. Must be a map.")

keys = query_tags.args.get("keys")
values = query_tags.args.get("values")
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
raise SQLMeshError(
"Invalid value for `session_properties.query_tags`. Must be a map with array "
"keys and array values."
)

tags: t.Dict[str, t.Optional[str]] = {}
for key, value in zip(keys.expressions, values.expressions):
if not isinstance(key, exp.Literal) or not key.is_string:
raise SQLMeshError(
"Invalid key in `session_properties.query_tags`. Keys must be string literals."
)

if isinstance(value, exp.Null):
tags[key.this] = None
elif isinstance(value, exp.Literal) and value.is_string:
tags[key.this] = value.this
else:
raise SQLMeshError(
"Invalid value in `session_properties.query_tags`. Values must be string "
"literals or NULL."
)

return tags


class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
DIALECT = "databricks"
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
Expand Down Expand Up @@ -98,6 +135,12 @@ def _use_spark_session(self) -> bool:
def is_spark_session_connection(self) -> bool:
return isinstance(self.connection, SparkSessionConnection)

@property
def _is_databricks_sql_connector_connection(self) -> bool:
return not self.is_spark_session_connection and not self._connection_pool.get_attribute(
"use_spark_engine_adapter"
)

def _set_spark_engine_adapter_if_needed(self) -> None:
self._spark_engine_adapter = None

Expand Down Expand Up @@ -181,10 +224,23 @@ def _begin_session(self, properties: SessionProperties) -> t.Any:
"""Begin a new session."""
# Align the different possible connectors to a single catalog
self.set_current_catalog(self.default_catalog) # type: ignore
self._connection_pool.set_attribute("query_tags", _query_tags(properties.get("query_tags")))

def _end_session(self) -> None:
self._connection_pool.set_attribute("query_tags", None)
self._connection_pool.set_attribute("use_spark_engine_adapter", False)

def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
query_tags = self._connection_pool.get_attribute("query_tags")
if (
query_tags
and "query_tags" not in kwargs
and self._is_databricks_sql_connector_connection
):
kwargs["query_tags"] = query_tags

return super()._execute(sql, track_rows_processed, **kwargs)

def _df_to_source_queries(
self,
df: DF,
Expand Down
120 changes: 120 additions & 0 deletions tests/core/engine_adapter/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,23 @@
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
from sqlmesh.core.node import IntervalUnit
from sqlmesh.utils.errors import SQLMeshError
from tests.core.engine_adapter import to_sql_calls

pytestmark = [pytest.mark.databricks, pytest.mark.engine]


def _query_tags_map(*items: t.Optional[str]) -> exp.Map:
return exp.Map(
keys=exp.Array(expressions=[exp.Literal.string(item) for item in items[::2]]),
values=exp.Array(
expressions=[
exp.Null() if item is None else exp.Literal.string(item) for item in items[1::2]
]
),
)


def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists",
Expand Down Expand Up @@ -117,6 +129,114 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]


def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")

with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "app", "sqlmesh")}):
adapter.execute("SELECT 1")

adapter.cursor.execute.assert_called_with(
"SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"}
)

adapter.execute("SELECT 2")

adapter.cursor.execute.assert_called_with("SELECT 2")


def test_session_query_tags_allow_none_values(
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")

with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "feature", None)}):
adapter.execute("SELECT 1")

adapter.cursor.execute.assert_called_with(
"SELECT 1", query_tags={"team": "data-eng", "feature": None}
)


def test_session_query_tags_do_not_override_explicit_query_tags(
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")

with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
adapter.execute("SELECT 1", query_tags={"team": "analytics"})

adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"})


def test_session_query_tags_not_applied_to_spark_session_connection(
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
mocker.patch.object(
DatabricksEngineAdapter,
"is_spark_session_connection",
new_callable=mocker.PropertyMock,
return_value=True,
)

with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
adapter.execute("SELECT 1")

adapter.cursor.execute.assert_called_with("SELECT 1")


def test_session_query_tags_not_applied_to_spark_engine_adapter(
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
spark_cursor = mocker.Mock()
adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor)
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)

with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
adapter.execute("SELECT 1")

spark_cursor.execute.assert_called_with("SELECT 1")


@pytest.mark.parametrize(
"query_tags",
[
"team:data-eng",
exp.Map(
keys=exp.Array(expressions=[exp.Literal.number(1)]),
values=exp.Array(expressions=[exp.Literal.string("data-eng")]),
),
exp.Map(
keys=exp.Array(expressions=[exp.Literal.string("team")]),
values=exp.Array(expressions=[exp.Literal.number(1)]),
),
],
)
def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")

with pytest.raises(SQLMeshError, match="session_properties.query_tags"):
with adapter.session({"query_tags": query_tags}):
pass


def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
Expand Down