1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
try:
import mysql.connector
except ImportError:
pass
try:
import sqlite3
except ImportError:
pass
import aurweb.config
engine = None # See get_engine
def get_sqlalchemy_url():
"""
Build an SQLAlchemy for use with create_engine based on the aurweb configuration.
"""
import sqlalchemy
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
return sqlalchemy.engine.url.URL(
'mysql+mysqlconnector',
username=aurweb.config.get('database', 'user'),
password=aurweb.config.get('database', 'password'),
host=aurweb.config.get('database', 'host'),
database=aurweb.config.get('database', 'name'),
query={
'unix_socket': aurweb.config.get('database', 'socket'),
'buffered': True,
},
)
elif aur_db_backend == 'sqlite':
return sqlalchemy.engine.url.URL(
'sqlite',
database=aurweb.config.get('database', 'name'),
)
else:
raise ValueError('unsupported database backend')
def get_engine():
"""
Return the global SQLAlchemy engine.
The engine is created on the first call to get_engine and then stored in the
`engine` global variable for the next calls.
"""
from sqlalchemy import create_engine
global engine
if engine is None:
engine = create_engine(get_sqlalchemy_url(),
# check_same_thread is for a SQLite technicality
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
connect_args={"check_same_thread": False})
return engine
def connect():
"""
Return an SQLAlchemy connection. Connections are usually pooled. See
<https://docs.sqlalchemy.org/en/13/core/connections.html>.
Since SQLAlchemy connections are context managers too, you should use it
with Python’s `with` operator, or with FastAPI’s dependency injection.
"""
return get_engine().connect()
class Connection:
_conn = None
_paramstyle = None
def __init__(self):
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
aur_db_host = aurweb.config.get('database', 'host')
aur_db_name = aurweb.config.get('database', 'name')
aur_db_user = aurweb.config.get('database', 'user')
aur_db_pass = aurweb.config.get('database', 'password')
aur_db_socket = aurweb.config.get('database', 'socket')
self._conn = mysql.connector.connect(host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket,
buffered=True)
self._paramstyle = mysql.connector.paramstyle
elif aur_db_backend == 'sqlite':
aur_db_name = aurweb.config.get('database', 'name')
self._conn = sqlite3.connect(aur_db_name)
self._paramstyle = sqlite3.paramstyle
else:
raise ValueError('unsupported database backend')
def execute(self, query, params=()):
if self._paramstyle in ('format', 'pyformat'):
query = query.replace('%', '%%').replace('?', '%s')
elif self._paramstyle == 'qmark':
pass
else:
raise ValueError('unsupported paramstyle')
cur = self._conn.cursor()
cur.execute(query, params)
return cur
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()
|