#!/usr/bin/env python
"""
This module provides generic infrastructure for retrieving data from a
relational database that contains immutable data. It can use a local caching
database (sqlite3) to allow for reanalysis on platforms that cannot access the
original remote database or to reduce the load on remote databases.
"""
import warnings
try:
import pymysql
pymysql.install_as_MySQLdb()
except ImportError:
pass
try:
import MySQLdb
except ImportError:
pass
import sqlite3
HIT_CACHE_DB = 1
HIT_REMOTE_DB = 2
[docs]class CachingDb(object):
"""Connect relational database with an optional caching layer interposed.
"""
#pylint: disable=too-many-arguments
[docs] def __init__(self, dbhost=None, dbuser=None, dbpassword=None, dbname=None, cache_file=None):
"""Connect to a relational database.
If instantiated with a cache_file argument, all queries will go to that
SQLite-based cache database. If this class is not instantiated with a
cache_file argument, all queries will go out to the remote database.
If none of the connection arguments (``db*``) are specified, do not
connect to a remote database and instead rely entirely on the caching
database or a separate call to the ``connect()`` method.
Arguments:
dbhost (str, optional): hostname for the remote database
dbuser (str, optional): username to use when connecting to database
dbpassword (str, optional): password for authenticating to database
dbname (str, optional): name of database to use when connecting
cache_file (str, optional): Path to an SQLite3 database to use as
a caching layer.
Attributes:
saved_results (dict): in-memory data cache, keyed by table names
and whose values are dictionaries with keys ``rows`` and
``schema``. ``rows`` are a list of row tuples returned
from earlier queries, and ``schema`` is the SQL statement
required to create the table corresponding to ``rows``.
last_hit (int): a flag to indicate whether the last query was
found in the caching database or the remote database
cache_file (str): path to the caching database's file
cache_db (sqlite3.Connection): caching database connection handle
cache_db_ps (str): paramstyle of the caching database as defined
by `PEP-0249`_
remote_db: remote database connection handle
remote_db_ps (str): paramstyle of the remote database as defined
by `PEP-0249`_
.. _PEP-0249: https://www.python.org/dev/peps/pep-0249
"""
# self.saved_results is the in-memory data cache. It has a structure of
# saved_results = {
# 'table1': {
# 'rows': [ (row1), (row2), (row3), ... ],
# 'schema': 'create table if not exists table1(...)',
# }
# 'table2': { ... }
# ...
# }
self.saved_results = {}
self.last_hit = None
# cache db
self.cache_file = None
self.cache_db = None
self.cache_db_ps = None
# actual db
self.remote_db = None
self.remote_db_ps = None
# Connect to cache db if specified
if cache_file is not None:
self.connect_cache(cache_file)
# TODO: check to ensure db isn't empty
# result = self.query('select * from summary limit 1', (), nocache=True)
# if len(result) == 0:
# warnings.warn("Using an empty cache database; queries will return nothing")
if dbhost is not None \
and dbuser is not None \
and dbpassword is not None \
and dbname is not None:
self.connect(dbhost=dbhost,
dbuser=dbuser,
dbpassword=dbpassword,
dbname=dbname)
[docs] def connect(self, dbhost, dbuser, dbpassword, dbname):
"""Establish remote db connection.
Connects to a remote MySQL database and defines the connection handler
and paramstyle attributes.
Args:
dbhost (str): hostname for the remote database
dbuser (str): username to use when connecting to database
dbpassword (str): password for authenticating to database
dbname (str): name of database to use when connecting
"""
if self.cache_db is not None:
# can't really do both local and remote dbs; all queries will be
# run against the cache_db, which is probably not what someone
# wants
warnings.warn("attempting to use both remote and cache db; disabling cache db")
self.close_cache()
self.remote_db = MySQLdb.connect(host=dbhost,
user=dbuser,
passwd=dbpassword,
db=dbname)
self.remote_db_ps = get_paramstyle_symbol(MySQLdb.paramstyle)
[docs] def close(self):
"""Destroy connection objects.
Close the remote database connection handler and reset state of remote
connection attributes.
"""
self.remote_db = None
self.remote_db_ps = None
[docs] def connect_cache(self, cache_file):
"""Open the cache database file and set the handler attribute.
Args:
cache_file (str): Path to the SQLite3 caching database file to be
used.
"""
if cache_file is not None:
self.cache_db = sqlite3.connect(cache_file)
self.cache_file = cache_file
self.cache_db_ps = get_paramstyle_symbol(sqlite3.paramstyle)
[docs] def close_cache(self):
"""Close the cache database handler and reset caching db attributes.
"""
if self.cache_db is not None:
self.cache_db = self.cache_db.close()
self.cache_file = None
[docs] def drop_cache(self, tables=None):
"""Flush saved results from memory.
If tables are specified, only drop those tables' results. If no tables
are provided, flush everything.
Args:
tables (list, optional): List of table names (str) to flush. If
omitted, flush all tables in cache.
"""
drop_caches = set([])
for table in self.saved_results:
if tables is None or table in tables:
drop_caches.add(table)
for drop_cache in drop_caches:
del self.saved_results[drop_cache]
[docs] def save_cache(self, cache_file):
"""Commit the in-memory cache to a cache database.
This method is currently very memory-inefficient and not good for
caching giant pieces of a database without something wrapping it to feed
it smaller pieces.
Note:
This manipulates the ``cache_db*`` attributes in a dirty way
to prevent closing and re-opening the original cache db. If the
``self.open_cache()`` is ever changed to include tracking more
state, this function must also be updated to retain that state while
the old cache db state is being temporarily shuffled out.
Args:
cache_file (str): Path to the cache file to be used to write out
the cache contents. This file will temporarily pre-empt the
`cache_file` attribute and should be a different file.
"""
### Shuffle out the old cache db state (if it exists)
old_state = {}
if self.cache_file is not None and self.cache_file != cache_file:
old_state = {
'cache_file': self.cache_file,
'cache_db': self.cache_db,
'cache_db_ps': self.cache_db_ps,
}
### Open a new cache db connection without closing the old cache db
self.connect_cache(cache_file)
### Commit each table we've retained in memory
drop_caches = set([])
for table, table_info in self.saved_results.items():
if len(self.saved_results[table]['rows']) < 1:
warnings.warn("table %s has no rows" % table)
continue
num_fields = None
### Verify and preprocess each saved row
for row in self.saved_results[table]['rows']:
if num_fields is None:
num_fields = len(row)
### Verify that the rows we've saved are actually all of the
### same length so that they have a hope of being inserted
### into the schema
if len(row) != num_fields:
warnings.warn(
"saved_results[%s] contains non-uniform rows (%d, %d); skipping table" %
table, len(row), num_fields)
continue
else:
### Prepend table name to row to facilitate the bulk insert
### query below
# self.saved_results[table]['rows'][index] = (table,) + row
pass
### Create the table (if necessary). This will throw all sorts of
### exceptions if
### (1) the table doesn't already exist in the cache database, or
### (2) 'schema' isn't set correctly by the downstream application
if table_info['schema'] is not None:
self.cache_db.execute(
"CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY(%s))" %
(table,
', '.join(table_info['schema']['columns']),
', '.join(table_info['schema']['primary_key'])))
### INSERT OR REPLACE so that the cache db never wins if a duplicate
### primary key is detected
query_str = "insert or replace into %s values (%s)" % (table, ','.join(['?'] * num_fields))
self.cache_db.executemany(
query_str,
table_info['rows'])
self.cache_db.commit()
### Drop committed rows from memory
drop_caches.add(table)
for drop_cache in drop_caches:
del self.saved_results[drop_cache]
self.close_cache()
### Shuffle back in the state of the old cache db
if len(old_state) > 0:
self.cache_file = old_state['cache_file']
self.cache_db = old_state['cache_db']
self.cache_db_ps = old_state['cache_db_ps']
[docs] def query(self, query_str, query_variables=(), table=None, table_schema=None):
"""Pass a query through all layers of cache and return on the first hit.
If a table is specified, the results of this query can be saved to the
cache db into a table of that name.
Args:
query_str (str): SQL query expressed as a string
query_variables (tuple): parameters to be substituted into
`query_str` if `query_str` is a parameterized query
table (str, optional): name of table in the cache database to save
the results of the query
table_schema (str, optional): when `table` is specified, the SQL
line to initialize the table in which the query results will
be cached.
Returns:
tuple: Tuple of tuples corresponding to rows of fields as returned
by the SQL query.
"""
### Collapse query string to remove extraneous whitespace
query_str = ' '.join(query_str.split())
### Check the cache database (if available)
if self.cache_db is not None:
results = self._query_sqlite3(query_str, query_variables)
self.last_hit = HIT_CACHE_DB
### Check the MySQL database (if available)
elif self.remote_db is not None:
results = self._query_mysql(query_str, query_variables)
self.last_hit = HIT_REMOTE_DB
else:
raise RuntimeError('No databases available to query')
if table is not None:
### Initialize the table if our intent is to save the result of this
### query.
if table not in self.saved_results:
self.saved_results[table] = {
'rows': [],
'schema': None,
}
### Table schema can be defined or re-defined on any query. It is
### up to the downstream application to manage this correctly.
if table_schema is not None:
self.saved_results[table]['schema'] = table_schema
### Append our results
self.saved_results[table]['rows'] += list(results)
return results
[docs] def _query_sqlite3(self, query_str, query_variables):
"""Run a query against the cache database and return the full output.
Args:
query_str (str): SQL query expressed as a string
query_variables (tuple): parameters to be substituted into
`query_str` if `query_str` is a parameterized query
"""
cursor = self.cache_db.cursor()
if '%(ps)' in query_str:
query_str = query_str % {'ps': self.cache_db_ps}
cursor.execute(query_str, query_variables)
rows = cursor.fetchall()
cursor.close()
return rows
[docs] def _query_mysql(self, query_str, query_variables):
"""Run a query against the MySQL database and return the full output.
Args:
query_str (str): SQL query expressed as a string
query_variables (tuple): parameters to be substituted into
`query_str` if `query_str` is a parameterized query
"""
cursor = self.remote_db.cursor()
if '%(ps)' in query_str:
query_str = query_str % {'ps': self.remote_db_ps}
cursor.execute(query_str, query_variables)
rows = cursor.fetchall()
cursor.close()
return rows
[docs]def get_paramstyle_symbol(paramstyle):
"""Infer the correct paramstyle for a database.paramstyle
Provides a generic way to determine the paramstyle of a database connection
handle. See `PEP-0249`_ for more information.
Args:
paramstyle (str): Result of a generic database handler's `paramstyle`
attribute
Returns:
str: The string corresponding to the paramstyle of the given database
connection.
.. _PEP-0249: https://www.python.org/dev/peps/pep-0249
"""
if paramstyle == 'qmark':
return "?"
elif paramstyle == 'format' or paramstyle == 'pyformat':
return "%s"
else:
raise Exception("Unsupported paramstyle %s" % paramstyle)