Compare commits

...

3 Commits

Author SHA1 Message Date
qiangyanwen 25a096e781 add test 2023-01-07 18:24:15 +08:00
qiangyanwen f1e48984d7 add test 2023-01-07 18:22:22 +08:00
qiangyanwen c3534dff3c add test 2022-12-26 18:21:03 +08:00
7 changed files with 43 additions and 14 deletions

View File

@ -20,7 +20,6 @@ async def websocket_endpoint(
websocket: WebSocket websocket: WebSocket
): ):
await manager.connect(websocket) await manager.connect(websocket)
logger.info("ws开始连接....")
logger.info("websocket client ip==>{} port===>{}".format(websocket.client.host, websocket.client.port)) logger.info("websocket client ip==>{} port===>{}".format(websocket.client.host, websocket.client.port))
try: try:
while True: while True:
@ -32,5 +31,4 @@ async def websocket_endpoint(
except asyncio.TimeoutError: except asyncio.TimeoutError:
await manager.send_json(message, websocket) await manager.send_json(message, websocket)
except (WebSocketDisconnect, RuntimeError): except (WebSocketDisconnect, RuntimeError):
logger.info("ws开始关闭连接.....")
manager.disconnect(websocket) manager.disconnect(websocket)

View File

@ -6,13 +6,13 @@
from fastapi import APIRouter, Body, Depends, Query from fastapi import APIRouter, Body, Depends, Query
from typing_extensions import Annotated from typing_extensions import Annotated
from utils.jwt_token import parse_token
from config import settings from config import settings
from config.factory import AutomationResponse from config.factory import AutomationResponse
from entity.user_entity import UserFrom from entity.user_entity import UserFrom, ActiveUser
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.db_session import get_db from config.db_session import get_db
from repository.user_repository import login_user from repository.user_repository import login_user, active_user
from service.user_service import user_service from service.user_service import user_service
from utils.jwt_token import create_token from utils.jwt_token import create_token
from utils.response import success_200, error_211 from utils.response import success_200, error_211
@ -24,7 +24,7 @@ router = APIRouter(prefix="/api", tags=["用户模块"])
@router.post("/register/user", summary="注册用户", name="注册用户") @router.post("/register/user", summary="注册用户", name="注册用户")
async def register(user: Annotated[UserFrom, Body(...)], db: Session = Depends(get_db)): async def register(user: Annotated[UserFrom, Body(...)], db: Session = Depends(get_db)):
user = user_service(db, user) user = await user_service(db, user)
if isinstance(user, dict): if isinstance(user, dict):
return success_200(data=user, message="用户注册成功") return success_200(data=user, message="用户注册成功")
return error_211(message=user) return error_211(message=user)
@ -32,10 +32,24 @@ async def register(user: Annotated[UserFrom, Body(...)], db: Session = Depends(g
@router.post("/login", summary="用户登录", name="用户登录") @router.post("/login", summary="用户登录", name="用户登录")
async def login(user: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): async def login(user: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = login_user(db, user.username, user.password) user = await login_user(db, user.username, user.password)
if user[0]: if user[0]:
expire_time = timedelta(minutes=settings.ACCESS.ACCESS_TOKEN_EXPIRE_MINUTES) expire_time = timedelta(minutes=settings.ACCESS.ACCESS_TOKEN_EXPIRE_MINUTES)
token = create_token(AutomationResponse.model_to_dict(user[1]), expire_time) token = create_token(AutomationResponse.model_to_dict(user[1]), expire_time)
return {"code": 200, "message": "登录成功", "access_token": token, return {"code": 200, "message": "登录成功", "access_token": token,
"user": AutomationResponse.model_to_dict(user[1], "password"), "token_type": "Bearer"} "user": AutomationResponse.model_to_dict(user[1], "password"), "token_type": "Bearer"}
return error_211(message=user[1]) return error_211(message=user[1])
@router.post("/user/active", summary="用户启用和禁用", name="用户启用和禁用")
async def active(user: Annotated[ActiveUser, Body(...)], db: Session = Depends(get_db), user_id: int = Depends(parse_token)):
status = 0
if user.status == 0: # 正在启用
status = 1
if user.status == 1: # 正在禁用
status = 0
await active_user(db, user_id=user.id, status=status)
if status == 0: # 正在启用
return success_200(data=dict(status=1), message="禁用成功")
if status == 1: # 正在停用
return success_200(data=dict(status=0), message="启用成功")

View File

@ -10,3 +10,10 @@ class UserFrom(BaseModel):
username: str username: str
password: str password: str
email: EmailStr email: EmailStr
status: int = 1
class ActiveUser(BaseModel):
id: int
status: int

View File

@ -4,7 +4,7 @@
# @Author :qiangyanwen # @Author :qiangyanwen
# @File :model.py # @File :model.py
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, INT, String, DATETIME from sqlalchemy import Column, INT, String, DATETIME,Integer
from config.database import DatabaseModel from config.database import DatabaseModel
@ -17,6 +17,7 @@ class User(DatabaseModel):
username = Column(String(16), unique=True, index=True, comment="用户名") username = Column(String(16), unique=True, index=True, comment="用户名")
password = Column(String(256), comment="密码") password = Column(String(256), comment="密码")
email = Column(String(64), unique=True, nullable=False, comment="邮箱") email = Column(String(64), unique=True, nullable=False, comment="邮箱")
status = Column(Integer,nullable=False,comment="用户状态,0是禁用1是启用",default=1)
created_time = Column(DATETIME, comment='创建时间') created_time = Column(DATETIME, comment='创建时间')
deleted_time = Column(DATETIME, comment="更新时间") deleted_time = Column(DATETIME, comment="更新时间")
__table_args__ = ({'comment': '用户表'}) __table_args__ = ({'comment': '用户表'})

View File

@ -6,7 +6,7 @@
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from config.db_session import db_add from config.db_session import db_add, db_save
from config.factory import AutomationResponse from config.factory import AutomationResponse
from entity.user_entity import UserFrom from entity.user_entity import UserFrom
from enums.enums import RegisterUser from enums.enums import RegisterUser
@ -27,9 +27,15 @@ def register_user(db: Session, user: UserFrom) -> dict:
return AutomationResponse.model_to_dict(us_er, "password") return AutomationResponse.model_to_dict(us_er, "password")
def login_user(db: Session, username, password): async def login_user(db: Session, username, password):
password = get_md5_pwd(password) password = get_md5_pwd(password)
user = db.query(User).filter(User.username == username, User.password == password).first() user = await db.query(User).filter(User.username == username, User.password == password).first()
if user.username and user.email: if user.username and user.email:
return True, user return True, user
return False, RegisterUser.LOGIN_USER_ERROR.value return False, RegisterUser.LOGIN_USER_ERROR.value
async def active_user(db: Session, user_id: int, status: int):
user = await db.query(User).filter(User.id == user_id).first()
user.status = status
db_save(db, user)

View File

@ -11,7 +11,7 @@ from utils.jwt_token import get_md5_pwd
from repository.user_repository import register_user from repository.user_repository import register_user
def user_service(db: Session, user: UserFrom): async def user_service(db: Session, user: UserFrom):
user.password = get_md5_pwd(user.password) user.password = get_md5_pwd(user.password)
is_true = check_user_email(db, user) is_true = check_user_email(db, user)
if is_true: if is_true:

View File

@ -17,6 +17,7 @@ class SSHConnection(object):
self._client = None self._client = None
def __enter__(self): def __enter__(self):
print("客户端开始创建连接.....")
transport = paramiko.Transport((self._host, self._port)) transport = paramiko.Transport((self._host, self._port))
transport.connect(username=self._username, password=self._password) transport.connect(username=self._username, password=self._password)
self._transport = transport self._transport = transport
@ -32,7 +33,7 @@ class SSHConnection(object):
self._sftp = paramiko.SFTPClient.from_transport(self._transport) self._sftp = paramiko.SFTPClient.from_transport(self._transport)
self._sftp.put(local_path, remote_path) self._sftp.put(local_path, remote_path)
def exec_command(self, command): def command(self, command):
if self._client is None: if self._client is None:
self._client = paramiko.SSHClient() self._client = paramiko.SSHClient()
self._client._transport = self._transport self._client._transport = self._transport
@ -51,7 +52,9 @@ class SSHConnection(object):
self._transport.close() self._transport.close()
if self._client: if self._client:
self._client.close() self._client.close()
print("客户端关闭已成功.....")
with SSHConnection("47.96.135.132", 22, "root", "Qyw1994@520") as ssh: with SSHConnection("47.96.135.132", 22, "root", "Qyw1994@520") as ssh:
ls = ssh.exec_command("ls -l") ls = ssh.command("ls -l")
print(ls)