#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2025/3/15 15:47 # @Author : AngesZhu # @File : postgresql_utils.py # @Desc : pg数据库操作类 import psycopg2 from psycopg2 import sql, DatabaseError from utils.logger_utils import logger from utils.timer_utils import timer class PostgresHandler: def __init__(self, host, database, user, password, port=5432, **kwargs): """ 初始化 PostgreSQL 数据库连接。 :param host: 数据库主机 :param database: 数据库名称 :param user: 数据库用户名 :param password: 数据库密码 :param port: 端口号,默认 5432 """ try: self.connection = psycopg2.connect( host=host, database=database, user=user, password=password, port=port ) self.cursor = self.connection.cursor() except DatabaseError as e: logger.error(f"Error connecting to PostgreSQL: {e}") raise def execute_query(self, query, params=None): """ 执行 SQL 查询。 :param query: SQL 语句 :param params: 可选的查询参数 :return: 查询结果(如果是 SELECT 语句) """ try: with timer(): self.cursor.execute(query, params or ()) # 获取表头字段 column_names = [desc[0] for desc in self.cursor.description] if query.strip().lower().startswith("select"): return { "columns": column_names, "data": self.cursor.fetchall() } self.connection.commit() except DatabaseError as e: logger.error(f"Error executing query: {query}, Params: {params}, Error: {e}") self.connection.rollback() raise def insert(self, table, data): """ 插入数据。 :param table: 表名 :param data: 字典格式的列和值 :return: 已插入的数据 """ try: columns = sql.SQL(', ').join(map(sql.Identifier, data.keys())) values = sql.SQL(', ').join(sql.Placeholder() * len(data)) query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format( table=sql.Identifier(table), columns=columns, values=values ) self.execute_query(query, tuple(data.values())) return data # 返回已插入的数据 except DatabaseError as e: logger.error(f"Error inserting data into {table}: {data}, Error: {e}") self.connection.rollback() raise def batch_insert(self, table, data_list): """ 批量插入数据。 :param table: 表名 :param data_list: 包含字典的列表,每个字典是一行数据 :return: 已插入的数据列表 """ try: if not data_list: return [] columns = sql.SQL(', ').join(map(sql.Identifier, data_list[0].keys())) values_template = sql.SQL(', ').join(sql.Placeholder() * len(data_list[0])) query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({values})").format( table=sql.Identifier(table), columns=columns, values=values_template ) values = [tuple(data.values()) for data in data_list] self.cursor.executemany(query.as_string(self.connection), values) self.connection.commit() return data_list # 返回已插入的数据列表 except DatabaseError as e: logger.error(f"Error batch inserting data into {table}: {data_list}, Error: {e}") self.connection.rollback() raise def update(self, table, data, condition): """ 更新数据。 :param table: 表名 :param data: 字典格式的列和值 :param condition: WHERE 条件(字符串格式,如 "id = %s") :return: 影响的行数 """ try: set_clause = sql.SQL(', ').join( sql.SQL('{} = {}').format(sql.Identifier(k), sql.Placeholder()) for k in data.keys() ) query = sql.SQL("UPDATE {table} SET {set_clause} WHERE {condition}").format( table=sql.Identifier(table), set_clause=set_clause, condition=sql.SQL(condition) ) self.execute_query(query, tuple(data.values())) return self.cursor.rowcount # 返回影响的行数 except DatabaseError as e: logger.error(f"Error updating data in {table}: {data}, Condition: {condition}, Error: {e}") self.connection.rollback() raise def delete(self, table, condition): """ 删除数据。 :param table: 表名 :param condition: WHERE 条件(字符串格式,如 "id = %s") :return: 影响的行数 """ try: query = sql.SQL("DELETE FROM {table} WHERE {condition}").format( table=sql.Identifier(table), condition=sql.SQL(condition) ) self.execute_query(query) return self.cursor.rowcount # 返回影响的行数 except DatabaseError as e: logger.error(f"Error deleting data from {table}: Condition: {condition}, Error: {e}") self.connection.rollback() raise def begin_transaction(self): """ 开始事务。 """ try: self.connection.autocommit = False except DatabaseError as e: logger.error(f"Error beginning transaction: {e}") raise def commit_transaction(self): """ 提交事务。 """ try: self.connection.commit() self.connection.autocommit = True except DatabaseError as e: logger.error(f"Error committing transaction: {e}") self.connection.rollback() raise def rollback_transaction(self): """ 回滚事务。 """ try: self.connection.rollback() self.connection.autocommit = True except DatabaseError as e: logger.error(f"Error rolling back transaction: {e}") raise def close(self): """ 关闭数据库连接。 """ try: self.cursor.close() self.connection.close() except DatabaseError as e: logger.error(f"Error closing connection: {e}") raise if __name__ == "__main__": db_config = { "host": "localhost", "database": "test_db", "user": "postgres", "password": "password", "port": 5432 } handler = PostgresHandler(**db_config) try: # 插入单条数据 inserted_data = handler.insert("users", {"name": "Alice", "age": 25}) print("Inserted:", inserted_data) # 批量插入数据 batch_data = [ {"name": "Bob", "age": 30}, {"name": "Charlie", "age": 35} ] inserted_batch = handler.batch_insert("users", batch_data) print("Batch Inserted:", inserted_batch) # 查询数据 results = handler.execute_query("SELECT * FROM users WHERE age > %s", (20,)) print("Query Results:", results) # 更新数据 rows_affected = handler.update("users", {"age": 26}, "name = %s", ("Alice",)) print("Rows Affected by Update:", rows_affected) # 删除数据 rows_deleted = handler.delete("users", "name = %s", ("Charlie",)) print("Rows Deleted:", rows_deleted) finally: handler.close()