Source code for laforge.distros

import logging
import math
import re
from importlib import import_module
from pathlib import Path
from urllib import parse

import sqlalchemy as sa

from .sql import Script, Table

logger = logging.getLogger(__name__)
logger.debug(__name__)


[docs]class SQLDistroNotFound(Exception): pass
[docs]class Distro: """Base class for SQL Distros ..note:: http://troels.arvin.dk/db/rdbms/ """ driver = None name = "n/a" human_name = "n/a" NUMERIC_RANGES = { sa.types.SMALLINT: 2 ** 15 - 101, sa.types.INT: 2 ** 31 - 101, sa.types.BIGINT: 2 ** 63 - 101, } NUMERIC_PADDING_FACTOR = 10 find_template = """--ANSI.find() select table_schema, table_name from information_schema.tables where table_schema like '{schema_pattern}' and table_name like '{object_pattern}'; """ untouchable_identifiers = [] def __init__(self, _): try: import_module(self.driver) except ModuleNotFoundError: logger.warning(f"No driver ({self.driver}) to support distro {self.name}") def __new__(cls, name): retrieved_subclass = cls._get_subclass_given(name) return super().__new__(retrieved_subclass) @classmethod def _get_subclass_given(cls, name): """ ..note :: These pop; a user-defined subclass will come up last. """ distro_exact = cls._get_exact(name) if distro_exact: return distro_exact.pop() distro_fuzzy = cls._get_fuzzy(name) if len(distro_fuzzy) != 1: cls._fail_to_find_requested_distro(name) return distro_fuzzy.pop() @classmethod def _get_exact(cls, s): return [x for x in cls.__subclasses__() if x.name.lower() == str(s).lower()] @classmethod def _get_fuzzy(cls, s): return [ x for x in cls.__subclasses__() if re.match(x.regex, str(s), flags=re.IGNORECASE) ] @classmethod def _fail_to_find_requested_distro(cls, given_name): err_msg = ( f"Given input `{given_name}`, " + f"could not match distribution to known: {cls.known()}" ) raise SQLDistroNotFound(err_msg) @classmethod def known(cls): return {x.name: x.human_name for x in cls.__subclasses__()} def determine_dtypes(self, df): new_dtypes = {} for column in df.columns: col = df[column].copy() col.dropna() sql_specification = self._determine_dtype(df, column) if sql_specification: new_dtypes[column] = sql_specification return new_dtypes def _determine_dtype(self, df, column): if df[column].dtype in ("object", "unicode_", "string_"): return self._get_varchar_spec(df, column) if df[column].dtype in ("float64",): return self._check_float_spec(df, column) if df[column].dtype in ("int64",): return self._check_integer_spec(df, column) return None @classmethod def _check_float_spec(cls, df, column): if not df[column].apply(float.is_integer).all(): return None logger.debug("Demoting column [%s] from float...", column) return cls._check_integer_spec(df, column) @classmethod def _check_integer_spec(cls, df, column): observed_range = [df[column].min(), df[column].max()] for sqltype in (sa.types.SMALLINT, sa.types.INT, sa.types.BIGINT): if cls._well_within_range(observed_range, sqltype): logger.debug( "Column [%s] numeric type determined to be %s.", column, sqltype ) return sqltype return None @classmethod def _well_within_range(cls, observed, sqltype): limit = cls.NUMERIC_RANGES[sqltype] # Convert to Python's int because numpy's int64 will overflow max_observed = int(max(abs(x) for x in observed)) test_level = max_observed * cls.NUMERIC_PADDING_FACTOR valid = bool(-limit < -test_level and test_level < limit) return valid @classmethod def _get_varchar_spec(cls, df, column): try: stringed = df[column].str.encode(encoding="utf-8").str except AttributeError: return None observed_len = stringed.len().max() return cls._create_varchar_spec(observed_len) @staticmethod def _create_varchar_spec(max_length): try: rounded = round_up(max_length, nearest=50) except ValueError: return None return sa.VARCHAR(rounded) def find(self, channel, object_pattern="%", schema_pattern="%"): # Lone % causes ValueError on unsupported format character 0x27 object_pattern = object_pattern.replace(r"%", r"%%") schema_pattern = schema_pattern.replace(r"%", r"%%") object_query = self.find_template.format( schema_pattern=schema_pattern, object_pattern=object_pattern ) object_script = Script(object_query, channel=channel) df = object_script.to_table()[["table_schema", "table_name"]] result_list = [ Table(x.table_name, schema=x.table_schema, channel=channel) for x in df.itertuples(index=False) ] return result_list def create_spec(self, *, server, database, engine_kwargs): raise NotImplementedError def create_engine(self, *, server, database, engine_kwargs): url, final_engine_kwargs = self.create_spec( server=server, database=database, engine_kwargs=engine_kwargs ) return self._create_engine(url, **final_engine_kwargs) @classmethod def _create_engine(cls, url, **engine_kwargs): return sa.create_engine(url, **engine_kwargs) @property def resolver(self): # Must be implemented per-distro by subclass raise NotImplementedError def __hash__(self): return hash(self.name) def __eq__(self, other): return hash(self) == hash(other) def __str__(self): return self.name def __repr__(self): return f"Distro('{self.name}')"
[docs]class MySQL(Distro): name = "mysql" human_name = "MySQL/MariaDB" regex = "^(my|maria).*" driver = "pymysql" resolver = "{schema}.`{name}`" def create_spec(self, *, server, database, engine_kwargs): username = engine_kwargs.pop("username") password = engine_kwargs.pop("password") url = ( f"{self.name}+{self.driver}:" + f"//{username}:{password}@" + f"{server}/{database}" + f"?charset=utf8mb4" ) return (url, engine_kwargs)
[docs]class PostgresQL(Distro): name = "postgresql" human_name = "PostgreSQL" regex = r"^post.*" driver = "psycopg2" resolver = "{schema}.{name}" def create_spec(self, *, server, database, engine_kwargs): username = engine_kwargs.pop("username") password = engine_kwargs.pop("password") url = f"{self.name}+{self.driver}://{username}:{password}@{server}/{database}" return (url, engine_kwargs)
[docs]class MSSQL(Distro): name = "mssql" human_name = "Microsoft SQL Server" regex = r"(^(mss|ms s|micro).*)|(.*server)" driver = "pyodbc" resolver = "[{database}].[{schema}].[{name}]" find_template = """--MSSQL.find() select sch.name as [schema], obj.name as [name], type_desc, obj.create_date, obj.modify_date, object_id from {database}.sys.schemas sch left join {database}.sys.objects obj on sch.schema_id = obj.schema_id where sch.name like '{schema_pattern}' and obj.name like '{object_pattern}' and type_desc not like '%constraint%' and type_desc not in ('sql_stored_procedure'); """ def create_spec(self, *, server, database, engine_kwargs): # https://docs.sqlalchemy.org/en/13/dialects/mssql.html spec_dict = { "server": server, "database": database, "driver": "SQL Server", "fast_executemany": "yes", "autocommit": "yes", } if "driver" in engine_kwargs: spec_dict["driver"] = engine_kwargs.pop("driver") if "username" in engine_kwargs and "password" in engine_kwargs: spec_dict["UID"] = engine_kwargs.pop("username") spec_dict["PWD"] = engine_kwargs.pop("password") else: spec_dict["trusted_connection"] = "yes" spec_string = ";".join(f"{k}={{{v}}}" for k, v in spec_dict.items()) engine_inputs = parse.quote_plus(spec_string) url = f"{self.name}+{self.driver}:///?odbc_connect={engine_inputs}" engine_kwargs = {"encoding": "latin1"} engine_kwargs.update(engine_kwargs) return (url, engine_kwargs) @classmethod def _create_engine(cls, url, **engine_kwargs): engine = sa.create_engine(url, **engine_kwargs) cls.add_fast_executemany(engine) return engine
[docs] @staticmethod def add_fast_executemany(engine: sa.engine.Engine): """ Dramatically improve pyodbc upload performance Theoretically, just "fast_executemany": "True" should be sufficient in newer versions of the driver. .. note :: Improved 1m row upload from over 7 minutes to less than 1 under pyodbc==4.0.26, SQLAlchemy==1.3.1, pandas==0.24.2. """ # pylint: disable=unused-argument, unused-variable @sa.event.listens_for(engine, "before_cursor_execute") def receive_before_cursor_execute( conn, cursor, statement, params, context, executemany ): if executemany: cursor.fast_executemany = True cursor.commit()
def find(self, channel, object_pattern="%", schema_pattern="%"): object_query = self.find_template.format( schema_pattern=schema_pattern, object_pattern=object_pattern, server=channel.server, database=channel.database, ) object_script = Script(object_query, channel=channel) df = object_script.to_table()[["name", "schema"]] result_list = [ Table(x.name, schema=x.schema, channel=channel) for x in df.itertuples(index=False) ] return result_list
[docs]class SQLite(Distro): name = "sqlite" human_name = "SQLite" regex = r"^.*lite\d?" driver = "sqlite3" resolver = "{name}" # Filenames have wholly different semantics from other SQL identifiers untouchable_identifiers = ["database"] find_template = """--SQLite.find() select name as table_name from sqlite_master where type = 'table' and name like '{object_pattern}'; """ def create_spec(self, *, server, database, engine_kwargs): # pylint: disable=unused-argument # server unneeded if not database or re.match("[^a-z]*memory[^a-z]*", str(database).lower()): final_database = ":memory:" else: resolved = Path(database).expanduser().resolve() resolved.touch() final_database = str(resolved) url = f"{self.name}:///{final_database}" return (url, engine_kwargs) def find(self, channel, object_pattern="%", schema_pattern="%"): object_query = self.find_template.format(object_pattern=object_pattern) object_script = Script(object_query, channel=channel) df = object_script.to_table()[["table_name"]] result_list = [ Table(x.table_name, channel=channel) for x in df.itertuples(index=False) ] return result_list
[docs] def determine_dtypes(self, df): """SQlite does not make gradations in integers or text, so don't try.""" return None
[docs]def round_up(n, nearest=1): """Round up ``n`` to the nearest ``nearest``. :param n: :param nearest: (Default value = 1) """ return nearest * math.ceil(n / nearest)
""" Copyright 2019 Matt VanEseltine. This file is part of laforge. laforge is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. laforge is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with laforge. If not, see <https://www.gnu.org/licenses/>. """