data_factory/utils/mysql_utils.py

223 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2025/3/15 15:53
# @Author : AngesZhu
# @File : mysql_utils.py
# @Desc : mysql数据库操作类
import pymysql
from pymysql.err import MySQLError
class MySQLHandler:
def __init__(self, host, database, user, password, port=3306, **kwargs):
"""
初始化 MySQL 数据库连接。
:param host: 数据库主机
:param database: 数据库名称
:param user: 数据库用户名
:param password: 数据库密码
:param port: 端口号,默认 3306
"""
try:
self.connection = pymysql.connect(
host=host,
database=database,
user=user,
password=password,
port=port,
cursorclass=pymysql.cursors.DictCursor
)
self.cursor = self.connection.cursor()
except MySQLError as e:
print(f"Error connecting to MySQL: {e}")
raise
def execute_query(self, query, params=None):
"""
执行 SQL 查询。
:param query: SQL 语句
:param params: 可选的查询参数
:return: 查询结果(如果有的话)
"""
try:
self.cursor.execute(query, params or ())
if query.strip().lower().startswith("select"):
return self.cursor.fetchall()
self.connection.commit()
except MySQLError as e:
print(f"Error executing query: {query}, Error: {e}")
self.connection.rollback()
raise
def insert(self, table, data):
"""
插入数据。
:param table: 表名
:param data: 字典格式的列和值
:return: 已插入的数据
"""
try:
columns = ', '.join(data.keys())
placeholders = ', '.join(['%s'] * len(data))
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
self.execute_query(query, tuple(data.values()))
return data # 返回已插入的数据
except MySQLError as e:
print(f"Error inserting data into {table}: {data}, Error: {e}")
self.connection.rollback()
raise
def batch_insert(self, table, data_list):
"""
批量插入数据。
:param table: 表名
:param data_list: 包含字典的列表,每个字典是一行数据
:return: 已插入的数据列表
"""
try:
if not data_list:
return []
columns = ', '.join(data_list[0].keys())
placeholders = ', '.join(['%s'] * len(data_list[0]))
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
values = [tuple(data.values()) for data in data_list]
self.cursor.executemany(query, values)
self.connection.commit()
return data_list # 返回已插入的数据列表
except MySQLError as e:
print(f"Error batch inserting data into {table}: {data_list}, Error: {e}")
self.connection.rollback()
raise
def update(self, table, data, condition):
"""
更新数据。
:param table: 表名
:param data: 字典格式的列和值
:param condition: WHERE 条件(字符串格式,如 "id = %s"
:return: 影响的行数
"""
try:
set_clause = ', '.join(f"{key} = %s" for key in data.keys())
query = f"UPDATE {table} SET {set_clause} WHERE {condition}"
result = self.execute_query(query, tuple(data.values()))
return self.cursor.rowcount # 返回影响的行数
except MySQLError as e:
print(f"Error updating data in {table}: {data}, Condition: {condition}, Error: {e}")
self.connection.rollback()
raise
def delete(self, table, condition):
"""
删除数据。
:param table: 表名
:param condition: WHERE 条件(字符串格式,如 "id = %s"
:return: 影响的行数
"""
try:
query = f"DELETE FROM {table} WHERE {condition}"
result = self.execute_query(query)
return self.cursor.rowcount # 返回影响的行数
except MySQLError as e:
print(f"Error deleting data from {table}: Condition: {condition}, Error: {e}")
self.connection.rollback()
raise
def begin_transaction(self):
"""
开始事务。
"""
try:
self.connection.begin()
except MySQLError as e:
print(f"Error beginning transaction: {e}")
raise
def commit_transaction(self):
"""
提交事务。
"""
try:
self.connection.commit()
except MySQLError as e:
print(f"Error committing transaction: {e}")
self.connection.rollback()
raise
def rollback_transaction(self):
"""
回滚事务。
"""
try:
self.connection.rollback()
except MySQLError as e:
print(f"Error rolling back transaction: {e}")
raise
def get_tables(self):
# 获取所有表结构
tables = {}
try:
# 获取所有表结构
self.cursor.execute("SHOW TABLES")
temp_tables = [table[0] for table in self.cursor.fetchall()]
for table in temp_tables:
self.cursor.execute(f"SHOW CREATE TABLE {table}")
tables[table] = self.cursor.fetchone()[1]
finally:
self.cursor.close()
return tables
def close(self):
"""
关闭数据库连接。
"""
try:
self.cursor.close()
self.connection.close()
except MySQLError as e:
print(f"Error closing connection: {e}")
raise
if __name__ == "__main__":
db_config = {
"host": "localhost",
"database": "test_db",
"user": "root",
"password": "password",
"port": 3306
}
handler = MySQLHandler(**db_config)
try:
# 插入单条数据
inserted_data = handler.insert("users", {"name": "Alice", "age": 25})
print("Inserted:", inserted_data)
# 批量插入数据
batch_data = [
{"name": "Bob", "age": 30},
{"name": "Charlie", "age": 35}
]
inserted_batch = handler.batch_insert("users", batch_data)
print("Batch Inserted:", inserted_batch)
# 查询数据
results = handler.execute_query("SELECT * FROM users WHERE age > %s", (20,))
print("Query Results:", results)
# 更新数据
rows_affected = handler.update("users", {"age": 26}, "name = %s", ("Alice",))
print("Rows Affected by Update:", rows_affected)
# 删除数据
rows_deleted = handler.delete("users", "name = %s", ("Charlie",))
print("Rows Deleted:", rows_deleted)
finally:
handler.close()