PyAPplus64/src/PyAPplus64/applus_db.py

173 lines
5.6 KiB
Python

# Copyright (c) 2023 Thomas Tuerk (kontakt@thomas-tuerk.de)
#
# This file is part of PyAPplus64 (see https://www.thomas-tuerk.de/de/pyapplus64).
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
import pyodbc # type: ignore
import logging
from .sql_utils import SqlStatement
from . import sql_utils
from typing import List, Dict, Set, Any, Optional, Callable, Sequence
logger = logging.getLogger(__name__)
class APplusDBSettings:
"""
Einstellungen, mit welcher DB sich verbunden werden soll.
"""
def __init__(self, server: str, database: str, user: str, password: str):
self.server = server
self.database = database
self.user = user
self.password = password
def getConnectionString(self) -> str:
"""Liefert den ODBC Connection-String für die Verbindung.
:return: den Connection-String
"""
return ("Driver={SQL Server Native Client 11.0};"
"Server="+self.server+";"
"Database="+self.database+";"
"UID="+self.user+";"
"PWD="+self.password + ";")
def connect(self) -> pyodbc.Connection:
"""Stellt eine neue Verbindung her und liefert diese zurück.
"""
return pyodbc.connect(self.getConnectionString())
def row_to_dict(row: pyodbc.Row) -> Dict[str, Any]:
"""Konvertiert eine Zeile in ein Dictionary"""
return dict(zip([t[0] for t in row.cursor_description], row))
def _logSQLWithArgs(sql: SqlStatement, *args: Any) -> None:
if args:
logger.debug("executing '{}' with args {}".format(str(sql), str(args)))
else:
logger.debug("executing '{}'".format(str(sql)))
def rawQueryAll(
cnxn: pyodbc.Connection,
sql: SqlStatement,
*args: Any,
apply: Optional[Callable[[pyodbc.Row], Any]] = None) -> Sequence[Any]:
"""
Führt eine SQL Query direkt aus und liefert alle Zeilen zurück.
Wenn apply gesetzt ist, wird die Funktion auf jeder Zeile ausgeführt und das Ergebnis ausgeben, die nicht None sind.
"""
_logSQLWithArgs(sql, *args)
with cnxn.cursor() as cursor:
cursor.execute(str(sql), *args)
rows = cursor.fetchall()
if apply is None:
return rows
else:
res = []
for r in rows:
rr = apply(r)
if not (rr is None):
res.append(rr)
return res
def rawQuery(cnxn: pyodbc.Connection, sql: sql_utils.SqlStatement, f: Callable[[pyodbc.Row], None], *args: Any) -> None:
"""Führt eine SQL Query direkt aus und führt für jede Zeile die übergeben Funktion aus."""
_logSQLWithArgs(sql, *args)
with cnxn.cursor() as cursor:
cursor.execute(str(sql), *args)
for row in cursor:
f(row)
def rawQuerySingleRow(cnxn: pyodbc.Connection, sql: SqlStatement, *args: Any) -> Optional[pyodbc.Row]:
"""Führt eine SQL Query direkt aus, die maximal eine Zeile zurückliefern soll. Diese Zeile wird geliefert."""
_logSQLWithArgs(sql, *args)
with cnxn.cursor() as cursor:
cursor.execute(str(sql), *args)
return cursor.fetchone()
def rawQuerySingleValue(cnxn: pyodbc.Connection, sql: SqlStatement, *args: Any) -> Any:
"""Führt eine SQL Query direkt aus, die maximal einen Wert zurückliefern soll. Dieser Wert oder None wird geliefert."""
_logSQLWithArgs(sql, *args)
with cnxn.cursor() as cursor:
cursor.execute(str(sql), *args)
row = cursor.fetchone()
if row:
return row[0]
else:
return None
def getUniqueFieldsOfTable(cnxn: pyodbc.Connection, table: str) -> Dict[str, List[str]]:
"""
Liefert alle Spalten einer Tabelle, die eindeutig sein müssen.
Diese werden als Dictionary gruppiert nach Index-Namen geliefert.
Jeder Eintrag enthält eine Liste von Feldern, die zusammen eindeutig sein
müssen.
"""
sql = sql_utils.SqlStatementSelect("sys.indexes AS i")
join = sql.addInnerJoin("sys.index_columns AS ic")
join.on.addCondition("i.OBJECT_ID = ic.OBJECT_ID")
join.on.addCondition("i.index_id = ic.index_id")
sql.where.addConditionFieldEq("OBJECT_NAME(ic.OBJECT_ID)", table)
sql.where.addConditionFieldEq("i.is_unique", True)
sql.addFields("i.name AS INDEX_NAME", "COL_NAME(ic.OBJECT_ID,ic.column_id) AS COL")
_logSQLWithArgs(sql)
indices: Dict[str, List[str]] = {}
with cnxn.cursor() as cursor:
cursor.execute(str(sql))
for row in cursor:
cols = indices.get(row.INDEX_NAME, [])
cols.append(sql_utils.normaliseDBfield(row.COL))
indices[row.INDEX_NAME] = cols
return indices
class DBTableIDs():
"""Klasse, die Mengen von IDs gruppiert nach Tabellen speichert"""
def __init__(self) -> None:
self.data: Dict[str, Set[int]] = {}
def add(self, table: str, *ids: int) -> None:
"""
fügt Eintrag hinzu
:param table: die Tabelle
:type table: str
:param id: die ID
"""
table = table.upper()
if not (table in self.data):
self.data[table] = set(ids)
else:
self.data[table].update(ids)
def getTable(self, table: str) -> Set[int]:
"""
Liefert die Menge der IDs für eine bestimmte Tabelle.
:param table: die Tabelle
:type table: str
:return: die IDs
"""
table = table.upper()
return self.data.get(table, set())
def __str__(self) -> str:
return str(self.data)