From 13614ed11b1900781f94c9ef6d5a82255ae356d8 Mon Sep 17 00:00:00 2001 From: mvelasqu Date: Wed, 27 May 2026 12:39:18 -0600 Subject: [PATCH 1/2] fix: agentic repo cleaning --- src/plexosdb/db.py | 68 ++++---- src/plexosdb/db_manager.py | 132 ++++++--------- src/plexosdb/types.py | 10 ++ src/plexosdb/utils.py | 253 ++++++++++++++++++----------- src/plexosdb/xml_handler.py | 22 +-- tests/test_plexosdb_copy_object.py | 16 +- 6 files changed, 265 insertions(+), 236 deletions(-) create mode 100644 src/plexosdb/types.py diff --git a/src/plexosdb/db.py b/src/plexosdb/db.py index fc2e6f8..17caa65 100644 --- a/src/plexosdb/db.py +++ b/src/plexosdb/db.py @@ -140,7 +140,7 @@ def _get_plexos_version(self) -> tuple[int, ...] | None: return None if not result: return None - return tuple(map(int, result[0].split("."))) + return tuple(map(int, cast(str, result[0]).split("."))) @classmethod def from_xml( @@ -216,7 +216,6 @@ def from_xml( xml_tags = set([e.tag for e in xml_handler.root]) # Extract set of valid tags from xml for tag in xml_tags: # Only parse valid schemas that we maintain. - # NOTE: If there are some missing tables, we need to add them to the Enums. schema_enum = str2enum(tag) if not schema_enum: continue @@ -500,7 +499,6 @@ def add_membership( child_object_id = self.get_object_id(child_class_enum, child_object_name) collection_id = self.get_collection_id(collection_enum, parent_class_enum, child_class_enum) - # NOTE: Measure if this is faster than passing the ids query = f""" INSERT INTO {Schema.Memberships.name} (parent_class_id,parent_object_id, collection_id, child_class_id, child_object_id) @@ -925,10 +923,7 @@ def add_properties_from_records( metadata_map = prepared.metadata_map if not params: - msg = f"Failed to parse the properties for the given {collection=} and {object_class=}. " - msg += "Check the function plan_property_inserts" return - # raise PropertyError(msg) has_datafile_text = any(meta.get("datafile_text") for meta in metadata_map.values()) has_timeslice_text = any(meta.get("timeslice") for meta in metadata_map.values()) @@ -1273,7 +1268,6 @@ def add_report( ) raise NameError(msg) - # NOTE: We can migrate this to its own `get_property_report_id` if needed. property_id = self.query( "select property_id from t_property_report where collection_id = ? and name = ?", (collection_id, property), @@ -1370,7 +1364,7 @@ def copy_object( object_id = self.get_object_id(object_class, name=original_object_name) category_id = self.query("SELECT category_id from t_object WHERE object_id = ?", (object_id,)) category = self.query("SELECT name from t_category WHERE category_id = ?", (category_id[0][0],)) - new_object_id = self.add_object(object_class, new_object_name, category=category[0][0]) + new_object_id = self.add_object(object_class, new_object_name, category=cast(str, category[0][0])) membership_mapping = self.copy_object_memberships( object_class=object_class, original_name=original_object_name, new_name=new_object_name ) @@ -1923,7 +1917,7 @@ def get_attribute( *, object_name: str, attribute_name: str, - ) -> Any: + ) -> tuple[int | float | str | bytes | None, ...]: """Get attribute details for a specific object.""" query = """ SELECT @@ -1939,7 +1933,7 @@ def get_attribute( result = self._db.fetchone(query, (attribute_id, object_id)) assert result - return cast(Any, result) + return result def get_attribute_id(self, class_enum: ClassEnum, /, name: str) -> int: """Return the ID for a given attribute. @@ -2315,7 +2309,7 @@ def list_object_memberships( """ query_conditions = [] - params: dict[str, int | float | str] = {} + params: dict[str, int | float | str | bytes | None] = {} # We first add conditions for looking the object. query_conditions.append( @@ -2386,7 +2380,7 @@ def get_memberships_system( """ extra = "" - extra_params: list[Any] = [] + extra_params: list[str] = [] if collection: extra = " AND parent_class.name = ? AND collections.name = ?" extra_params = [ClassEnum.System.value, collection.value] @@ -2522,7 +2516,7 @@ def get_object_data_ids( # assert result if not result: return [] - return [row[0] for row in result] + return [cast(int, row[0]) for row in result] def get_object_properties( self, @@ -2673,7 +2667,7 @@ def get_object_id(self, class_enum: ClassEnum, /, name: str, *, category: str | AND t_class.name = ? """ - params: list[Any] = [name, class_enum] + params: list[str | int] = [name, class_enum] if category: category_id = self.get_category_id(class_enum, category) params.append(category_id) @@ -2705,7 +2699,7 @@ def get_objects_id( """ result = self._db.fetchall(query, tuple(names)) assert result - return [r[0] for r in result] + return [cast(int, r[0]) for r in result] def get_plexos_version(self) -> tuple[int, ...] | None: """Return the version information of the PLEXOS model.""" @@ -2970,7 +2964,7 @@ def iterate_properties( If specified category does not exist """ conditions: list[str] = [] - query_params: list[Any] = [] + query_params: list[str] = [] if class_enum and checks_module.check_class_exists(self, class_enum): conditions.append(f"child_class.name = '{class_enum}'") @@ -3068,7 +3062,7 @@ def list_attributes(self, class_enum: ClassEnum) -> list[str]: """ result = self._db.fetchall_dict(query, (class_enum,)) assert result - return [row["name"] for row in result] + return [cast(str, row["name"]) for row in result] def list_categories(self, class_enum: ClassEnum) -> list[str]: """Get all categories for a specific class. @@ -3116,7 +3110,7 @@ def list_categories(self, class_enum: ClassEnum) -> list[str]: """ result = self._db.fetchall_dict(query, (class_enum,)) assert result - return [row["name"] for row in result] + return [cast(str, row["name"]) for row in result] def list_child_objects( self, @@ -3221,7 +3215,7 @@ def list_child_objects( WHERE parent_obj.object_id = ? """ - params: list[Any] = [parent_object_id] + params: list[int | str] = [parent_object_id] if child_class is not None: query += " AND child_class.name = ?" @@ -3262,7 +3256,7 @@ def list_classes(self) -> list[str]: query_string = f"SELECT name from {Schema.Class.name}" result = self.query(query_string) assert result - return [d[0] for d in result] + return [cast(str, d[0]) for d in result] def list_collections( self, @@ -3364,7 +3358,7 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non """ class_id = self.get_class_id(class_enum) - params: Sequence[Any] + params: Sequence[int | str] if category is None: query = f"SELECT name FROM {Schema.Objects.name} WHERE class_id = ? ORDER BY name" params = (class_id,) @@ -3386,7 +3380,7 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non result = self._db.query(query, params) assert result is not None - return [row[0] for row in result] + return [cast(str, row[0]) for row in result] def list_parent_objects( self, @@ -3487,7 +3481,7 @@ def list_parent_objects( WHERE child_obj.object_id = ? """ - params: list[Any] = [child_object_id] + params: list[int | str] = [child_object_id] if parent_class is not None: query += " AND parent_class.name = ?" @@ -3548,7 +3542,7 @@ def list_scenarios(self) -> list[str]: """ result = self.query(query_string, (ClassEnum.Scenario,)) assert result - return [d[0] for d in result] + return [cast(str, d[0]) for d in result] def list_models(self) -> list[str]: """Return all models in the database. @@ -3585,7 +3579,7 @@ def list_models(self) -> list[str]: """ result = self.query(query_string, (ClassEnum.Model,)) assert result - return [d[0] for d in result] + return [cast(str, d[0]) for d in result] def list_scenarios_by_model(self, model_name: str) -> list[str]: """ @@ -3625,14 +3619,14 @@ def list_scenarios_by_model(self, model_name: str) -> list[str]: WHERE membership.parent_object_id = ? and t_class.name = ? """ result = self.query(query, (parent_object_id, ClassEnum.Scenario)) - return [row[0] for row in result] if result else [] + return [cast(str, row[0]) for row in result] if result else [] def list_units(self) -> list[dict[int, str]]: """List all available units in the database.""" query_string = "SELECT unit_id, value from t_unit" result = self.query(query_string) assert result - return [{d[0]: d[1]} for d in result] + return [{cast(int, d[0]): cast(str, d[1])} for d in result] def list_valid_properties( self, @@ -3687,7 +3681,7 @@ def list_valid_properties( query = "SELECT name from t_property where collection_id = ?" result = self.query(query, (collection_id,)) assert result - return [d[0] for d in result] + return [cast(str, d[0]) for d in result] def list_valid_properties_report( self, @@ -3730,9 +3724,15 @@ def list_valid_properties_report( query = "SELECT name from t_property_report where collection_id = ?" result = self.query(query, (collection_id,)) assert result - return [d[0] for d in result] + return [cast(str, d[0]) for d in result] - def query(self, query_string: str, params: tuple[Any, ...] | dict[str, Any] | None = None) -> list[Any]: + def query( + self, + query_string: str, + params: tuple[int | float | str | bytes | None, ...] + | dict[str, int | float | str | bytes | None] + | None = None, + ) -> list[tuple[int | float | str | bytes | None, ...]]: """Execute a read-only query and return all results. Executes a SQL SELECT query against the database and returns the results. @@ -3858,9 +3858,11 @@ def to_xml(self, target_path: str | Path) -> bool: if not rows: continue column_types_tuples = self.query(f"SELECT name, type FROM pragma_table_info('{table_name}')") - column_types: dict[str, str] = {key: value for key, value in column_types_tuples} + column_types: dict[str, str] = { + cast(str, key): cast(str, value) for key, value in column_types_tuples + } logger.trace("Adding {} to {}", table_name, target_path) - xml_handler.create_table_element(rows, column_types, table_name) + xml_handler.create_table_element(rows, column_types, cast(str, table_name)) xml_handler.to_xml(target_path) @@ -3946,7 +3948,7 @@ def update_object( object_id = self.get_object_id(class_enum, object_name) set_clauses = ["name = ?"] - params: list[Any] = [new_name] + params: list[str | int] = [new_name] if new_category is not None: category_id = self.get_category_id(class_enum, new_category) diff --git a/src/plexosdb/db_manager.py b/src/plexosdb/db_manager.py index 7dbca7e..8f98734 100644 --- a/src/plexosdb/db_manager.py +++ b/src/plexosdb/db_manager.py @@ -1,13 +1,15 @@ """SQLite database manager.""" import sqlite3 -from collections.abc import Callable, Generator, Iterator +from collections.abc import Callable, Generator, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass +from loguru import logger from pathlib import Path -from typing import Any, cast, overload +from types import TracebackType +from typing import TypeAlias, cast, overload -from loguru import logger +_SQLiteParam: TypeAlias = int | float | str | bytes | None @dataclass(slots=True) @@ -219,7 +221,9 @@ def close(self) -> None: # Always null the connection reference self._con = None - def execute(self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None) -> bool: + def execute( + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> bool: """Execute a SQL statement that doesn't return results. Each execution is its own transaction unless used within a transaction context. @@ -263,7 +267,9 @@ def execute(self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = logger.error(f"Rollback error: {rb_error}") return False - def executemany(self, query: str, params_seq: list[tuple[Any, ...]] | list[dict[str, Any]]) -> bool: + def executemany( + self, query: str, params_seq: Sequence[tuple[_SQLiteParam, ...]] | Sequence[dict[str, _SQLiteParam]] + ) -> bool: """Execute a SQL statement with multiple parameter sets. Parameters @@ -342,9 +348,9 @@ def executescript(self, script: str) -> bool: def iter_query( self, query: str, - params: tuple[Any, ...] | dict[str, Any] | None = None, + params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None, batch_size: int = 1000, - ) -> Iterator[tuple[Any, ...]]: + ) -> Iterator[tuple[_SQLiteParam, ...]]: """Execute a read-only query and return an iterator of results. This is memory-efficient for large result sets. Use only for SELECT statements. @@ -414,7 +420,7 @@ def last_insert_rowid(self) -> int: def list_table_names(self) -> list[str]: """Return a list of current table names on the database.""" sql = "SELECT name FROM sqlite_master WHERE type ='table'" - return [r[0] for r in self.fetchall(sql)] + return [cast(str, r[0]) for r in self.fetchall(sql)] def optimize(self) -> bool: """Run optimization routines on the database. @@ -454,41 +460,30 @@ def _validate_query_type(self, query: str) -> None: if first_word in write_keywords: raise ValueError(f"Use execute() for {first_word} statements, not query()") - # Add generic type support for query results @overload - def query(self, query: str, params: None = None) -> list[tuple[Any, ...]]: - """Read-only queries without parameters use the default binding.""" - ... + def query(self, query: str, params: None = None) -> list[tuple[_SQLiteParam, ...]]: ... @overload - def query(self, query: str, params: tuple[Any, ...] | dict[str, Any]) -> list[tuple[Any, ...]]: - """Read-only queries that bind positional or named parameters.""" - ... - def query( - self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> list[tuple[Any, ...]]: - """Execute a read-only query and return all results. + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] + ) -> list[tuple[_SQLiteParam, ...]]: ... - Note: This method should ONLY be used for SELECT statements. - For INSERT/UPDATE/DELETE, use execute() instead. + def query( + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> list[tuple[_SQLiteParam, ...]]: + """Execute a read-only SELECT query and return all results. Parameters ---------- query : str - SQL query to execute (SELECT statements only) + SQL SELECT statement. params : tuple or dict, optional - Parameters to bind to the query + Parameters to bind to the query. Returns ------- - list - Query results (tuples or named tuples based on initialization) - - Raises - ------ - sqlite3.Error - If a database error occurs + list[tuple] + All result rows. """ self._validate_query_type(query) cursor = self.connection.cursor() @@ -496,54 +491,19 @@ def query( cursor.execute(query, params or tuple()) return cursor.fetchall() except sqlite3.Error: - # Let the caller handle database errors raise finally: cursor.close() def fetchall( - self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> list[tuple[Any, ...]]: - """Execute a query and return all results as a list of rows. - - This method is a standard DB-API style alias for query(). - - Parameters - ---------- - query : str - SQL query to execute (SELECT statements only) - params : tuple or dict, optional - Parameters to bind to the query - - Returns - ------- - list - All rows (as tuples or named tuples based on row_factory setting) - - Raises - ------ - sqlite3.Error - If a database error occurs - - See Also - -------- - query : Equivalent method with PlexosDB-specific naming - fetchone : Get only the first row of results - fetchall_dict : Return results as dictionaries - - Examples - -------- - >>> db = SQLiteManager() - >>> db.execute("CREATE TABLE test (id INTEGER, name TEXT)") - >>> db.execute("INSERT INTO test VALUES (1, 'Alice')") - >>> db.fetchall("SELECT * FROM test") - [(1, 'Alice')] - """ + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> list[tuple[_SQLiteParam, ...]]: + """Alias for query(); returns all rows as a list of tuples.""" return self.query(query, params) def fetchall_dict( - self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> list[dict[str, Any]]: + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> list[dict[str, _SQLiteParam]]: """Execute a query and return all results as a list of dictionaries. Parameters @@ -555,7 +515,7 @@ def fetchall_dict( Returns ------- - list[dict[str, Any]] + list[dict[str, _SQLiteParam]] All rows as dictionaries with column names as keys Raises @@ -591,8 +551,11 @@ def fetchall_dict( cursor.close() def fetchmany( - self, query: str, size: int = 1000, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> list[tuple[Any, ...]]: + self, + query: str, + size: int = 1000, + params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None, + ) -> list[tuple[_SQLiteParam, ...]]: """Execute a query and return a specified number of rows. Parameters @@ -640,7 +603,9 @@ def fetchmany( finally: cursor.close() - def fetchone(self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None) -> Any | None: + def fetchone( + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> tuple[_SQLiteParam, ...] | None: """Execute a query and return only the first result row. Parameters @@ -688,8 +653,8 @@ def fetchone(self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = cursor.close() def fetchone_dict( - self, query: str, params: tuple[Any, ...] | dict[str, Any] | None = None - ) -> dict[str, Any] | None: + self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None + ) -> dict[str, _SQLiteParam] | None: """Execute a query and return only the first result row as a dictionary. Parameters @@ -701,7 +666,7 @@ def fetchone_dict( Returns ------- - dict[str, Any] or None + dict[str, _SQLiteParam] or None First row as dictionary with column names as keys, or None if no results Raises @@ -740,9 +705,9 @@ def fetchone_dict( def iter_dicts( self, query: str, - params: tuple[Any, ...] | dict[str, Any] | None = None, + params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None, batch_size: int = 1000, - ) -> Iterator[dict[str, Any]]: + ) -> Iterator[dict[str, _SQLiteParam]]: """Execute a read-only query and yield results as dictionaries. This is memory-efficient for large result sets. Each row is returned @@ -759,7 +724,7 @@ def iter_dicts( Yields ------ - dict[str, Any] + dict[str, _SQLiteParam] One database row at a time as a dictionary Raises @@ -833,7 +798,7 @@ def transaction(self) -> Generator["SQLiteManager", None, None]: def insert_records( self, table_name: str, - records: dict[str, Any] | list[dict[str, Any]], + records: dict[str, _SQLiteParam] | list[dict[str, _SQLiteParam]], ) -> bool: """Insert records into a table using dictionaries with column names as keys. @@ -891,7 +856,10 @@ def __enter__(self) -> "SQLiteManager": return self def __exit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Automatically close connection when exiting context.""" self.close() diff --git a/src/plexosdb/types.py b/src/plexosdb/types.py new file mode 100644 index 0000000..4d383d1 --- /dev/null +++ b/src/plexosdb/types.py @@ -0,0 +1,10 @@ +"""Shared type aliases for the property insertion pipeline.""" + +from datetime import datetime +from typing import TypeAlias + +PropValue: TypeAlias = int | float | str | None +MetadataValue: TypeAlias = str | int | float | datetime | None +PropertyParams: TypeAlias = list[tuple[int, int, PropValue]] +MetadataMap: TypeAlias = dict[tuple[int, int, PropValue, int], dict[str, MetadataValue]] +DataIdMap: TypeAlias = dict[tuple[int, int, PropValue, int], tuple[int, str]] diff --git a/src/plexosdb/utils.py b/src/plexosdb/utils.py index 445ff87..efa0687 100644 --- a/src/plexosdb/utils.py +++ b/src/plexosdb/utils.py @@ -8,12 +8,13 @@ from datetime import datetime from importlib.resources import files from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar, cast from loguru import logger from .enums import ClassEnum from .exceptions import NotFoundError +from .types import DataIdMap, MetadataMap, PropValue, PropertyParams SQLITE_INT64_MIN = -(2**63) SQLITE_INT64_MAX = 2**63 - 1 @@ -27,14 +28,17 @@ class PreparedPropertiesResult: """Prepared inputs for bulk property insertion.""" - params: list[tuple[int, int, Any]] + params: PropertyParams collection_properties: list[tuple[str, int]] - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]] + metadata_map: MetadataMap normalized_records: list[dict[str, Any]] deprecated_format_used: bool -def batched(iterable: Iterable[Any], n: int) -> Iterator[tuple[Any, ...]]: +_T = TypeVar("_T") + + +def batched(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]: """Implement batched iterator. https://docs.python.org/3/library/itertools.html#itertools.batched @@ -43,7 +47,7 @@ def batched(iterable: Iterable[Any], n: int) -> Iterator[tuple[Any, ...]]: return iter(lambda: tuple(islice(it, n)), ()) -def validate_string(value: str) -> Any: +def validate_string(value: str) -> int | float | str | bool | None: """Validate string and convert it to python object. This function also tries to parse floats or ints. @@ -64,9 +68,7 @@ def validate_string(value: str) -> Any: parsed_int = int(value) if SQLITE_INT64_MIN <= parsed_int <= SQLITE_INT64_MAX: return parsed_int - # Keep oversized integers as strings to avoid sqlite3 OverflowError - # during executemany bindings. - return str(parsed_int) + return value # oversized integer: keep as string for SQLite safety except ValueError: pass try: @@ -78,7 +80,7 @@ def validate_string(value: str) -> Any: if value == "false" or value == "FALSE": return False try: - value = ast.literal_eval(value) + return cast("int | float | str | bool | None", ast.literal_eval(value)) except Exception: logger.trace("Could not parse {}", value) return value @@ -111,12 +113,13 @@ def normalize_names(*args: str | Iterable[str]) -> list[str]: ValueError If the input is neither a string nor an iterable of strings """ - names: Iterable[Any] if len(args) == 1 and hasattr(args[0], "__iter__") and not isinstance(args[0], str): - names = args[0] - else: - names = args - return list(set(str(name) for name in names if name is not None)) + maybe_iter = args[0] + if not isinstance(maybe_iter, Iterable): + raise ValueError("Input must be a string or an iterable of strings") + return list({str(name) for name in maybe_iter if name is not None}) + + return list({str(name) for name in args if name is not None}) def get_sql_query(query_name: str) -> str: @@ -138,9 +141,9 @@ def get_sql_query(query_name: str) -> str: def prepare_sql_data_params( records: list[dict[str, float]], - memberships: list[dict[str, Any]], + memberships: list[dict[str, int | str]], property_mapping: list[tuple[str, int]], -) -> list[tuple[int, int, Any]]: +) -> PropertyParams: """Create list of tuples for data ingestion. Parameters @@ -155,19 +158,24 @@ def prepare_sql_data_params( Returns ------- - list[tuple[int, int, Any]] + list[tuple[int, int, PropValue]] List of tuples containing (membership_id, property_id, value) for database insertion """ property_id_map = {prop: pid for prop, pid in property_mapping} - name_to_membership = {membership["name"]: membership["membership_id"] for membership in memberships} - return [ - (name_to_membership[record["name"]], property_id_map[prop], value) - for record in records - if record["name"] in name_to_membership - for prop, value in record.items() - if prop != "name" and prop in property_id_map - ] + name_to_membership: dict[str, int] = { + cast(str, membership["name"]): cast(int, membership["membership_id"]) for membership in memberships + } + return cast( + PropertyParams, + [ + (name_to_membership[cast(str, record["name"])], property_id_map[prop], value) + for record in records + if cast(str, record["name"]) in name_to_membership + for prop, value in record.items() + if prop != "name" and prop in property_id_map + ], + ) def create_membership_record( @@ -310,26 +318,24 @@ def plan_property_inserts( def _fetch_collection_properties(db: PlexosDB, *, collection_id: int) -> list[tuple[str, int]]: """Fetch property rows for a collection as (name, id) tuples.""" - return db.query(f"select name, property_id from t_property where collection_id={collection_id}") + return cast( + list[tuple[str, int]], + db.query(f"select name, property_id from t_property where collection_id={collection_id}"), + ) -def _resolve_membership_map( +def _fetch_memberships( db: PlexosDB, - normalized_records: list[dict[str, Any]], *, + component_names: tuple[str, ...], object_class: ClassEnum, parent_class: ClassEnum, collection: CollectionEnum, -) -> dict[str, int]: - """Resolve membership ids for each object name.""" - component_names = tuple({d["name"] for d in normalized_records if d.get("name") is not None}) - if not component_names: - return {} - - memberships: list[tuple[str, int]] | list[dict[str, Any]] +) -> list[tuple[Any, ...]] | list[dict[str, Any]]: + """Fetch memberships for component names under the requested hierarchy.""" if parent_class == ClassEnum.System: try: - memberships = db.get_memberships_system( + return db.get_memberships_system( component_names, object_class=object_class, collection=collection, @@ -340,24 +346,65 @@ def _resolve_membership_map( f"Objects not found: {missing}. Add them with `add_object` or `add_objects` before " "adding properties." ) from exc - else: - collection_id = db.get_collection_id( - collection, parent_class_enum=parent_class, child_class_enum=object_class - ) - parent_class_id = db.get_class_id(parent_class) - child_class_id = db.get_class_id(object_class) - placeholders = ",".join("?" for _ in component_names) - query = f""" - SELECT child_object.name, mem.membership_id - FROM t_membership AS mem - INNER JOIN t_object AS child_object ON child_object.object_id = mem.child_object_id - WHERE mem.parent_class_id = ? - AND mem.child_class_id = ? - AND mem.collection_id = ? - AND child_object.name IN ({placeholders}) - """ - params: tuple[Any, ...] = (parent_class_id, child_class_id, collection_id, *component_names) - memberships = db._db.fetchall(query, params) + + collection_id = db.get_collection_id( + collection, parent_class_enum=parent_class, child_class_enum=object_class + ) + parent_class_id = db.get_class_id(parent_class) + child_class_id = db.get_class_id(object_class) + placeholders = ",".join("?" for _ in component_names) + query = f""" + SELECT child_object.name, mem.membership_id + FROM t_membership AS mem + INNER JOIN t_object AS child_object ON child_object.object_id = mem.child_object_id + WHERE mem.parent_class_id = ? + AND mem.child_class_id = ? + AND mem.collection_id = ? + AND child_object.name IN ({placeholders}) + """ + params = (parent_class_id, child_class_id, collection_id, *component_names) + return db._db.fetchall(query, params) + + +def _membership_entry(membership: tuple[Any, ...] | dict[str, Any]) -> tuple[str, int] | None: + """Parse a membership row into a validated (object_name, membership_id) pair.""" + if isinstance(membership, dict): + object_name = membership["name"] + membership_id = membership["membership_id"] + if isinstance(object_name, str) and isinstance(membership_id, int): + return object_name, membership_id + return None + + if ( + isinstance(membership, tuple) + and len(membership) > 1 + and isinstance(membership[0], str) + and isinstance(membership[1], int) + ): + return membership[0], membership[1] + return None + + +def _resolve_membership_map( + db: PlexosDB, + normalized_records: list[dict[str, Any]], + *, + object_class: ClassEnum, + parent_class: ClassEnum, + collection: CollectionEnum, +) -> dict[str, int]: + """Resolve membership ids for each object name.""" + component_names = tuple({d["name"] for d in normalized_records if d.get("name") is not None}) + if not component_names: + return {} + + memberships = _fetch_memberships( + db, + component_names=component_names, + object_class=object_class, + parent_class=parent_class, + collection=collection, + ) if not memberships: missing = ", ".join(sorted(name for name in component_names if name)) @@ -369,12 +416,11 @@ def _resolve_membership_map( name_to_membership: dict[str, int] = {} ambiguous_objects: set[str] = set() for membership in memberships: - if isinstance(membership, dict): - object_name = membership["name"] - membership_id = membership["membership_id"] - else: - object_name = membership[0] - membership_id = membership[1] + parsed_membership = _membership_entry(membership) + if parsed_membership is None: + continue + + object_name, membership_id = parsed_membership existing_membership_id = name_to_membership.get(object_name) if existing_membership_id is not None and existing_membership_id != membership_id: ambiguous_objects.add(object_name) @@ -490,10 +536,10 @@ def _build_property_rows( *, name_to_membership: dict[str, int], property_id_map: dict[str, int], -) -> tuple[list[tuple[int, int, Any]], dict[tuple[int, int, Any, int], dict[str, Any]]]: +) -> tuple[PropertyParams, MetadataMap]: """Build parameter tuples and metadata for normalized records.""" - params: list[tuple[int, int, Any]] = [] - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]] = {} + params: PropertyParams = [] + metadata_map: MetadataMap = {} for record in normalized_records: membership_id = name_to_membership.get(record["name"]) @@ -523,10 +569,10 @@ def _build_property_rows( def insert_property_values( db: PlexosDB, - params: list[tuple[int, int, Any]], + params: PropertyParams, *, - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]] | None = None, -) -> dict[tuple[int, int, Any, int], tuple[int, str]]: + metadata_map: MetadataMap | None = None, +) -> DataIdMap: """Insert property data and return mapping of data IDs to object names. Parameters @@ -571,9 +617,9 @@ def insert_property_values( chunk, ) for mid, name in rows: - membership_to_name[mid] = name + membership_to_name[cast(int, mid)] = cast(str, name) - data_id_map: dict[tuple[int, int, Any, int], tuple[int, str]] = {} + data_id_map: DataIdMap = {} for i, (membership_id, property_id, value) in enumerate(params): data_id_map[(membership_id, property_id, value, i)] = ( first_id + i, @@ -588,12 +634,12 @@ def insert_property_values( def apply_scenario_tags( db: PlexosDB, - params: list[tuple[int, int, Any]], + params: PropertyParams, /, *, scenario: str, chunksize: int, - data_id_map: dict[tuple[int, int, Any, int], tuple[int, str]] | None = None, + data_id_map: DataIdMap | None = None, ) -> None: """Insert scenario tags for property data. @@ -625,32 +671,32 @@ def apply_scenario_tags( ) if key in data_id_map ] - for batch in batched(tag_rows, chunksize): + for tag_batch in batched(tag_rows, chunksize): db._db.executemany( "INSERT INTO t_tag(data_id, object_id) VALUES (?, ?)", - list(batch), + list(tag_batch), ) else: - for batch in batched(params, chunksize): + for data_batch in batched(params, chunksize): scenario_query = f""" INSERT into t_tag(data_id, object_id) SELECT d.data_id, {scenario_id} FROM t_data d WHERE d.membership_id = ? AND d.property_id = ? AND d.value = ? """ - db._db.executemany(scenario_query, list(batch)) + db._db.executemany(scenario_query, list(data_batch)) def insert_property_texts( db: PlexosDB, - params: list[tuple[int, int, Any]], + params: PropertyParams, /, *, - data_id_map: dict[tuple[int, int, Any, int], tuple[int, str]], + data_id_map: DataIdMap, records: list[dict[str, Any]], field_name: str, text_class: ClassEnum, - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]] | None = None, + metadata_map: MetadataMap | None = None, ) -> None: """Add text data for properties from specified field. @@ -692,8 +738,8 @@ def insert_property_texts( def _persist_metadata_for_data( db: PlexosDB, *, - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]], - data_id_map: dict[tuple[int, int, Any, int], tuple[int, str]], + metadata_map: MetadataMap, + data_id_map: DataIdMap, ) -> None: """Attach band and date metadata for inserted data rows.""" bands_to_insert: list[tuple[int, int]] = [] @@ -711,10 +757,14 @@ def _persist_metadata_for_data( date_to = metadata.get("date_to") if band is not None: - bands_to_insert.append((data_id, band)) + bands_to_insert.append((data_id, cast(int, band))) - _append_date_if_present(dates_from_to_insert, data_id, date_value=date_from, label="date_from") - _append_date_if_present(dates_to_to_insert, data_id, date_value=date_to, label="date_to") + _append_date_if_present( + dates_from_to_insert, data_id, date_value=cast("datetime | None", date_from), label="date_from" + ) + _append_date_if_present( + dates_to_to_insert, data_id, date_value=cast("datetime | None", date_to), label="date_to" + ) if bands_to_insert: db._db.executemany("INSERT INTO t_band(data_id, band_id) VALUES (?, ?)", bands_to_insert) @@ -737,9 +787,9 @@ def _append_date_if_present( def _build_text_lookup( records: list[dict[str, Any]], *, field_name: str -) -> dict[tuple[str, str | None], Any]: +) -> dict[tuple[str, str | None], str | None]: """Create a lookup of object/property combinations to text values.""" - text_map: dict[tuple[str, str | None], Any] = {} + text_map: dict[tuple[str, str | None], str | None] = {} for rec in records: obj_name = rec.get("name") if obj_name is None: @@ -759,16 +809,16 @@ def _build_text_lookup( def _collect_text_rows( - params: list[tuple[int, int, Any]], - data_id_map: dict[tuple[int, int, Any, int], tuple[int, str]], + params: PropertyParams, + data_id_map: DataIdMap, *, - metadata_map: dict[tuple[int, int, Any, int], dict[str, Any]] | None, - text_map: dict[tuple[str, str | None], Any], + metadata_map: MetadataMap | None, + text_map: dict[tuple[str, str | None], str | None], class_id: int, field_name: str, -) -> list[tuple[int, int, Any]]: +) -> list[tuple[int, int, str]]: """Convert params and metadata into t_text insert rows.""" - texts_to_insert: list[tuple[int, int, Any]] = [] + texts_to_insert: list[tuple[int, int, str]] = [] for i, (membership_id, property_id, value) in enumerate(params): row_key = (membership_id, property_id, value, i) @@ -780,22 +830,26 @@ def _collect_text_rows( if metadata_map: row_text = metadata_map.get(row_key, {}).get(field_name) if row_text is not None: - texts_to_insert.append((data_id, class_id, row_text)) + texts_to_insert.append((data_id, class_id, cast(str, row_text))) continue - property_name = metadata_map.get(row_key, {}).get("property_name") if metadata_map else None - lookup_keys = [(obj_name, property_name), (obj_name, None)] + property_name = ( + cast("str | None", metadata_map.get(row_key, {}).get("property_name")) if metadata_map else None + ) + lookup_keys: list[tuple[str, str | None]] = [(obj_name, property_name), (obj_name, None)] for lookup in lookup_keys: if lookup in text_map: - texts_to_insert.append((data_id, class_id, text_map[lookup])) + text_val = text_map[lookup] + if text_val is not None: + texts_to_insert.append((data_id, class_id, text_val)) break return texts_to_insert def build_data_id_map( - db: SQLiteManager, params: list[tuple[int, int, Any]] -) -> dict[tuple[int, int, Any], tuple[int, str]]: + db: SQLiteManager, params: PropertyParams +) -> dict[tuple[int, int, PropValue], tuple[int, str]]: """Build mapping of (membership_id, property_id, value) to (data_id, obj_name). Parameters @@ -817,11 +871,14 @@ def build_data_id_map( JOIN t_object o ON m.child_object_id = o.object_id WHERE d.membership_id = ? AND d.property_id = ? AND d.value = ? """ - data_id_map = {} + data_id_map: dict[tuple[int, int, PropValue], tuple[int, str]] = {} for membership_id, property_id, value in params: result = db.fetchone(data_ids_query, (membership_id, property_id, value)) if result: - data_id_map[(membership_id, property_id, value)] = (result[0], result[1]) + data_id = result[0] + obj_name = result[1] + if isinstance(data_id, int) and isinstance(obj_name, str): + data_id_map[(membership_id, property_id, value)] = (data_id, obj_name) return data_id_map diff --git a/src/plexosdb/xml_handler.py b/src/plexosdb/xml_handler.py index 19655cc..20db199 100644 --- a/src/plexosdb/xml_handler.py +++ b/src/plexosdb/xml_handler.py @@ -88,7 +88,7 @@ def parse( def create_table_element( self, - rows: list[tuple[Any, ...]], + rows: list[tuple[int | float | str | bytes | None, ...]], column_types: dict[str, str], table_name: str, ) -> bool: @@ -113,9 +113,9 @@ def create_table_element( def get_records( self, element_enum: Schema, - *elements: Iterable[str | int], + *elements: str | int, rename_dict: dict[str, str] | None = None, - **tag_elements: Any, + **tag_elements: str | int, ) -> list[dict[str, Any]]: """Return a given element(s) as list of dictionaries.""" if rename_dict is None: @@ -137,9 +137,9 @@ def get_records( def iter( self, element_type: Schema, - *elements: Iterable[str | int], - label: str | None = None, - **tags: Any, + *elements: str | int, + label: str | int | None = None, + **tags: str | int, ) -> Iterable[ET.Element]: """Return elements from the XML based on the type. @@ -200,14 +200,14 @@ def to_xml(self, fpath: str | PathLike[str]) -> bool: return True - def _cache_iter(self, element_type: Schema, **tag_elements: Any) -> Iterator[ET.Element]: + def _cache_iter(self, element_type: Schema, **tag_elements: str | int) -> Iterator[ET.Element]: """Return iterator over cached XML elements matching filters. Parameters ---------- element_type : Schema Schema enum describing the cached element type. - **tag_elements : Any + **tag_elements : str | int Optional tag filters (usually by label) to narrow the results. Returns @@ -232,8 +232,8 @@ def _cache_iter(self, element_type: Schema, **tag_elements: Any) -> Iterator[ET. def _iter_elements( self, element_type: str, - *elements: Any, - **tag_elements: Any, + *elements: str | int, + **tag_elements: str | int, ) -> Iterator[ET.Element]: """Iterate over the xml file. @@ -269,7 +269,7 @@ def _remove_namespace(self, namespace: str) -> None: elem.tag = elem.tag[nsl:] -def xml_query(element_name: str, *tags: Any, **tag_elements: Any) -> str: +def xml_query(element_name: str, *tags: str | int, **tag_elements: str | int) -> str: """Construct XPath query for extracting data from a XML with no namespace. Parameters diff --git a/tests/test_plexosdb_copy_object.py b/tests/test_plexosdb_copy_object.py index 51ccbe8..c2263bf 100644 --- a/tests/test_plexosdb_copy_object.py +++ b/tests/test_plexosdb_copy_object.py @@ -60,12 +60,8 @@ def test_copy_object_copies_date_from_and_date_to(db_base: PlexosDB): date_to=date_to, ) - original_date_from = db.query( - "SELECT date FROM t_date_from WHERE data_id = ?", (original_data_id,) - ) - original_date_to = db.query( - "SELECT date FROM t_date_to WHERE data_id = ?", (original_data_id,) - ) + original_date_from = db.query("SELECT date FROM t_date_from WHERE data_id = ?", (original_data_id,)) + original_date_to = db.query("SELECT date FROM t_date_to WHERE data_id = ?", (original_data_id,)) assert original_date_from == [(date_from.isoformat(),)] assert original_date_to == [(date_to.isoformat(),)] @@ -77,12 +73,8 @@ def test_copy_object_copies_date_from_and_date_to(db_base: PlexosDB): new_data_id = new_data_ids[0] assert new_data_id != original_data_id - copied_date_from = db.query( - "SELECT date FROM t_date_from WHERE data_id = ?", (new_data_id,) - ) - copied_date_to = db.query( - "SELECT date FROM t_date_to WHERE data_id = ?", (new_data_id,) - ) + copied_date_from = db.query("SELECT date FROM t_date_from WHERE data_id = ?", (new_data_id,)) + copied_date_to = db.query("SELECT date FROM t_date_to WHERE data_id = ?", (new_data_id,)) assert copied_date_from == [(date_from.isoformat(),)] assert copied_date_to == [(date_to.isoformat(),)] From 1207f71fb2421d15344a074c286098395db8bc75 Mon Sep 17 00:00:00 2001 From: mvelasqu Date: Wed, 27 May 2026 14:24:44 -0600 Subject: [PATCH 2/2] fix: add missing docstring to formatting --- src/plexosdb/db_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/plexosdb/db_manager.py b/src/plexosdb/db_manager.py index 8f98734..de734ad 100644 --- a/src/plexosdb/db_manager.py +++ b/src/plexosdb/db_manager.py @@ -461,12 +461,16 @@ def _validate_query_type(self, query: str) -> None: raise ValueError(f"Use execute() for {first_word} statements, not query()") @overload - def query(self, query: str, params: None = None) -> list[tuple[_SQLiteParam, ...]]: ... + def query(self, query: str, params: None = None) -> list[tuple[_SQLiteParam, ...]]: + """Execute a read-only SQL query without bound parameters.""" + ... @overload def query( self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] - ) -> list[tuple[_SQLiteParam, ...]]: ... + ) -> list[tuple[_SQLiteParam, ...]]: + """Execute a read-only SQL query with tuple or named parameters.""" + ... def query( self, query: str, params: tuple[_SQLiteParam, ...] | dict[str, _SQLiteParam] | None = None