Source code for explorer.models

import logging
from time import time
import uuid

from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import DatabaseError, models, transaction
from django.urls import reverse
from django.utils.translation import gettext_lazy as _

from explorer import app_settings
from explorer.telemetry import Stat, StatNames
from explorer.utils import (
    extract_params, get_params_for_url, get_s3_bucket, passes_blacklist, s3_url,
    shared_dict_update, swap_params,
)
from explorer.ee.db_connections.utils import default_db_connection


# Issue #618. All models must be imported so that Django understands how to manage migrations for the app
from explorer.ee.db_connections.models import DatabaseConnection  # noqa
from explorer.assistant.models import PromptLog, TableDescription  # noqa

MSG_FAILED_BLACKLIST = "Query failed the SQL blacklist: %s"

logger = logging.getLogger(__name__)


[docs]class Query(models.Model): title = models.CharField(max_length=255) sql = models.TextField(blank=False, null=False) description = models.TextField(blank=True) created_by_user = models.ForeignKey( settings.AUTH_USER_MODEL, null=True, blank=True, on_delete=models.CASCADE ) created_at = models.DateTimeField(auto_now_add=True) last_run_date = models.DateTimeField(auto_now=True) snapshot = models.BooleanField( default=False, help_text=_("Include in snapshot task (if enabled)") ) # NOTE this field is deprecated in favor of database_connection and no longer in use. # It is present in the 6.0 release to preserve backwards compatibility in case there is need for a rollback. # It will be removed in a future release (e.g. v6.1) connection = models.CharField( blank=True, max_length=128, default="", help_text=_( "Name of DB connection (as specified in settings) to use for " "this query." "Will use EXPLORER_DEFAULT_CONNECTION if left blank" ) ) database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) few_shot = models.BooleanField(default=False, help_text=_( "Will be included as a good example of SQL in assistant queries that use relevant tables"))
[docs] def __init__(self, *args, **kwargs): self.params = kwargs.get("params") kwargs.pop("params", None) super().__init__(*args, **kwargs)
class Meta: ordering = ["title"] verbose_name = _("Query") verbose_name_plural = _("Queries") def __str__(self): return str(self.title) def get_run_count(self): return self.querylog_set.count() def last_run_log(self): ql = self.querylog_set.first() return ql or QueryLog(success=True, run_at=self.created_at) def avg_duration_display(self): d = self.avg_duration() if d: return f"{self.avg_duration():10.3f}" return "" def avg_duration(self): return self.querylog_set.aggregate( models.Avg("duration") )["duration__avg"] def passes_blacklist(self): return passes_blacklist(self.final_sql()) def final_sql(self): return swap_params(self.sql, self.available_params()) def execute_query_only(self): # check blacklist every time sql is run to catch parameterized SQL passes_blacklist_flag, failing_words = self.passes_blacklist() error = MSG_FAILED_BLACKLIST % ", ".join( failing_words) if not passes_blacklist_flag else None if error: raise ValidationError( error, code="InvalidSql" ) conn = self.database_connection or default_db_connection() return QueryResult( self.final_sql(), conn.as_django_connection() ) def execute_with_logging(self, executing_user): ql = self.log(executing_user) ql.save() try: ret = self.execute() except DatabaseError as e: ql.success = False ql.error = str(e) ql.save() raise e ql.duration = ret.duration ql.save() Stat(StatNames.QUERY_RUN, {"sql_len": len(ql.sql), "duration": ql.duration}).track() return ret, ql def execute(self): ret = self.execute_query_only() ret.process() return ret def available_params(self): """ Merge parameter values into a dictionary of available parameters :return: A merged dictionary of parameter names and values. Values of non-existent parameters are removed. :rtype: dict """ p = extract_params(self.sql) p2 = {k: v["default"] for k, v in p.items()} if self.params: shared_dict_update(p2, self.params) return p2 def available_params_w_labels(self): """ Merge parameter values into a dictionary of available parameters with their labels :return: A merged dictionary of parameter names and values/labels. Values of non-existent parameters are removed. :rtype: dict """ p = extract_params(self.sql) return { k: { "label": v["label"] if v["label"] else k, "val": self.params[k] if self.params and k in self.params else v["default"] } for k, v in p.items() } def get_absolute_url(self): return reverse("query_detail", kwargs={"query_id": self.id}) @property def params_for_url(self): return get_params_for_url(self) def log(self, user=None): if user: if user.is_anonymous: user = None ql = QueryLog( sql=self.final_sql(), query_id=self.id, run_by_user=user, database_connection=self.database_connection, ) ql.save() return ql @property def shared(self): return self.id in set( sum(app_settings.EXPLORER_GET_USER_QUERY_VIEWS().values(), []) ) @property def snapshots(self): if app_settings.ENABLE_TASKS: b = get_s3_bucket() objects = b.objects.filter(Prefix=f"query-{self.id}/snap-") objects_s = sorted(objects, key=lambda k: k.last_modified) return [ SnapShot( s3_url(b, o.key), o.last_modified ) for o in objects_s ] def is_favorite(self, user): if user.is_authenticated: return self.favorites.filter(user_id=user.id).exists() else: return False
class SnapShot: def __init__(self, url, last_modified): self.url = url self.last_modified = last_modified class QueryLog(models.Model): sql = models.TextField(blank=True) query = models.ForeignKey( Query, null=True, blank=True, on_delete=models.SET_NULL ) run_by_user = models.ForeignKey( settings.AUTH_USER_MODEL, null=True, blank=True, on_delete=models.CASCADE ) run_at = models.DateTimeField(auto_now_add=True) duration = models.FloatField(blank=True, null=True) # milliseconds # NOTE this field is deprecated in favor of database_connection and no longer in use. # It is present in the 6.0 release to preserve backwards compatibility in case there is need for a rollback. # It will be removed in a future release (e.g. v6.1) connection = models.CharField(blank=True, max_length=128, default="") database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) success = models.BooleanField(default=True) error = models.TextField(blank=True, null=True) @property def is_playground(self): return self.query_id is None class Meta: ordering = ["-run_at"] class QueryFavorite(models.Model): query = models.ForeignKey( Query, on_delete=models.CASCADE, related_name="favorites" ) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="favorites" ) class Meta: unique_together = ["query", "user"] class QueryResult: def __init__(self, sql, connection): self.sql = sql self.connection = connection cursor, duration = self.execute_query() self._description = cursor.description or [] self._data = [list(r) for r in cursor.fetchall()] self.duration = duration cursor.close() self._headers = self._get_headers() self._summary = {} @property def data(self): return self._data or [] @property def headers(self): return self._headers or [] @property def header_strings(self): return [str(h) for h in self.headers] def _get_headers(self): return [ ColumnHeader(d[0]) for d in self._description ] if self._description else [ColumnHeader("--")] def _get_numerics(self): if hasattr(self.connection.Database, "NUMBER"): return [ ix for ix, c in enumerate(self._description) if hasattr(c, "type_code") and c.type_code in self.connection.Database.NUMBER.values ] elif self.data: d = self.data[0] return [ ix for ix, _ in enumerate(self._description) if not isinstance(d[ix], str) and str(d[ix]).isnumeric() ] return [] def _get_transforms(self): transforms = dict(app_settings.EXPLORER_TRANSFORMS) return [ (ix, transforms[str(h)]) for ix, h in enumerate(self.headers) if str(h) in transforms.keys() ] def column(self, ix): return [r[ix] for r in self.data] def process(self): start_time = time() self.process_columns() self.process_rows() logger.info("Explorer Query Processing took %sms." % ((time() - start_time) * 1000)) def process_columns(self): for ix in self._get_numerics(): self.headers[ix].add_summary(self.column(ix)) def process_rows(self): transforms = self._get_transforms() if transforms: for r in self.data: for ix, t in transforms: r[ix] = t.format(str(r[ix])) def execute_query(self): cursor = self.connection.cursor() start_time = time() try: with transaction.atomic(self.connection.alias): cursor.execute(self.sql) except DatabaseError as e: cursor.close() raise e return cursor, ((time() - start_time) * 1000) class ColumnHeader: def __init__(self, title): self.title = title.strip() self.summary = None def add_summary(self, column): self.summary = ColumnSummary(self, column) def __str__(self): return self.title class ColumnStat: def __init__(self, label, statfn, precision=2, handles_null=False): self.label = label self.statfn = statfn self.precision = precision self.handles_null = handles_null def __call__(self, coldata): self.value = round( float(self.statfn(coldata)), self.precision ) if coldata else 0 def __str__(self): return self.label class ColumnSummary: def __init__(self, header, col): self._header = header self._stats = [ ColumnStat("Sum", sum), ColumnStat("Avg", lambda x: float(sum(x)) / float(len(x))), ColumnStat("Min", min), ColumnStat("Max", max), ColumnStat( "NUL", lambda x: int(sum(map(lambda y: 1 if y is None else 0, x))), 0, True ) ] without_nulls = list(map(lambda x: 0 if x is None else x, col)) for stat in self._stats: stat(col) if stat.handles_null else stat(without_nulls) @property def stats(self): return {c.label: c.value for c in self._stats} def __str__(self): return str(self._header) class ExplorerValueManager(models.Manager): def get_uuid(self): # If blank or non-existing, generates a new UUID uuid_obj, created = self.get_or_create( key=ExplorerValue.INSTALL_UUID, defaults={"value": str(uuid.uuid4())} ) if created or uuid_obj.value is None: uuid_obj.value = str(uuid.uuid4()) uuid_obj.save() return uuid_obj.value def get_startup_last_send(self): # Stored as a Unix timestamp try: timestamp = self.get(key=ExplorerValue.STARTUP_METRIC_LAST_SEND).value if timestamp: return float(timestamp) return None except ExplorerValue.DoesNotExist: return None def set_startup_last_send(self, ts): obj, created = self.get_or_create( key=ExplorerValue.STARTUP_METRIC_LAST_SEND, defaults={"value": str(ts)} ) if not created: obj.value = str(ts) obj.save() def get_item(self, key): return self.filter(key=key).first() class ExplorerValue(models.Model): INSTALL_UUID = "UUID" STARTUP_METRIC_LAST_SEND = "SMLS" ASSISTANT_SYSTEM_PROMPT = "ASP" EXPLORER_SETTINGS_CHOICES = [ (INSTALL_UUID, "Install Unique ID"), (STARTUP_METRIC_LAST_SEND, "Startup metric last send"), (ASSISTANT_SYSTEM_PROMPT, "System prompt for SQL Assistant"), ] key = models.CharField(max_length=5, choices=EXPLORER_SETTINGS_CHOICES, unique=True) value = models.TextField(null=True, blank=True) objects = ExplorerValueManager()