Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions src/plexosdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _get_plexos_version(self) -> tuple[int, ...] | None:
return None
if not result:
return None
return tuple(map(int, result[0].split(".")))
return tuple(map(int, cast(str, result[0]).split(".")))

@classmethod
def from_xml(
Expand Down Expand Up @@ -216,7 +216,6 @@ def from_xml(
xml_tags = set([e.tag for e in xml_handler.root]) # Extract set of valid tags from xml
for tag in xml_tags:
# Only parse valid schemas that we maintain.
# NOTE: If there are some missing tables, we need to add them to the Enums.
schema_enum = str2enum(tag)
if not schema_enum:
continue
Expand Down Expand Up @@ -500,7 +499,6 @@ def add_membership(
child_object_id = self.get_object_id(child_class_enum, child_object_name)
collection_id = self.get_collection_id(collection_enum, parent_class_enum, child_class_enum)

# NOTE: Measure if this is faster than passing the ids
query = f"""
INSERT INTO {Schema.Memberships.name}
(parent_class_id,parent_object_id, collection_id, child_class_id, child_object_id)
Expand Down Expand Up @@ -925,10 +923,7 @@ def add_properties_from_records(
metadata_map = prepared.metadata_map

if not params:
msg = f"Failed to parse the properties for the given {collection=} and {object_class=}. "
msg += "Check the function plan_property_inserts"
return
# raise PropertyError(msg)

has_datafile_text = any(meta.get("datafile_text") for meta in metadata_map.values())
has_timeslice_text = any(meta.get("timeslice") for meta in metadata_map.values())
Expand Down Expand Up @@ -1273,7 +1268,6 @@ def add_report(
)
raise NameError(msg)

# NOTE: We can migrate this to its own `get_property_report_id` if needed.
property_id = self.query(
"select property_id from t_property_report where collection_id = ? and name = ?",
(collection_id, property),
Expand Down Expand Up @@ -1370,7 +1364,7 @@ def copy_object(
object_id = self.get_object_id(object_class, name=original_object_name)
category_id = self.query("SELECT category_id from t_object WHERE object_id = ?", (object_id,))
category = self.query("SELECT name from t_category WHERE category_id = ?", (category_id[0][0],))
new_object_id = self.add_object(object_class, new_object_name, category=category[0][0])
new_object_id = self.add_object(object_class, new_object_name, category=cast(str, category[0][0]))
membership_mapping = self.copy_object_memberships(
object_class=object_class, original_name=original_object_name, new_name=new_object_name
)
Expand Down Expand Up @@ -1923,7 +1917,7 @@ def get_attribute(
*,
object_name: str,
attribute_name: str,
) -> Any:
) -> tuple[int | float | str | bytes | None, ...]:
"""Get attribute details for a specific object."""
query = """
SELECT
Expand All @@ -1939,7 +1933,7 @@ def get_attribute(

result = self._db.fetchone(query, (attribute_id, object_id))
assert result
return cast(Any, result)
return result

def get_attribute_id(self, class_enum: ClassEnum, /, name: str) -> int:
"""Return the ID for a given attribute.
Expand Down Expand Up @@ -2315,7 +2309,7 @@ def list_object_memberships(
"""

query_conditions = []
params: dict[str, int | float | str] = {}
params: dict[str, int | float | str | bytes | None] = {}

# We first add conditions for looking the object.
query_conditions.append(
Expand Down Expand Up @@ -2386,7 +2380,7 @@ def get_memberships_system(
"""

extra = ""
extra_params: list[Any] = []
extra_params: list[str] = []
if collection:
extra = " AND parent_class.name = ? AND collections.name = ?"
extra_params = [ClassEnum.System.value, collection.value]
Expand Down Expand Up @@ -2522,7 +2516,7 @@ def get_object_data_ids(
# assert result
if not result:
return []
return [row[0] for row in result]
return [cast(int, row[0]) for row in result]

def get_object_properties(
self,
Expand Down Expand Up @@ -2673,7 +2667,7 @@ def get_object_id(self, class_enum: ClassEnum, /, name: str, *, category: str |
AND
t_class.name = ?
"""
params: list[Any] = [name, class_enum]
params: list[str | int] = [name, class_enum]
if category:
category_id = self.get_category_id(class_enum, category)
params.append(category_id)
Expand Down Expand Up @@ -2705,7 +2699,7 @@ def get_objects_id(
"""
result = self._db.fetchall(query, tuple(names))
assert result
return [r[0] for r in result]
return [cast(int, r[0]) for r in result]

def get_plexos_version(self) -> tuple[int, ...] | None:
"""Return the version information of the PLEXOS model."""
Expand Down Expand Up @@ -2970,7 +2964,7 @@ def iterate_properties(
If specified category does not exist
"""
conditions: list[str] = []
query_params: list[Any] = []
query_params: list[str] = []

if class_enum and checks_module.check_class_exists(self, class_enum):
conditions.append(f"child_class.name = '{class_enum}'")
Expand Down Expand Up @@ -3068,7 +3062,7 @@ def list_attributes(self, class_enum: ClassEnum) -> list[str]:
"""
result = self._db.fetchall_dict(query, (class_enum,))
assert result
return [row["name"] for row in result]
return [cast(str, row["name"]) for row in result]

def list_categories(self, class_enum: ClassEnum) -> list[str]:
"""Get all categories for a specific class.
Expand Down Expand Up @@ -3116,7 +3110,7 @@ def list_categories(self, class_enum: ClassEnum) -> list[str]:
"""
result = self._db.fetchall_dict(query, (class_enum,))
assert result
return [row["name"] for row in result]
return [cast(str, row["name"]) for row in result]

def list_child_objects(
self,
Expand Down Expand Up @@ -3221,7 +3215,7 @@ def list_child_objects(
WHERE parent_obj.object_id = ?
"""

params: list[Any] = [parent_object_id]
params: list[int | str] = [parent_object_id]

if child_class is not None:
query += " AND child_class.name = ?"
Expand Down Expand Up @@ -3262,7 +3256,7 @@ def list_classes(self) -> list[str]:
query_string = f"SELECT name from {Schema.Class.name}"
result = self.query(query_string)
assert result
return [d[0] for d in result]
return [cast(str, d[0]) for d in result]

def list_collections(
self,
Expand Down Expand Up @@ -3364,7 +3358,7 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non
"""
class_id = self.get_class_id(class_enum)

params: Sequence[Any]
params: Sequence[int | str]
if category is None:
query = f"SELECT name FROM {Schema.Objects.name} WHERE class_id = ? ORDER BY name"
params = (class_id,)
Expand All @@ -3386,7 +3380,7 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non

result = self._db.query(query, params)
assert result is not None
return [row[0] for row in result]
return [cast(str, row[0]) for row in result]

def list_parent_objects(
self,
Expand Down Expand Up @@ -3487,7 +3481,7 @@ def list_parent_objects(
WHERE child_obj.object_id = ?
"""

params: list[Any] = [child_object_id]
params: list[int | str] = [child_object_id]

if parent_class is not None:
query += " AND parent_class.name = ?"
Expand Down Expand Up @@ -3548,7 +3542,7 @@ def list_scenarios(self) -> list[str]:
"""
result = self.query(query_string, (ClassEnum.Scenario,))
assert result
return [d[0] for d in result]
return [cast(str, d[0]) for d in result]

def list_models(self) -> list[str]:
"""Return all models in the database.
Expand Down Expand Up @@ -3585,7 +3579,7 @@ def list_models(self) -> list[str]:
"""
result = self.query(query_string, (ClassEnum.Model,))
assert result
return [d[0] for d in result]
return [cast(str, d[0]) for d in result]

def list_scenarios_by_model(self, model_name: str) -> list[str]:
"""
Expand Down Expand Up @@ -3625,14 +3619,14 @@ def list_scenarios_by_model(self, model_name: str) -> list[str]:
WHERE membership.parent_object_id = ? and t_class.name = ?
"""
result = self.query(query, (parent_object_id, ClassEnum.Scenario))
return [row[0] for row in result] if result else []
return [cast(str, row[0]) for row in result] if result else []

def list_units(self) -> list[dict[int, str]]:
"""List all available units in the database."""
query_string = "SELECT unit_id, value from t_unit"
result = self.query(query_string)
assert result
return [{d[0]: d[1]} for d in result]
return [{cast(int, d[0]): cast(str, d[1])} for d in result]

def list_valid_properties(
self,
Expand Down Expand Up @@ -3687,7 +3681,7 @@ def list_valid_properties(
query = "SELECT name from t_property where collection_id = ?"
result = self.query(query, (collection_id,))
assert result
return [d[0] for d in result]
return [cast(str, d[0]) for d in result]

def list_valid_properties_report(
self,
Expand Down Expand Up @@ -3730,9 +3724,15 @@ def list_valid_properties_report(
query = "SELECT name from t_property_report where collection_id = ?"
result = self.query(query, (collection_id,))
assert result
return [d[0] for d in result]
return [cast(str, d[0]) for d in result]

def query(self, query_string: str, params: tuple[Any, ...] | dict[str, Any] | None = None) -> list[Any]:
def query(
self,
query_string: str,
params: tuple[int | float | str | bytes | None, ...]
| dict[str, int | float | str | bytes | None]
| None = None,
) -> list[tuple[int | float | str | bytes | None, ...]]:
"""Execute a read-only query and return all results.

Executes a SQL SELECT query against the database and returns the results.
Expand Down Expand Up @@ -3858,9 +3858,11 @@ def to_xml(self, target_path: str | Path) -> bool:
if not rows:
continue
column_types_tuples = self.query(f"SELECT name, type FROM pragma_table_info('{table_name}')")
column_types: dict[str, str] = {key: value for key, value in column_types_tuples}
column_types: dict[str, str] = {
cast(str, key): cast(str, value) for key, value in column_types_tuples
}
logger.trace("Adding {} to {}", table_name, target_path)
xml_handler.create_table_element(rows, column_types, table_name)
xml_handler.create_table_element(rows, column_types, cast(str, table_name))

xml_handler.to_xml(target_path)

Expand Down Expand Up @@ -3946,7 +3948,7 @@ def update_object(
object_id = self.get_object_id(class_enum, object_name)

set_clauses = ["name = ?"]
params: list[Any] = [new_name]
params: list[str | int] = [new_name]

if new_category is not None:
category_id = self.get_category_id(class_enum, new_category)
Expand Down
Loading
Loading