Source code for laforge.sql

#!/usr/bin/env python3
"""SQL utilities for mid-level interaction. Inspired by pathlib; powered by SQLALchemy.

.. note::

    Supported: MSSQL, MariaDB/MySQL, PostgreSQL, SQLite.
    Supportable: Firebird, Oracle, Sybase.

"""

import logging
import re
import textwrap
from keyword import kwlist
from pathlib import Path

import pandas as pd
import pyparsing
import sqlalchemy as sa
import yaml

from .toolbox import flatten

logger = logging.getLogger(__name__)
logger.debug(__name__)

RESERVED_WORD_FILE = Path(__file__).parent / "data" / "reserved_words.yaml"
RESERVED_WORDS = {x.lower() for x in yaml.safe_load(RESERVED_WORD_FILE.read_text())}
RESERVED_WORDS.update(kwlist)


def is_reserved_word(s):
    try:
        return s.lower() in RESERVED_WORDS
    except AttributeError:
        return False


[docs]class SQLTableNotFound(Exception): pass
[docs]class SQLChannelNotFound(Exception): pass
[docs]class SQLIdentifierProblem(ValueError): pass
[docs]class Channel: """Abstraction from Engine, other static details.""" known_engines = {} known_channels = {} def __init__( self, distro, *, server=None, database=None, schema=None, **engine_kwargs ): """Anything else is captured in engine_kwargs for create_spec""" from .distros import Distro self.distro = Distro(distro) self.server = server self.database = database self.schema = schema self.engine = self._construct_engine(**engine_kwargs) self.save_engine() self.metadata = sa.MetaData(bind=self.engine, schema=self.schema) if self.metadata.bind.url.database: self.database = self.metadata.bind.url.database self.inspector = sa.inspect(self.engine) @classmethod def grab(cls): try: *_, last_channel = iter(cls.known_channels.values()) except ValueError: raise SQLChannelNotFound("No known SQL channels exist.") return last_channel def _construct_engine(self, **engine_kwargs): existing_engine = self.retrieve_engine() if existing_engine: return existing_engine return self.distro.create_engine( server=self.server, database=self.database, engine_kwargs=engine_kwargs ) def retrieve_engine(self): return self._retrieve_engine(repr(self)) @classmethod def _retrieve_engine(cls, key): return cls.known_engines.get(key) def save_engine(self): self._save_engine(self, repr(self), self.engine) @classmethod def _save_engine(cls, channel, key, engine): if key not in cls.known_engines: cls.known_engines[key] = engine if engine not in cls.known_channels: cls.known_channels[engine] = channel
[docs] def execute_statement(self, statement, fetch=False): """Execute SQL (core method) .. todo:: De-messify """ statement = self.clean_up_statement(statement) assert fetch in ("df", "tuples", False) if fetch == "df": try: return pd.read_sql(statement, con=self.engine) except Exception as err: logger.error( "Error reading SQL to DF using %s\nExecuting:\n\n%s\n\n", self.engine, statement, ) raise err final_result = None with self.engine.connect() as cnxn: with cnxn.begin() as transxn: try: result = cnxn.execution_options(autocommit=True).execute(statement) if fetch == "tuples": final_result = result.fetchall() finally: transxn.commit() return final_result
@staticmethod def clean_up_statement(s): s = s.strip() for quote_char in ('"', "'"): while s.startswith(quote_char) and s.endswith(quote_char): s = s.strip(quote_char) return s def find(self, object_pattern="%", schema_pattern="%"): return self.distro.find( channel=self, object_pattern=object_pattern, schema_pattern=schema_pattern ) def __hash__(self): return hash(repr(self)) def __eq__(self, other): return hash(self) == hash(other) def __repr__(self): pieces = ";".join( str(x) for x in (self.distro, self.server, self.database, self.schema) if x ) return f"{self.__class__.__name__}({pieces})"
[docs]def execute(statement, fetch=False, channel=None): """Convenience method, autofetches Channel if possible""" if not channel: channel = Channel.grab() return channel.execute_statement(statement, fetch=fetch)
[docs]class Script: """SQL query string, parsable by 'go' separation and execute()able.""" BATCH_TERMINATOR = "go" _batch_borders = re.compile(r"\n[^\w;\d-]*go\W*?\n", flags=re.IGNORECASE) _terminating_batch_terminator = re.compile(r"(?<=\W)(go\W*)+$", flags=re.IGNORECASE) _terminating_semicolon = re.compile(r"[\s;]+$") def __init__(self, query, channel=None): if not channel: channel = Channel.grab() self.channel = channel self.query = query self._query_string = self._remove_comments(query) self.parsed = self._parse(self._query_string) @classmethod def _remove_comments(cls, input_string): lines = cls._remove_multiline_sql_comment(input_string).splitlines() new_lines = (cls._remove_double_dash_sql_comment(l) for l in lines) return "\n".join(new_lines) def _parse(self, query): flattened_group = flatten(self._construct_statements(query)) return [s for s in flattened_group if self._is_useful_statement(s)] def _construct_statements(self, query): return [ self._normalize_batch_end(batch) for batch in self._break_into_batches(query) ] def _break_into_batches(self, query): """.. todo:: Parse multiple go in a row more elegantly, avoid injecting line breaks. Cf. _format_statement """ batches = self._batch_borders.split(query) return [batch.strip() for batch in batches] @classmethod def _normalize_batch_end(cls, batch): batch = batch.strip() batch = cls._terminating_semicolon.sub("", batch) batch = cls._terminating_batch_terminator.sub(r"", batch) batch = cls._terminating_semicolon.sub("", batch) return batch + ";" @staticmethod def _is_useful_statement(s): return re.findall("[A-Za-z]", s) @staticmethod def _remove_multiline_sql_comment(input_string): cleansed = ( pyparsing.nestedExpr("/*", "*/").suppress().transformString(input_string) ) return cleansed @staticmethod def _remove_double_dash_sql_comment(single_line): double_dash_pattern = re.compile(r"\s*?--.*") return double_dash_pattern.sub("", single_line) # Public API
[docs] def execute(self, statements=None): """Execute itsel(f|ves)""" statements = statements or self.parsed for i, statement in enumerate(statements): logger.debug( "Executing statement %s of %s, %s lines long: %s", i + 1, len(statements), statement.count("\n") + 1, textwrap.shorten(statement, 80), ) self.channel.execute_statement(statement)
[docs] def to_table(self): """Executes all and tries to return a DataFrame for the result of the final query. This is one of two ways that laforge retrieves tables. .. warning:: This is limited by the capacity of Pandas to retrieve only the final result. For Microsoft SQL Server, if a lengthy set of queries is desired, the most reliable approach appears to be a single final query after a 'go' as a batch gterminator. .. warning:: This will rename columns that do not conform to naming standards. """ assert isinstance(self.query, str) logger.debug("Executing SQL: %s", textwrap.shorten(self.query, 80)) logger.debug( "%s queries prior to retrieval; final for df is %s chars long.", len(self.parsed[:-1]), len(self.parsed[-1]), ) for stmt in self.parsed[:-1]: self.channel.execute_statement(stmt, fetch=False) df = self.channel.execute_statement(self.parsed[-1], fetch="df") rows, cols = df.shape logger.debug("Received %s rows, %s columns.", rows, cols) df = fix_bad_columns(df) return df
def read(self): return self.to_table() def __len__(self): return len(self.parsed) def __str__(self): return "Script: " + "\n".join(q for q in self.parsed) def __repr__(self): return "<{} of {} statement{}>".format( self.__class__.__name__, len(self), ["s", ""][len(self) == 1] )
[docs]class Table: """Represents a SQL table, featuring methods to read/write DataFrames. .. todo :: Factor out to superclass to allow views """ def __init__(self, name, channel=None, **kwargs): self.channel = channel if channel else Channel.grab() self.metadata = self.channel.metadata self.distro = self.channel.distro identifiers = self._parse_args(name, kwargs) try: self.__name = identifiers["name"] except KeyError: raise SQLIdentifierProblem("Must provide table name.") self.__schema = identifiers.get("schema", self.channel.schema) self.__database = identifiers.get("database", self.channel.database) self.__server = self.channel.server self.__metal = None @property def metal(self): if self.__metal is None: self.__metal = sa.Table( self.name, self.metadata, autoload=True, autoload_with=self.channel.engine, extend_existing=True, ) return self.__metal @property def identifiers(self): return { "server": self.server, "database": self.database, "schema": self.schema, "name": self.name, } @property def server(self): return self.__server @property def database(self): return self.__database @property def schema(self): return self.__schema @property def name(self): return self.__name def _parse_args(self, name, kwargs): id_dict = { "schema": kwargs.get("schema", self.channel.schema), "database": kwargs.get("database", self.channel.database), "server": kwargs.get("server", self.channel.server), } parts_dict = { k: v for k, v in zip( ["name", "schema", "database", "server"], reversed(name.split(".")) ) if v } id_dict.update(parts_dict) if "name" in id_dict: Identifier(id_dict["name"]).check() return id_dict # API def exists(self): insp = sa.inspect(self.channel.engine) tables = insp.get_table_names(schema=self.schema or None) return self.name in tables def resolve(self, strict=False): if strict and not self.exists(): raise SQLTableNotFound("{} does not exist.".format(self)) return self.distro.resolver.format(**self.identifiers)
[docs] def write(self, df, if_exists="replace"): """From DataFrame, create a new table and fill it with values""" try: if df.empty: raise RuntimeError("DataFrame to write is empty!") except AttributeError: raise RuntimeError(f"Can only write DataFrame, not {type(df)}") if "" in df.columns: df = fix_bad_columns(df) dtypes = self.distro.determine_dtypes(df) df.to_sql( name=self.name, con=self.channel.engine, schema=self.schema or None, # sqlite can't use "" or it craps out if_exists=if_exists, index=False, dtype=dtypes, )
[docs] def read(self): """Return the full table as a DataFrame""" select_all = sa.select([self.metal]) return pd.read_sql(select_all, con=self.metadata.bind)
[docs] def drop(self, ignore_existence=False): """Delete the table within SQL""" if self.exists(): self.metal.drop() elif not ignore_existence: raise SQLTableNotFound(self) assert not self.exists() logger.debug("%s dropped.", self)
def __len__(self): count_query = sa.select([sa.func.count()]).select_from(self.metal) return int(Scalar(self.metadata.bind.execute(count_query))) def __str__(self): return self.resolve(strict=False) def __repr__(self): return f"Table('{self}')" def __eq__(self, other): return hash(self) == hash(other) def __hash__(self): return hash(self.metal)
[docs]class Scalar: """Little helper to produce clearly typed single (upper left) ResultProxy result.""" def __init__(self, prox): self.item = prox.first()[0] prox.close() def __int__(self): return int(self.item) def __str__(self): return str(self.item)
[docs]class Identifier: """Single standardized variable/database/schema/table/column/anything identifier. .. todo:: class InvalidIdentifierError relay_id_problem(identifier, action, reason=None, replacement=None) """ # A-Z, a-z, 0-9, @ # $ _ mostly okay for table name following first letter VALID_CHARACTERS_AFTER_FIRST = r"[\w@_#$]+" # Unless it's the first character, which can't be a number or $ # So a complete name needs to be a block starting with a proper lead character # and (possibly) continuing with valid characters VALID_NAME_PATTERN = r"[A-Za-z@_#][\w@_#$]*" WHITELIST = [":memory:", "tables"] BLACKLIST = ["?column?", "", None] def __init__(self, user_input, extra=None): """ :param user_input: Something that can be converted into a useful string. :param extra: Additional something that could be usefully appended/exchanged with the native identifier. .. todo:: Can this be fully idempotent? .. todo:: Re-validate fallbacks? """ self.original = user_input if self.original in self.WHITELIST: self.normalized = self.original elif self.original in self.BLACKLIST: self.normalized = self._normalize("", leading_underscore=False, extra=extra) else: stringed_input = str(self.original).strip() self.normalized = self._normalize( stringed_input, leading_underscore=stringed_input.strip().startswith("_"), extra=extra, ) def check(self): if self.normalized == self.original: return True logger.debug( "Identifier [%s] suggested normalization: [%s].", self.original, self.normalized, ) return False def _normalize(self, working, leading_underscore, extra): working = self._replace_characters(working) working = self._fix(working, extra) working = self._stylize(working, leading_underscore) working = self._shorten(working) working = self._amend(working) return working @classmethod def _replace_characters(cls, attempt, replacement="_"): # Strip out non-valid characters, replace with replacement attempt = replacement.join( re.findall(cls.VALID_CHARACTERS_AFTER_FIRST, attempt) ) return attempt def _fix(self, s, extra=None): hit = re.search(self.VALID_NAME_PATTERN, s) if hit: return hit.group(0) # Lack of match: no usable first character (could be blank/all specials) # But we do need something to work with if not s and extra is None: raise SQLIdentifierProblem("Empty identifier requires extra input") fixed = f"column_{s or extra}" logger.warning(f"No useful name from: [{s}], replaced with: [{fixed}]") return fixed @staticmethod def _stylize(attempt, leading_underscore): # Don't add a leading underscore if it wasn't there already (junk replacement) if not leading_underscore: attempt = attempt.lstrip("_") return attempt @staticmethod def _shorten(s, max_length=62, warning_length=255): """Cut off lengthy identifiers. .. note :: SQL Server allows 128 (116 for temp); postgre 63, MySQL 64. This currently uses 62 as a lowest common denominator. :param attempt: :param max_length: (Default value = 62) :param warning_length: (Default value = 255) """ shortened = s[:max_length] if shortened == s: return s display_wings = (max_length + 6) // 2 display = f"{s[:display_wings]}[...]{s[-display_wings:]}" if len(s) >= warning_length: logger.warning("Too long to be a reasonable identifier name: %s", display) logger.warning("Truncated [%s] into [%s].", display, shortened) return shortened @classmethod def _amend(cls, s, suffix="_"): """ :param attempt: :param suffix: (Default value = "_") """ initial_attempt = s while is_reserved_word(s): s = s + suffix assert len(s) > len(initial_attempt) if initial_attempt != s: logger.debug("Reserved word '%s' amended to '%s'", initial_attempt, s) return s def __str__(self): return self.normalized
def fix_bad_columns(df): badnames = set(Identifier.BLACKLIST).intersection(df.columns) leading_numbers = any(x for x in df.columns if x[:1].isdigit()) if not badnames and not leading_numbers: return df new_columns = [Identifier(c, extra=i).normalized for i, c in enumerate(df.columns)] new_df = df.set_axis(labels=new_columns, axis="columns", inplace=False) return new_df """ Copyright 2019 Matt VanEseltine. This file is part of laforge. laforge is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. laforge is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with laforge. If not, see <https://www.gnu.org/licenses/>. """