data_factory/utils/postgresql_utils.py

235 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/15 15:47
# @Author : AngesZhu
# @File : postgresql_utils.py
# @Desc : pg数据库操作类
import psycopg2
from psycopg2 import sql, DatabaseError
from utils.logger_utils import logger
from utils.timer_utils import timer
class PostgresHandler:
def __init__(self, host, database, user, password, port=5432, **kwargs):
"""
初始化 PostgreSQL 数据库连接。
:param host: 数据库主机
:param database: 数据库名称
:param user: 数据库用户名
:param password: 数据库密码
:param port: 端口号,默认 5432
"""
try:
self.connection = psycopg2.connect(
host=host,
database=database,
user=user,
password=password,
port=port
)
self.cursor = self.connection.cursor()
except DatabaseError as e:
logger.error(f"Error connecting to PostgreSQL: {e}")
raise
def execute_query(self, query, params=None):
"""
执行 SQL 查询。
:param query: SQL 语句
:param params: 可选的查询参数
:return: 查询结果(如果是 SELECT 语句)
"""
try:
with timer():
self.cursor.execute(query, params or ())
# 获取表头字段
column_names = [desc[0] for desc in self.cursor.description]
if query.strip().lower().startswith("select"):
return {
"columns": column_names,
"data": self.cursor.fetchall()
}
self.connection.commit()
except DatabaseError as e:
logger.error(f"Error executing query: {query}, Params: {params}, Error: {e}")
self.connection.rollback()
raise
def insert(self, table, data):
"""
插入数据。
:param table: 表名
:param data: 字典格式的列和值
:return: 已插入的数据
"""
try:
columns = sql.SQL(', ').join(map(sql.Identifier, data.keys()))
values = sql.SQL(', ').join(sql.Placeholder() * len(data))
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format(
table=sql.Identifier(table),
columns=columns,
values=values
)
self.execute_query(query, tuple(data.values()))
return data # 返回已插入的数据
except DatabaseError as e:
logger.error(f"Error inserting data into {table}: {data}, Error: {e}")
self.connection.rollback()
raise
def batch_insert(self, table, data_list):
"""
批量插入数据。
:param table: 表名
:param data_list: 包含字典的列表,每个字典是一行数据
:return: 已插入的数据列表
"""
try:
if not data_list:
return []
columns = sql.SQL(', ').join(map(sql.Identifier, data_list[0].keys()))
values_template = sql.SQL(', ').join(sql.Placeholder() * len(data_list[0]))
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format(
table=sql.Identifier(table),
columns=columns,
values=values_template
)
values = [tuple(data.values()) for data in data_list]
self.cursor.executemany(query.as_string(self.connection), values)
self.connection.commit()
return data_list # 返回已插入的数据列表
except DatabaseError as e:
logger.error(f"Error batch inserting data into {table}: {data_list}, Error: {e}")
self.connection.rollback()
raise
def update(self, table, data, condition):
"""
更新数据。
:param table: 表名
:param data: 字典格式的列和值
:param condition: WHERE 条件(字符串格式,如 "id = %s"
:return: 影响的行数
"""
try:
set_clause = sql.SQL(', ').join(
sql.SQL('{} = {}').format(sql.Identifier(k), sql.Placeholder()) for k in data.keys()
)
query = sql.SQL("UPDATE {table} SET {set_clause} WHERE {condition}").format(
table=sql.Identifier(table),
set_clause=set_clause,
condition=sql.SQL(condition)
)
self.execute_query(query, tuple(data.values()))
return self.cursor.rowcount # 返回影响的行数
except DatabaseError as e:
logger.error(f"Error updating data in {table}: {data}, Condition: {condition}, Error: {e}")
self.connection.rollback()
raise
def delete(self, table, condition):
"""
删除数据。
:param table: 表名
:param condition: WHERE 条件(字符串格式,如 "id = %s"
:return: 影响的行数
"""
try:
query = sql.SQL("DELETE FROM {table} WHERE {condition}").format(
table=sql.Identifier(table),
condition=sql.SQL(condition)
)
self.execute_query(query)
return self.cursor.rowcount # 返回影响的行数
except DatabaseError as e:
logger.error(f"Error deleting data from {table}: Condition: {condition}, Error: {e}")
self.connection.rollback()
raise
def begin_transaction(self):
"""
开始事务。
"""
try:
self.connection.autocommit = False
except DatabaseError as e:
logger.error(f"Error beginning transaction: {e}")
raise
def commit_transaction(self):
"""
提交事务。
"""
try:
self.connection.commit()
self.connection.autocommit = True
except DatabaseError as e:
logger.error(f"Error committing transaction: {e}")
self.connection.rollback()
raise
def rollback_transaction(self):
"""
回滚事务。
"""
try:
self.connection.rollback()
self.connection.autocommit = True
except DatabaseError as e:
logger.error(f"Error rolling back transaction: {e}")
raise
def close(self):
"""
关闭数据库连接。
"""
try:
self.cursor.close()
self.connection.close()
except DatabaseError as e:
logger.error(f"Error closing connection: {e}")
raise
if __name__ == "__main__":
db_config = {
"host": "localhost",
"database": "test_db",
"user": "postgres",
"password": "password",
"port": 5432
}
handler = PostgresHandler(**db_config)
try:
# 插入单条数据
inserted_data = handler.insert("users", {"name": "Alice", "age": 25})
print("Inserted:", inserted_data)
# 批量插入数据
batch_data = [
{"name": "Bob", "age": 30},
{"name": "Charlie", "age": 35}
]
inserted_batch = handler.batch_insert("users", batch_data)
print("Batch Inserted:", inserted_batch)
# 查询数据
results = handler.execute_query("SELECT * FROM users WHERE age > %s", (20,))
print("Query Results:", results)
# 更新数据
rows_affected = handler.update("users", {"age": 26}, "name = %s", ("Alice",))
print("Rows Affected by Update:", rows_affected)
# 删除数据
rows_deleted = handler.delete("users", "name = %s", ("Charlie",))
print("Rows Deleted:", rows_deleted)
finally:
handler.close()