223 lines
7.0 KiB
Python
223 lines
7.0 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
# @Time : 2025/3/15 15:53
|
||
# @Author : AngesZhu
|
||
# @File : mysql_utils.py
|
||
# @Desc : mysql数据库操作类
|
||
import pymysql
|
||
from pymysql.err import MySQLError
|
||
|
||
|
||
class MySQLHandler:
|
||
def __init__(self, host, database, user, password, port=3306, **kwargs):
|
||
"""
|
||
初始化 MySQL 数据库连接。
|
||
:param host: 数据库主机
|
||
:param database: 数据库名称
|
||
:param user: 数据库用户名
|
||
:param password: 数据库密码
|
||
:param port: 端口号,默认 3306
|
||
"""
|
||
try:
|
||
self.connection = pymysql.connect(
|
||
host=host,
|
||
database=database,
|
||
user=user,
|
||
password=password,
|
||
port=port,
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
self.cursor = self.connection.cursor()
|
||
except MySQLError as e:
|
||
print(f"Error connecting to MySQL: {e}")
|
||
raise
|
||
|
||
def execute_query(self, query, params=None):
|
||
"""
|
||
执行 SQL 查询。
|
||
:param query: SQL 语句
|
||
:param params: 可选的查询参数
|
||
:return: 查询结果(如果有的话)
|
||
"""
|
||
try:
|
||
self.cursor.execute(query, params or ())
|
||
if query.strip().lower().startswith("select"):
|
||
return self.cursor.fetchall()
|
||
self.connection.commit()
|
||
except MySQLError as e:
|
||
print(f"Error executing query: {query}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def insert(self, table, data):
|
||
"""
|
||
插入数据。
|
||
:param table: 表名
|
||
:param data: 字典格式的列和值
|
||
:return: 已插入的数据
|
||
"""
|
||
try:
|
||
columns = ', '.join(data.keys())
|
||
placeholders = ', '.join(['%s'] * len(data))
|
||
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
|
||
self.execute_query(query, tuple(data.values()))
|
||
return data # 返回已插入的数据
|
||
except MySQLError as e:
|
||
print(f"Error inserting data into {table}: {data}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def batch_insert(self, table, data_list):
|
||
"""
|
||
批量插入数据。
|
||
:param table: 表名
|
||
:param data_list: 包含字典的列表,每个字典是一行数据
|
||
:return: 已插入的数据列表
|
||
"""
|
||
try:
|
||
if not data_list:
|
||
return []
|
||
|
||
columns = ', '.join(data_list[0].keys())
|
||
placeholders = ', '.join(['%s'] * len(data_list[0]))
|
||
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
|
||
values = [tuple(data.values()) for data in data_list]
|
||
self.cursor.executemany(query, values)
|
||
self.connection.commit()
|
||
return data_list # 返回已插入的数据列表
|
||
except MySQLError as e:
|
||
print(f"Error batch inserting data into {table}: {data_list}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def update(self, table, data, condition):
|
||
"""
|
||
更新数据。
|
||
:param table: 表名
|
||
:param data: 字典格式的列和值
|
||
:param condition: WHERE 条件(字符串格式,如 "id = %s")
|
||
:return: 影响的行数
|
||
"""
|
||
try:
|
||
set_clause = ', '.join(f"{key} = %s" for key in data.keys())
|
||
query = f"UPDATE {table} SET {set_clause} WHERE {condition}"
|
||
result = self.execute_query(query, tuple(data.values()))
|
||
return self.cursor.rowcount # 返回影响的行数
|
||
except MySQLError as e:
|
||
print(f"Error updating data in {table}: {data}, Condition: {condition}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def delete(self, table, condition):
|
||
"""
|
||
删除数据。
|
||
:param table: 表名
|
||
:param condition: WHERE 条件(字符串格式,如 "id = %s")
|
||
:return: 影响的行数
|
||
"""
|
||
try:
|
||
query = f"DELETE FROM {table} WHERE {condition}"
|
||
result = self.execute_query(query)
|
||
return self.cursor.rowcount # 返回影响的行数
|
||
except MySQLError as e:
|
||
print(f"Error deleting data from {table}: Condition: {condition}, Error: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def begin_transaction(self):
|
||
"""
|
||
开始事务。
|
||
"""
|
||
try:
|
||
self.connection.begin()
|
||
except MySQLError as e:
|
||
print(f"Error beginning transaction: {e}")
|
||
raise
|
||
|
||
def commit_transaction(self):
|
||
"""
|
||
提交事务。
|
||
"""
|
||
try:
|
||
self.connection.commit()
|
||
except MySQLError as e:
|
||
print(f"Error committing transaction: {e}")
|
||
self.connection.rollback()
|
||
raise
|
||
|
||
def rollback_transaction(self):
|
||
"""
|
||
回滚事务。
|
||
"""
|
||
try:
|
||
self.connection.rollback()
|
||
except MySQLError as e:
|
||
print(f"Error rolling back transaction: {e}")
|
||
raise
|
||
|
||
def get_tables(self):
|
||
# 获取所有表结构
|
||
tables = {}
|
||
try:
|
||
# 获取所有表结构
|
||
self.cursor.execute("SHOW TABLES")
|
||
temp_tables = [table[0] for table in self.cursor.fetchall()]
|
||
|
||
for table in temp_tables:
|
||
self.cursor.execute(f"SHOW CREATE TABLE {table}")
|
||
tables[table] = self.cursor.fetchone()[1]
|
||
finally:
|
||
self.cursor.close()
|
||
return tables
|
||
|
||
def close(self):
|
||
"""
|
||
关闭数据库连接。
|
||
"""
|
||
try:
|
||
self.cursor.close()
|
||
self.connection.close()
|
||
except MySQLError as e:
|
||
print(f"Error closing connection: {e}")
|
||
raise
|
||
|
||
|
||
if __name__ == "__main__":
|
||
db_config = {
|
||
"host": "localhost",
|
||
"database": "test_db",
|
||
"user": "root",
|
||
"password": "password",
|
||
"port": 3306
|
||
}
|
||
|
||
handler = MySQLHandler(**db_config)
|
||
|
||
try:
|
||
# 插入单条数据
|
||
inserted_data = handler.insert("users", {"name": "Alice", "age": 25})
|
||
print("Inserted:", inserted_data)
|
||
|
||
# 批量插入数据
|
||
batch_data = [
|
||
{"name": "Bob", "age": 30},
|
||
{"name": "Charlie", "age": 35}
|
||
]
|
||
inserted_batch = handler.batch_insert("users", batch_data)
|
||
print("Batch Inserted:", inserted_batch)
|
||
|
||
# 查询数据
|
||
results = handler.execute_query("SELECT * FROM users WHERE age > %s", (20,))
|
||
print("Query Results:", results)
|
||
|
||
# 更新数据
|
||
rows_affected = handler.update("users", {"age": 26}, "name = %s", ("Alice",))
|
||
print("Rows Affected by Update:", rows_affected)
|
||
|
||
# 删除数据
|
||
rows_deleted = handler.delete("users", "name = %s", ("Charlie",))
|
||
print("Rows Deleted:", rows_deleted)
|
||
|
||
finally:
|
||
handler.close()
|