235 lines
7.8 KiB
Python
235 lines
7.8 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
# @Time : 2025/3/15 15:47
|
||
# @Author : AngesZhu
|
||
# @File : postgresql_utils.py
|
||
# @Desc : pg数据库操作类
|
||
import psycopg2
|
||
from psycopg2 import sql, DatabaseError
|
||
from utils.logger_utils import logger
|
||
from utils.timer_utils import timer
|
||
|
||
|
||
class PostgresHandler:
|
||
def __init__(self, host, database, user, password, port=5432, **kwargs):
|
||
"""
|
||
初始化 PostgreSQL 数据库连接。
|
||
:param host: 数据库主机
|
||
:param database: 数据库名称
|
||
:param user: 数据库用户名
|
||
:param password: 数据库密码
|
||
:param port: 端口号,默认 5432
|
||
"""
|
||
try:
|
||
self.connection = psycopg2.connect(
|
||
host=host,
|
||
database=database,
|
||
user=user,
|
||
password=password,
|
||
port=port
|
||
)
|
||
self.cursor = self.connection.cursor()
|
||
except DatabaseError as e:
|
||
logger.error(f"Error connecting to PostgreSQL: {e}")
|
||
raise
|
||
|
||
def execute_query(self, query, params=None):
|
||
"""
|
||
执行 SQL 查询。
|
||
:param query: SQL 语句
|
||
:param params: 可选的查询参数
|
||
:return: 查询结果(如果是 SELECT 语句)
|
||
"""
|
||
try:
|
||
with timer():
|
||
self.cursor.execute(query, params or ())
|
||
# 获取表头字段
|
||
column_names = [desc[0] for desc in self.cursor.description]
|
||
if query.strip().lower().startswith("select"):
|
||
return {
|
||
"columns": column_names,
|
||
"data": self.cursor.fetchall()
|
||
}
|
||
self.connection.commit()
|
||
except DatabaseError as e:
|
||
logger.error(f"Error executing query: {query}, Params: {params}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def insert(self, table, data):
|
||
"""
|
||
插入数据。
|
||
:param table: 表名
|
||
:param data: 字典格式的列和值
|
||
:return: 已插入的数据
|
||
"""
|
||
try:
|
||
columns = sql.SQL(', ').join(map(sql.Identifier, data.keys()))
|
||
values = sql.SQL(', ').join(sql.Placeholder() * len(data))
|
||
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format(
|
||
table=sql.Identifier(table),
|
||
columns=columns,
|
||
values=values
|
||
)
|
||
self.execute_query(query, tuple(data.values()))
|
||
return data # 返回已插入的数据
|
||
except DatabaseError as e:
|
||
logger.error(f"Error inserting data into {table}: {data}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def batch_insert(self, table, data_list):
|
||
"""
|
||
批量插入数据。
|
||
:param table: 表名
|
||
:param data_list: 包含字典的列表,每个字典是一行数据
|
||
:return: 已插入的数据列表
|
||
"""
|
||
try:
|
||
if not data_list:
|
||
return []
|
||
|
||
columns = sql.SQL(', ').join(map(sql.Identifier, data_list[0].keys()))
|
||
values_template = sql.SQL(', ').join(sql.Placeholder() * len(data_list[0]))
|
||
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format(
|
||
table=sql.Identifier(table),
|
||
columns=columns,
|
||
values=values_template
|
||
)
|
||
|
||
values = [tuple(data.values()) for data in data_list]
|
||
self.cursor.executemany(query.as_string(self.connection), values)
|
||
self.connection.commit()
|
||
return data_list # 返回已插入的数据列表
|
||
except DatabaseError as e:
|
||
logger.error(f"Error batch inserting data into {table}: {data_list}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def update(self, table, data, condition):
|
||
"""
|
||
更新数据。
|
||
:param table: 表名
|
||
:param data: 字典格式的列和值
|
||
:param condition: WHERE 条件(字符串格式,如 "id = %s")
|
||
:return: 影响的行数
|
||
"""
|
||
try:
|
||
set_clause = sql.SQL(', ').join(
|
||
sql.SQL('{} = {}').format(sql.Identifier(k), sql.Placeholder()) for k in data.keys()
|
||
)
|
||
query = sql.SQL("UPDATE {table} SET {set_clause} WHERE {condition}").format(
|
||
table=sql.Identifier(table),
|
||
set_clause=set_clause,
|
||
condition=sql.SQL(condition)
|
||
)
|
||
self.execute_query(query, tuple(data.values()))
|
||
return self.cursor.rowcount # 返回影响的行数
|
||
except DatabaseError as e:
|
||
logger.error(f"Error updating data in {table}: {data}, Condition: {condition}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def delete(self, table, condition):
|
||
"""
|
||
删除数据。
|
||
:param table: 表名
|
||
:param condition: WHERE 条件(字符串格式,如 "id = %s")
|
||
:return: 影响的行数
|
||
"""
|
||
try:
|
||
query = sql.SQL("DELETE FROM {table} WHERE {condition}").format(
|
||
table=sql.Identifier(table),
|
||
condition=sql.SQL(condition)
|
||
)
|
||
self.execute_query(query)
|
||
return self.cursor.rowcount # 返回影响的行数
|
||
except DatabaseError as e:
|
||
logger.error(f"Error deleting data from {table}: Condition: {condition}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def begin_transaction(self):
|
||
"""
|
||
开始事务。
|
||
"""
|
||
try:
|
||
self.connection.autocommit = False
|
||
except DatabaseError as e:
|
||
logger.error(f"Error beginning transaction: {e}")
|
||
raise
|
||
|
||
def commit_transaction(self):
|
||
"""
|
||
提交事务。
|
||
"""
|
||
try:
|
||
self.connection.commit()
|
||
self.connection.autocommit = True
|
||
except DatabaseError as e:
|
||
logger.error(f"Error committing transaction: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def rollback_transaction(self):
|
||
"""
|
||
回滚事务。
|
||
"""
|
||
try:
|
||
self.connection.rollback()
|
||
self.connection.autocommit = True
|
||
except DatabaseError as e:
|
||
logger.error(f"Error rolling back transaction: {e}")
|
||
raise
|
||
|
||
def close(self):
|
||
"""
|
||
关闭数据库连接。
|
||
"""
|
||
try:
|
||
self.cursor.close()
|
||
self.connection.close()
|
||
except DatabaseError as e:
|
||
logger.error(f"Error closing connection: {e}")
|
||
raise
|
||
|
||
|
||
if __name__ == "__main__":
|
||
db_config = {
|
||
"host": "localhost",
|
||
"database": "test_db",
|
||
"user": "postgres",
|
||
"password": "password",
|
||
"port": 5432
|
||
}
|
||
|
||
handler = PostgresHandler(**db_config)
|
||
|
||
try:
|
||
# 插入单条数据
|
||
inserted_data = handler.insert("users", {"name": "Alice", "age": 25})
|
||
print("Inserted:", inserted_data)
|
||
|
||
# 批量插入数据
|
||
batch_data = [
|
||
{"name": "Bob", "age": 30},
|
||
{"name": "Charlie", "age": 35}
|
||
]
|
||
inserted_batch = handler.batch_insert("users", batch_data)
|
||
print("Batch Inserted:", inserted_batch)
|
||
|
||
# 查询数据
|
||
results = handler.execute_query("SELECT * FROM users WHERE age > %s", (20,))
|
||
print("Query Results:", results)
|
||
|
||
# 更新数据
|
||
rows_affected = handler.update("users", {"age": 26}, "name = %s", ("Alice",))
|
||
print("Rows Affected by Update:", rows_affected)
|
||
|
||
# 删除数据
|
||
rows_deleted = handler.delete("users", "name = %s", ("Charlie",))
|
||
print("Rows Deleted:", rows_deleted)
|
||
|
||
finally:
|
||
handler.close()
|