-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathsample_project.py
More file actions
520 lines (456 loc) · 21 KB
/
sample_project.py
File metadata and controls
520 lines (456 loc) · 21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import hashlib
import json
import logging
import os
import re
from abc import ABC, abstractmethod
from copy import deepcopy
from datetime import datetime
import psycopg2
from psycopg2 import sql
from django.conf import settings
from django.core.files.base import File
from jinja2 import Environment, FileSystemLoader
from backend.application.context.application import ApplicationContext
from backend.application.context.connection import ConnectionContext
from backend.application.utils import get_filter
from backend.core.models.config_models import ConfigModels
from backend.core.models.connection_models import ConnectionDetails
from backend.core.models.dependent_models import DependentModels
from backend.core.models.project_details import ProjectDetails
from backend.errors import SampleProjectConnectionFailed, MasterDbNotExist
from backend.errors.exceptions import SampleProjectLimitExceed
from backend.server.settings.base import SAMPLE_CONNECTION
from backend.utils.tenant_context import get_current_tenant, get_current_user
class SampleProject(ABC):
def __init__(self):
self._template_environment = None
self._project_context = None
self._application_context = None
self._sample_seed_path = None
self._sample_model_path = None
self._sample_py_path = None
self._sample_template_path = None
self.sample_connection = deepcopy(SAMPLE_CONNECTION)
self._postgres_connection = None
self._org_id: str = re.sub(r"[^A-Za-z0-9_]", "_", get_current_tenant().lower() or "default_org").strip("_")
self._user_id: str = re.sub(r"[^A-Za-z0-9_]", "_", get_current_user().get("username")).strip("_")
self.password = self.create_password(self._org_id)
self._clone_db = True
self.timestamp_str = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
self.project_limit = settings.SAMPLE_PROJECT_LIMIT
self.project_base_name = None
@property
@abstractmethod
def master_db_name(self):
pass
@property
@abstractmethod
def base_path(self):
pass
@property
def seed_path(self):
if not self._sample_seed_path:
self._sample_seed_path = os.path.join(self.base_path, "seed_files")
return self._sample_seed_path
@property
def model_path(self):
if not self._sample_model_path:
self._sample_model_path = os.path.join(self.base_path, "model_files")
return self._sample_model_path
@property
def model_py_path(self):
if not self._sample_py_path:
self._sample_py_path = os.path.join(self.base_path, "model_py_files")
return self._sample_py_path
@property
def model_template_path(self):
if not self._sample_template_path:
self._sample_template_path = os.path.join(self.base_path, "model_templates")
return self._sample_template_path
@property
def template_environment(self):
if not self._template_environment:
self._template_environment = Environment(loader=FileSystemLoader(self.model_template_path))
return self._template_environment
@property
def org_id(self):
return re.sub(r"[^A-Za-z0-9_]", "_", get_current_tenant().lower() or "default_org").strip("_")
@property
@abstractmethod
def database_name(self) -> str:
pass
@property
@abstractmethod
def user_name(self) -> str:
pass
@staticmethod
def create_password(org_id: str):
hash_object = hashlib.sha256(org_id.encode())
hex_dig = hash_object.hexdigest()
return hex_dig[:10]
@property
@abstractmethod
def project_name(self) -> str:
pass
@property
@abstractmethod
def project_description(self) -> str:
pass
@property
@abstractmethod
def postgres_connection_details(self) -> dict[str, str]:
pass
@property
@abstractmethod
def connection_name(self):
pass
@property
@abstractmethod
def csv_files(self):
pass
@property
@abstractmethod
def model_list(self):
pass
@property
def schema_name(self):
return "raw"
@property
def postgres_connection(self):
if not self._postgres_connection:
self._postgres_connection = psycopg2.connect(
host=self.sample_connection["host"],
port=self.sample_connection["port"],
user=self.sample_connection["user"],
password=self.sample_connection["passw"],
dbname=self.sample_connection["dbname"],
)
# Required to execute DROP DATABASE
self._postgres_connection.autocommit = True
return self._postgres_connection
def execute_sql_queries(self, statements: list):
"""This method is used to execute the sql queries."""
try:
cursor = self.postgres_connection.cursor()
for statement in statements:
cursor.execute(statement)
cursor.close()
except psycopg2.Error as e:
if self.master_db_name in str(e):
raise MasterDbNotExist()
logging.error(f"Error on querying the database --> {e}")
raise SampleProjectConnectionFailed()
def _grant_schema_permissions_on_new_db(self):
"""Grant schema permissions on the newly cloned database.
This must be executed on the NEW database, not the admin
database.
"""
new_db_connection = None
try:
new_db_connection = psycopg2.connect(
host=self.sample_connection["host"],
port=self.sample_connection["port"],
user=self.sample_connection["user"],
password=self.sample_connection["passw"],
dbname=self.database_name, # Connect to the NEW database
)
new_db_connection.autocommit = True
cursor = new_db_connection.cursor()
schemas = ["raw", "dev", "stg", "prod"]
user_ident = sql.Identifier(self.user_name)
for schema in schemas:
schema_ident = sql.Identifier(schema)
cursor.execute(sql.SQL("GRANT USAGE ON SCHEMA {} TO {};").format(schema_ident, user_ident))
cursor.execute(
sql.SQL("GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA {} TO {};").format(
schema_ident, user_ident
)
)
cursor.execute(
sql.SQL("ALTER DEFAULT PRIVILEGES IN SCHEMA {} "
"GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {};").format(
schema_ident, user_ident
)
)
# Transfer schema ownership for non-source schemas so the project user
# can DROP and recreate tables during transformation runs.
# Schema owners can drop any object within their schema in PostgreSQL.
# raw schema stays read-only (source data from template).
for schema in ["dev", "stg", "prod"]:
cursor.execute(
sql.SQL("ALTER SCHEMA {} OWNER TO {};").format(sql.Identifier(schema), user_ident)
)
cursor.close()
logging.info(f"Schema permissions granted on database {self.database_name}")
except psycopg2.Error as e:
logging.error(f"Error granting schema permissions on new database: {e}")
raise SampleProjectConnectionFailed()
finally:
if new_db_connection:
new_db_connection.close()
def close_postgres_connection(self):
"""Properly close the postgres connection."""
if self._postgres_connection:
try:
self._postgres_connection.close()
logging.info("Postgres connection closed successfully")
except psycopg2.Error as e:
logging.warning(f"Error closing postgres connection: {e}")
finally:
self._postgres_connection = None
def load_app_context(self):
"""This method is used to load the application context."""
if self.project_context:
self._application_context = ApplicationContext(project_id=self.project_context.project_id)
@property
def app_context(self) -> ApplicationContext:
"""This property is used to get the application context."""
if not self._application_context:
self.load_app_context()
return self._application_context
@property
def project_context(self) -> ProjectDetails:
"""This property is used to get the project context."""
if not self._project_context:
project_filter = {"project_name": self.project_name, "is_sample": True}
project_filter.update(get_filter())
pd: ProjectDetails = ProjectDetails.objects.filter(**project_filter).first()
if pd:
self._project_context = pd
return self._project_context
def load_sample_project(self) -> dict[str, str]:
"""This method is used to load the sample project.
This will clear the existing project and database and create a
new one.
"""
self.check_project_limit()
try:
# Check if the database already exists
# self.clear_existing_project()
# self.clear_existing_db()
self.create_new_database()
sample_project_details = self.create_project_connection()
self.create_schemas()
self.upload_and_run_csv()
self.create_and_load_models()
logging.info("sample project created successfully")
return sample_project_details
finally:
self.close_postgres_connection()
def clear_existing_project(self):
# Clearing the project if exists
filter_criteria = {"project_name": self.project_name, "is_sample": True}
filter_criteria.update(get_filter())
pd: ProjectDetails = ProjectDetails.objects.filter(**filter_criteria).first()
if pd:
logging.info(
f"Sample project with name {self.project_name} for User {filter_criteria['created_by__username']} is being deleted"
)
pd.delete()
logging.info("existing sample project deleted")
def clear_existing_db(self):
# Clearing the connection if exists
connection_filter = {"connection_name": self.connection_name}
connection_filter.update(get_filter())
cd: ConnectionDetails = ConnectionDetails.objects.filter(**connection_filter).first()
if cd:
cd.delete()
# Clearing existing users and databases
db_ident = sql.Identifier(self.database_name)
user_ident = sql.Identifier(self.user_name)
terminate_session = sql.SQL(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {};"
).format(sql.Literal(self.database_name))
drop_database = sql.SQL("DROP DATABASE IF EXISTS {};").format(db_ident)
drop_user = sql.SQL("DROP USER IF EXISTS {};").format(user_ident)
sql_statements = [
terminate_session,
drop_database,
drop_user,
]
self.execute_sql_queries(statements=sql_statements)
logging.info("existing sample db and user deleted")
def create_new_database(self):
"""This method is used to create a new database."""
db_ident = sql.Identifier(self.database_name)
user_ident = sql.Identifier(self.user_name)
create_db_query = sql.SQL("CREATE DATABASE {};").format(db_ident)
create_user_query = sql.SQL("CREATE USER {} WITH ENCRYPTED PASSWORD {};").format(
user_ident, sql.Literal(self.password)
)
grant_role_query = sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format(db_ident, user_ident)
statements = [create_db_query]
if not self.user_exist():
statements.append(create_user_query)
statements.append(grant_role_query)
if self._clone_db:
logging.info(f"creating(cloning) new sample db with the name - {self.database_name}")
create_template_db_query = sql.SQL("CREATE DATABASE {} TEMPLATE {};").format(
db_ident, sql.Identifier(self.master_db_name)
)
statements[0] = create_template_db_query
try:
self.execute_sql_queries(statements=statements)
logging.info(
f"new sample db and user created with the name - {self.user_name} database - {self.database_name}"
)
# GRANT schema permissions on the NEW database (not the admin DB)
if self._clone_db:
self._grant_schema_permissions_on_new_db()
except MasterDbNotExist as e:
if not self._clone_db:
logging.info(f"Error creating database {self.database_name}: {e}. All retries completed")
raise e
logging.info(f"Error creating database {self.database_name}: {e}. Retrying with empty template...")
self._clone_db = False
self.create_new_database()
# self.grant_permissions()
logging.info(f"Error creating database {self.database_name}: {e}. Retrying...")
def user_exist(self) -> bool:
try:
cursor = self.postgres_connection.cursor()
cursor.execute("SELECT 1 FROM pg_roles WHERE rolname = %s", (self.user_name,))
user_exist = cursor.fetchone() is not None
cursor.close()
logging.info(f"checking user already exist, result : {user_exist}")
return user_exist
except psycopg2.Error as e:
logging.critical(f"Exception while checking user")
return False
def grant_permissions(self):
"""This method is used to grant the permissions to the user and
database."""
user_ident = sql.Identifier(self.user_name)
db_ident = sql.Identifier(self.database_name)
statements = [
sql.SQL("CREATE USER {} WITH ENCRYPTED PASSWORD {};").format(user_ident, sql.Literal(self.password)),
sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format(db_ident, user_ident),
]
self.execute_sql_queries(statements=statements)
logging.info(f"new sample db and user created with the name - {self.user_name} database - {self.database_name}")
def check_project_limit(self):
filter_criteria = {"is_sample": True, "project_name__startswith": self.project_base_name}
filter_criteria.update(get_filter())
sample_project_count = ProjectDetails.objects.filter(**filter_criteria).count()
if sample_project_count >= self.project_limit:
raise SampleProjectLimitExceed(
project_base_name=self.project_base_name,
sample_project_count=sample_project_count,
sample_project_limit=self.project_limit,
)
def create_project_connection(self):
"""This method is used to create the project connection."""
self.sample_connection["passw"] = self.password
self.sample_connection["dbname"] = self.database_name
self.sample_connection["user"] = self.user_name
self.sample_connection["schema"] = self.schema_name
connection_filter = {
"connection_name": self.postgres_connection_details["name"],
}
connection_filter.update(get_filter())
con_context = ConnectionContext()
con_exist = ConnectionDetails.objects.filter(**connection_filter).first()
if con_exist:
logging.info(f"sample connection {con_exist} exists, proceeding to delete ")
con_exist.delete()
logging.info(f"existing sample connection {con_exist} deleted")
connection_data = con_context.create_connection(connection_details=self.postgres_connection_details)
connection_instance_id = connection_data.get("id")
connection_instance = ConnectionDetails.objects.filter(connection_id=connection_instance_id).first()
logging.info("Create connection is success")
pd = ProjectDetails(
project_name=self.project_name,
project_description=self.project_description,
connection_model=connection_instance,
created_by=get_current_user(),
is_sample=True,
)
pd.save()
project_details = self.app_context.get_project_details()
self.close_postgres_connection()
return project_details
def create_schemas(self):
"""Create required schemas in the newly created database."""
if self._clone_db:
return
try:
statements = []
# Create schemas using raw SQL
schemas = ["raw", "dev", "stg", "prod"]
for schema in schemas:
schema_ident = sql.Identifier(schema)
statements.append(sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE;").format(schema_ident))
statements.append(sql.SQL("CREATE SCHEMA IF NOT EXISTS {};").format(schema_ident))
self.execute_sql_queries(statements=statements)
logging.info("schemas created successfully")
except Exception as e:
logging.info(f"Error creating schemas: {e}")
def upload_and_run_csv(self):
csv_files = self.csv_files
for csv_file in csv_files:
file = os.path.join(self.seed_path, csv_file)
if not os.path.exists(file):
logging.info({"error": "File not found"})
with open(file, "rb") as file_content:
# Extract the filename from the path
file_name = os.path.basename(file)
# Wrap the file content in Django's File class to handle it efficiently
uploaded_file = File(file_content)
# Upload the file using the app's method
self.app_context.upload_a_file(file_name=file_name, file_content=uploaded_file)
logging.info("CSV files uploaded successfully")
if not self._clone_db:
seed = {"runAll": True, "schema_name": self.schema_name}
self.app_context.execute_visitran_seed_command(seed)
logging.info("seed command is executed")
def create_and_load_models(self):
model_list = self.model_list
for model_name in model_list:
logging.info(f"starting execution for {model_name}")
model_path = os.path.join(self.model_path, model_name)
with open(f"{model_path}.json") as model_file:
# Extract the filename from the path
config_model_data = json.load(model_file)
template = self.template_environment.get_template(f"{model_name}.jinja")
output = template.render({"model_path": self._project_context.project_py_name})
config_model = ConfigModels(
project_instance=self.project_context,
model_name=config_model_data["model_name"],
model_data=config_model_data["model_data"],
)
config_model.save()
self.app_context.session.update_model_content(model_name=model_name, model_content=output)
for dependent_data in config_model_data["dependent_models"]:
dp = DependentModels(
project_instance=self.project_context,
model=config_model,
transformation_id=dependent_data["transformation_id"],
model_data=dependent_data["model_data"],
)
dp.save()
# add node to model graph and update model reference
self.app_context.add_node_to_model_graph(model_name)
self.app_context.update_sample_project_model_graph(config_model_data["model_data"], model_name)
# IMPORTANT: Use app_context's project_instance to preserve the model graph
# that was saved during add_node_to_model_graph/update_sample_project_model_graph.
# Using self.project_context.save() would overwrite the graph with stale data.
self.app_context.session.project_instance.is_completed = True
self.app_context.session.project_instance.save()
logging.info("Created and loaded all models successfully")
def sanitize_name(self, name):
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name).lower()
if not re.match(r"^[a-zA-Z_]", sanitized):
sanitized = f"_{sanitized}"
# Create a short hash from the full original input (8 chars)
suffix = hashlib.sha256(name.encode()).hexdigest()[:8]
# Leave space for underscore and hash suffix (9 total)
base = sanitized[: 63 - 9] # 63 - 1 (_) - 8 (hash)
return f"{base}_{suffix}"
def project_exist(self, project_name: str):
try:
project_filter = {"project_name": project_name, "is_sample": True}
project_filter.update(get_filter())
return ProjectDetails.objects.filter(**project_filter).exists()
except:
return False