FastAPI 学习(二)
概要:本篇内容主要是对上篇文章的补充,主要的内容有FastAPI 参数接收、日志记录、用户文件上传、Demo 示例,后续将对这部分内容展开叙述;
1. FastAPI 参数接收
FastAPI 接收参数的方式与django和flask略有不同,下面将依次介绍用到的方法;
1.1 路径参数
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : fast_app.py
- @Time : 2025/6/2 22:47
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 学习使用 FastAPI 接收用户的参数的信息;
- """
- import uvicorn
- from fastapi import FastAPI
- app = FastAPI()
- @app.get("/items/{item_id}")
- async def read_item(item_id: int):
- """
- 通过 URL 的路径获取到参数的信息; 在路由装饰器中使用路径参数,通过函数参数接收。
- Args:
- item_id:
- Returns:
- """
- print(item_id) # >>> 1
- return {"item_id": item_id}
- if __name__ == '__main__':
- uvicorn.run(app, host='127.0.0.1', port=8000)
复制代码 说明:item_id 是路径参数,类型为 int,FastAPI 会自动解析路径中的参数并转换为指定类型。如果路径中没有匹配的参数,FastAPI 会返回 404 错误。
1.2 url 参数
使用?拼接在 url 末端的参数的信息;- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : fast_app.py
- @Time : 2025/6/2 22:47
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 学习使用 FastAPI 接收用户的参数的信息;
- """
- import uvicorn
- from fastapi import FastAPI, Query
- app = FastAPI()
- @app.get("/index")
- async def index(q: str = Query(None, max_length=50)):
- """
- 通过 URL拼接在函数的尾部; 在函数参数中定义查询参数,使用 Query 类型标注。
- Args:
- q:
- Returns:
- """
- results = {"items": [{"item_id": "Foo"}, {"item_id": "Bar"}]}
- if q:
- results.update({"q": q})
- return results
- if __name__ == '__main__':
- uvicorn.run(app, host='127.0.0.1', port=8000)
复制代码
说明:
- q 是查询参数,默认值为 None,max_length=50 表示查询参数的最大长度为 50。
- 如果没有传递查询参数 q,则使用默认值 None。
1.3 请求体接收
这种方式通常是使用 Pydantic 模型定义请求体的结构,通过函数参数接收- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : fast_app.py
- @Time : 2025/6/2 22:47
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 学习使用 FastAPI 接收用户的参数的信息;
- """
- import uvicorn
- from fastapi import FastAPI, Query
- from pydantic import BaseModel
- app = FastAPI()
- class Item(BaseModel):
- """ 定义 Pydantic 模型参数;
- """
- name: str
- description: str = None
- price: float
- tax: float = None
- @app.post("/items/")
- async def create_item(item: Item):
- print(item)
- return {"item": item}
- if __name__ == '__main__':
- uvicorn.run(app, host='127.0.0.1', port=8000)
复制代码
说明:
- Item是一个 Pydantic 模型,定义了请求体结构;
- FastAPI 会自动解析请求体中的JSON数据,并将其转换成为Item对象。
- 如果请求体中的数据不符合Item模型的定义,FastAPI 会返回 422 错误。
1.4 表单接收
定义方式:使用 Form 类型标注接收表单数据,主要用来接收前端 Form 表单校验形式的数据;- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : learn_arg.py
- @Time : 2025/6/3 22:42
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 学习 FastAPI 使用的参数解析形式;
- """
- import uvicorn
- from fastapi import FastAPI, Form
- app = FastAPI()
- @app.post("/login")
- def login(username: str = Form(...), password: str = Form(...)):
- return {"username": username, "password": password}
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
- username 和 password 是表单数据,使用 Form 类型标注。
- ... 表示该参数是必填项。
- FastAPI 会自动解析表单数据并将其传递给函数参数。
1.5 文件上传
定义方式:使用 File 和 UploadFile 类型标注接收文件。- import uvicorn
- from fastapi import FastAPI, File, UploadFile
- app = FastAPI()
- # @app.post("/login")
- # def login(username: str = Form(...), password: str = Form(...)):
- # return {"username": username, "password": password}
- @app.post("/files/")
- async def create_user(file: bytes = File(...)):
- # print(len(file.size))
- return {"file": len(file)}
- @app.post("/uploadfile/")
- async def create_upload_file(file: UploadFile = File(...)):
- return {"filename": file.filename}
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
- file 是上传的文件,使用 File 类型标注。
- UploadFile 是一个更高级的文件对象,包含文件名、内容类型等信息。
- FastAPI 会自动解析上传的文件并将其传递给函数参数。
1.6 请求头接收
在 FastAPI 中可以使用 Header 类型标注请求接收头;- import uvicorn
- from fastapi import FastAPI, File, UploadFile, Header
- app = FastAPI()
- @app.get("/items/")
- async def read_items(user_agent: str = Header(None)):
- return {"ua": user_agent}
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
1.7 依赖项接收
使用 Depneds 的方式来接收参数的信息;- import uvicorn
- from fastapi import FastAPI, File, UploadFile, Header, Depends
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker, Session
- app = FastAPI()
- engine = create_engine("my.db")
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- def get_db():
- db = SessionLocal()
- try:
- yield db
- except Exception as e:
- print(e)
- db.rollback()
- finally:
- db.close()
- @app.get("/items/")
- async def read_items(name: str, db: Session = Depends(get_db)):
- data = db.query() # 此处进行数据库的操作;
- return data
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码 接收依赖项, 对数据库进行操作;- from fastapi import FastAPI, Depends, Header
- app = FastAPI()
- async def verify_token(x_token: str = Header(...)):
- if x_token != "fake-super-secret-token":
- raise HTTPException(status_code=400, detail="X-Token header invalid")
- @app.get("/items/", dependencies=[Depends(verify_token)])
- async def read_items():
- return [{"item": "Foo"}, {"item": "Bar"}]
复制代码
- verify_token 是一个依赖项函数,验证请求头中的 X-Token 字段。
- dependencies=[Depends(verify_token)] 表示在调用路由函数之前,先调用 verify_token 函数。
额外补充:
多次依赖:FastAPI 的依赖注入机制会确保同一个依赖项在一次请求处理中只被调用一次,即使它在多个地方被引用。
1.8 背景任务接收
使用 BackgroundTasks 类型标注接收背景任务;- from fastapi import FastAPI, BackgroundTasks
- app = FastAPI()
- def write_notification(email: str, message=""):
- with open("log.txt", mode="w") as email_file:
- content = f"notification for {email}: {message}"
- email_file.write(content)
- @app.post("/send-notification/{email}")
- async def send_notification(email: str, background_tasks: BackgroundTasks):
- background_tasks.add_task(write_notification, email, message="some notification")
- return {"message": "Notification sent in the background"}
复制代码
- background_tasks 是一个背景任务对象,使用 BackgroundTasks 类型标注。
- background_tasks.add_task 方法可以添加背景任务,任务会在响应返回后执行。
- 这种场景适合做任务执行的开启,挂起任务到后台的时候使用;
1.9 请求对象接收
通过 Request 对象的参数进行接收;
这种方式,与 Flask 和 django 的获取参数的形式比较相似;- from fastapi import FastAPI, Request
- app = FastAPI()
- @app.get("/items/")
- async def read_items(request: Request):
- return {"client_host": request.client.host}
复制代码
这种方式获取的时候通常获取在请求头中的一些数据,在辅助验证的时候有着不错的效果;request 是一个请求对象,包含请求的完整信息,如客户端 IP 地址、请求头等。
通过以上多种方式,FastAPI 提供了灵活且强大的参数接收机制,可以根据实际需求选择合适的方式接收用户参数。
2. FastAPI 日志记录
额外补充:日志记录中的 Logger 与 Handler;
Logger 是日志系统的入口,负责产生日志记录,并可以设置日志级别来控制整个日志记录流程的初始过滤。Handler 则负责将日志记录发送到特定的目标位置,如文件、控制台、邮件等。一个 Logger 对象可以通过addHandler方法添加0到多个 Handler,每个 Handler 可以定义不同的日志级别,以实现日志分级过滤和显示;
- Logger
- 创建和管理:Logger 负责创建或获取用于发送日志记录消息的对象。用户可以通过logging.getLogger()创建或获取一个Logger对象。
- 日志级别:Logger 可以设置日志级别,控制日志记录的初始过滤。只有满足 Logger 级别要求的日志记录才会被处理和输出。
- 日志记录发送:Logger 对象用于记录日志消息,用户通过 Logger 对象发送日志记录。
- Handler
- 目标指定:Handler 指定如何处理和输出日志消息,可以将日志消息写入文件、控制台、网络等不同的目标。常见的 Handler子类包括StreamHandler(输出到控制台)、FileHandler(输出到文件)和SMTPHandler(输出到邮件)等。
- 日志级别设置:Handler 可以设置自己的日志级别,只有满足 Handler 级别要求的日志记录才会被处理和输出。Logger 和 Handler 的级别都需要满足,日志记录才会被处理和输出。
- 格式化:Handler 还可以设置日志格式化器(Formatter),定义日志记录的输出格式,如时间戳、日志级别、日志名称和消息内容等。
2.1 日志输出到文件
- import logging
- from fastapi import FastAPI
- # 配置日志
- logger = logging.getLogger(__name__)
- logger.setLevel(logging.INFO)
- # 创建文件处理器
- file_handler = logging.FileHandler('app.log')
- file_handler.setLevel(logging.INFO)
- # 创建控制台处理器
- console_handler = logging.StreamHandler()
- console_handler.setLevel(logging.INFO)
- # 创建日志格式
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
- file_handler.setFormatter(formatter)
- console_handler.setFormatter(formatter) # 设置终端输出的日志的格式;
- # 将文件处理器添加到日志器
- logger.addHandler(file_handler)
- logger.addHandler(console_handler) # 一个 logger 可以添加多个 handler;
- app = FastAPI()
- @app.get("/")
- async def read_root():
- logger.info("Handling root endpoint")
- return {"message": "Hello World"}
- """
- 在这个例子中:
- FileHandler 用于将日志写入文件 app.log。
- Formatter 用于定义日志的格式。
- addHandler 方法将文件处理器添加到日志器。
- """
复制代码 以上是传统的日志记录的方式,可以将我们的日志同时输出到终端和文件中,但是记录的仅仅是我们开发过程中自己定义的日志记录的信息,如果我们想要将终端的日志也输出到文件中,应该怎么操作呢?
基于启动的 uvicorn 实现;
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : base.py
- @Time : 2025/6/4 23:10
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 基于 uvicorn 实现日志的记录;
- """
- import logging
- from logging.handlers import RotatingFileHandler
- from fastapi import FastAPI
- import uvicorn
- # 创建 FastAPI 应用
- app = FastAPI()
- # 配置日志
- LOGGING_CONFIG = {
- "version": 1,
- "disable_existing_loggers": False,
- "formatters": {
- "default": {
- "()": "uvicorn.logging.DefaultFormatter",
- "fmt": "%(levelprefix)s %(message)s",
- "use_colors": None,
- },
- "access": {
- "()": "uvicorn.logging.AccessFormatter",
- "fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
- },
- },
- "handlers": {
- "default": {
- "formatter": "default",
- "class": "logging.handlers.TimedRotatingFileHandler",
- "filename": "./fastapi.log"
- },
- "access": {
- "formatter": "access",
- "class": "logging.handlers.TimedRotatingFileHandler",
- "filename": "./fastapi.log"
- },
- },
- "loggers": {
- "": {"handlers": ["default"], "level": "INFO"},
- "uvicorn.error": {"level": "INFO"},
- "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False},
- },
- }
- @app.get("/")
- async def read_root():
- return {"message": "Hello, World!"}
- if __name__ == "__main__":
- uvicorn.run(app, host="127.0.0.1", port=8000, log_config=LOGGING_CONFIG)
复制代码
2.2 中间件实现
基于 FastAPI 中间件实现日志的记录;- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : base.py
- @Time : 2025/6/4 23:10
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 基于 FastAPI 中间件实现日志的记录;
- """
- from fastapi import FastAPI, Request
- from loguru import logger
- import uvicorn
- # 创建 FastAPI 应用
- app = FastAPI()
- # 配置日志
- logger.add("./request.log", rotation="100 MB") # 自定义日志文件路径和大小限制
- @app.middleware("http")
- async def log_requests(request: Request, call_next):
- """
- 中间件:记录请求日志
- """
- logger.info(f"Request: {request.method} {request.url}")
- response = await call_next(request)
- logger.info(f"Response: {response.status_code}")
- return response
- @app.get("/")
- async def read_root():
- return {"message": "Hello, World!"}
- if __name__ == "__main__":
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
2.2 审计日志的记录
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : base.py
- @Time : 2025/6/4 23:10
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc :
- """
- import time
- import uvicorn
- from loguru import logger
- from fastapi import FastAPI, Request
- from starlette.middleware.base import BaseHTTPMiddleware
- # 配置日志
- logger.add("./request.log", rotation="100 MB") # 自定义日志文件路径和大小限制
- # 创建 FastAPI 应用
- app = FastAPI()
- class AuditLogMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- # 在请求到达路由处理函数之前执行的代码
- start_time = time.time()
- logger.info(f"Request received: {request.method} {request.url}")
- # 调用下一个中间件或路由处理函数
- response = await call_next(request)
- # 在响应返回客户端之前执行的代码
- process_time = time.time() - start_time
- logger.info(
- f"Request processed: {request.method} {request.url} - Status: {response.status_code} - Time: {process_time:.2f}s")
- return response
- # 添加审计日志中间件
- app.add_middleware(AuditLogMiddleware)
- @app.get("/")
- async def read_root():
- return {"message": "Hello World"}
- if __name__ == "__main__":
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
基于中间件的方式的时候可以记录到一些关于客户请求端的信息到日志中;
2.3 装饰器方式
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : base.py
- @Time : 2025/6/4 23:10
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 装饰器实现日志的记录
- """
- import logging
- from typing import Callable, Any
- from functools import wraps
- import uvicorn
- from fastapi import FastAPI
- # 配置日志
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- def log_decorator(module_name: str, user_description: str):
- """
- 日志装饰器,记录模块名称和用户描述
- :param module_name: 模块名称
- :param user_description: 用户描述
- """
- def decorator(func: Callable) -> Callable:
- @wraps(func)
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
- # 记录请求前的日志, 调用原有的参数
- logger.info(f"Module: {module_name} - Description: {user_description} - Request: {func.__name__}")
- # 调用原始函数
- result = await func(*args, **kwargs)
- # 记录请求后的日志
- logger.info(
- f"Module: {module_name} - Description: {user_description} - Response: {func.__name__} - Result: {result}")
- return result
- return wrapper
- return decorator
- app = FastAPI()
- @app.get("/")
- @log_decorator(module_name="主模块", user_description="访问根路径")
- async def read_root():
- return {"message": "Hello World"}
- @app.get("/items/{item_id}")
- @log_decorator(module_name="项目模块", user_description="获取项目详情")
- async def read_item(item_id: int):
- return {"item_id": item_id, "message": "Item retrieved"}
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码
此时日志已经输出到终端上面,并且包含了自定义的信息;如果需要用户的日志输出到别的地方如:消息队列、数据库、文件、API 接口的时候,用户可以自定义 Handdler 并且配置到 logger 中即可实现;
3. FastAPI 登录验证
登录验证的流程为常见的开发示例,可以开发完成之后备用为已有的封装服务,方便后续的开发工作的完善;
3.1 设定配置文件
本步骤的目的是为了设定存储数据的位置,以及用到的组件的信息;
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : settings.py
- @Time : 2025/6/5 23:10
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 配置文件
- """
- import os
- from pydantic import Field
- from pydantic_settings import BaseSettings, SettingsConfigDict
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker
- class Settings(object):
- """原生的配置 Python 文件的写法;
- """
- # 使用 SQL lite 作为数据库;
- DATABASE_URL = "sqlite:///data.db"
- engine = create_engine(DATABASE_URL, echo=True)
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- # SECRET_KEY
- ALGORITHM = "HS256"
- SECRET_KEY = "fshduifhue"
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
- class SettingsEnv(BaseSettings):
- """ 使用 .env 文件加载配置文件;
- """
- # SettingsConfigDict 如果环境变量存在相同名称的变量, 则覆盖文件内的, 否则使用文件内的键值对的变量形式;
- model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="forbid")
- DB_NAME: str = Field("data.db", description="数据库名称")
- REDIS_HOST: str = Field("127.0.0.1", description="redis 服务地址")
- REDIS_PORT: int = Field(6379, description="redis 端口号地址")
- REDIS_USER: str = Field("", description="redis 用户名")
- REDIS_PASSWORD: str = Field("", description="redis 密码")
- SECRET_KEY: str = Field("", description="FastAPI 加密秘钥")
- class DatabaseSession:
- """定义上下文处理连接;
- """
- def __init__(self):
- self.db = Settings.SessionLocal()
- def __enter__(self):
- return self.db
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.db.close()
复制代码 以上的两种模式可以自行选择,如果使用第二种,可以自行根据字段编写.env文件;
3.2 设置模型
本步骤的目的是用来设定登录用户中常见的数据库的模型以及 Pydantic 模型,方便后续的业务开展;
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : models.py
- @Time : 2025/6/5 23:07
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 定义数据库模型, 创建数据库表, 登录过程中创建一张表
- """
- from sqlalchemy import , Column, Integer, String
- from sqlalchemy.ext.declarative import declarative_base
- from user_web.settings import Settings
- Base = declarative_base()
- class User(Base):
- """用户表;
- """
- __tablename__ = "user"
- id = Column(Integer, primary_key=True, index=True, comment="用户ID")
- username = Column(String, unique=True, index=True, comment="用户名")
- hash_password = Column(String, comment="hash 密码")
- if __name__ == '__main__':
- # 创建数据库;
- Base.metadata.create_all(bind=Settings.engine)
复制代码 如需其他字段可以自行添加,本博客只做最简单的记录,不做其他的处理,后续补充可能会更新到 Git 上,可通过文章结尾的地址进行访问。- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : schemas.py
- @Time : 2025/6/5 23:42
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 定义 Pydantic 模型,辅助校验;
- """
- from pydantic import BaseModel, Field
- class User(BaseModel):
- id: int
- username: str
复制代码 3.3 ORM 操作封装
实现数据库常见操作的封装,本次操作数据库使用 SQLAlchemy,封装好的工具也可以在其他地方进行使用。
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : base_sqlalchemy_utils.py
- @Time : 2025/6/30 14:29
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 数据库操作的基类,定义些通用的数据操作方法
- """
- import math
- from sqlalchemy import desc, asc
- from sqlalchemy.orm.query import Query
- from sqlalchemy.orm.session import Session
- from sqlalchemy_utils import Choice
- class BaseDBOperateModel(object):
- @classmethod
- def get_list(cls, db_query: Query, filters: set, order: str = "-id", offset: int = 0, limit: int = 15) -> dict:
- """分页查询数据
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- filters: set;过滤条件
- order: str;排序规则,例:"+id,-create_time"
- offset: int;偏移量
- limit: int;取多少条
- Returns:
- dict;数据结果
- """
- # noinspection All
- count_number = db_query.filter(*filters).count()
- result = {
- "page": {
- "count": count_number,
- "total_page": cls.get_page_number(count_number, limit),
- "current_page": offset
- },
- }
- filter_result = list()
- if offset != 0:
- offset = (offset - 1) * limit
- if result["page"]["count"] > 0:
- # 禁用全部检查
- # noinspection All
- filter_query = db_query.filter(*filters)
- order_rules = cls.order_transfer(order)
- # noinspection All
- filter_result = filter_query.order_by(*order_rules).offset(offset).limit(limit).all()
- result["list"] = [cls.to_dict(c) for c in filter_result]
- return result
- @classmethod
- def get_all(cls, db_query: Query, filters: set, order: str = "-id", limit: int = 0) -> list:
- """获取所有符合条件的数据(有问题,暂不使用)
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- filters: set;过滤条件
- order: str;排序规则,例:"+id,-create_time"
- limit: int;取多少条
- Returns:
- list;多条查询数据结果
- """
- if not filters:
- result = db_query
- else:
- # noinspection All
- result = db_query.filter(*filters)
- order_rules = cls.order_transfer(order)
- result.order_by(*order_rules)
- if limit != 0:
- # noinspection All
- result = result.limit(limit)
- # noinspection All
- result = result.all()
- result = [cls.to_dict(c) for c in result]
- return result
- @classmethod
- def get_one(cls, db_query: Query, filters: set, order: str = "-id") -> dict:
- """获取一条符合条件的数据
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- filters: set;过滤条件
- order: str;排序规则,例:"+id,-create_time"
- Returns:
- dict;单条查询数据结果
- """
- # noinspection All
- result = db_query.filter(*filters)
- order_rules = cls.order_transfer(order)
- # noinspection All
- result = result.order_by(*order_rules).first()
- if result is None:
- return {}
- result = cls.to_dict(result)
- return result
- @staticmethod
- def insert(db_session: Session, model, data: dict) -> int:
- """插入一条数据
- Args:
- db_session: Session;数据库会话连接
- model: BaseModel;数据模型类对象
- data: set;插入数据
- Returns:
- int;插入成功后返回的id编号
- """
- users = model(**data)
- db_session.add(users)
- db_session.flush()
- return users.id
- @staticmethod
- def insert_all(db_session: Session, model, data: list) -> bool:
- """插入多条数据
- Args:
- db_session: Session;数据库会话连接
- model: BaseModel;数据模型类对象
- data: set;插入数据
- Returns:
- bool;插入多条数据成功返回True
- """
- users = list()
- for user_info in data:
- users.append(model(**user_info))
- db_session.add_all(users)
- db_session.commit()
- return True
- @staticmethod
- def update(db_query: Query, data: dict, filters: set) -> int:
- """修改符合条件的数据
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- data: dict;插入数据
- filters: set;过滤条件
- Returns:
- int;修改数据成功返回的行数
- """
- # noinspection All
- return db_query.filter(*filters).update(data, synchronize_session=False)
- @staticmethod
- def delete(db_query: Query, filters: set) -> int:
- """删除符合条件的数据
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- filters: set;过滤条件
- Returns:
- int;修改数据成功返回的行数
- """
- # noinspection All
- return db_query.filter(*filters).delete(synchronize_session=False)
- @staticmethod
- def get_count(db_query: Query, filters: set, field=False) -> int:
- """获取符合条件的数据数量
- Args:
- db_query: Query;数据库session绑定了查询数据库模型的查询类
- filters: set;过滤条件
- field: bool;是否指定字段计数
- Returns:
- int;
- """
- if field:
- # noinspection All
- return db_query.filter(*filters).scalar()
- else:
- # noinspection All
- return db_query.filter(*filters).count()
- @staticmethod
- def get_page_number(count: int, page_size: int) -> int:
- """获取总页数
- Args:
- count: int;数据总数
- page_size: int;分页大小
- Returns:
- int;总页数
- """
- page_size = abs(page_size)
- if page_size != 0:
- total_page = math.ceil(count / page_size)
- else:
- total_page = math.ceil(count / 5)
- return total_page
- @staticmethod
- def order_transfer(order: str):
- """order排序规则转换
- Args:
- order: str;排序规则,例:"+id,-create_time"
- Returns:
- list;排序规则列表
- """
- order_array = order.split(",")
- order_rules = list()
- for item in order_array:
- sort_rule = item[0]
- if sort_rule == "-":
- order_rules.append(desc(item[1:]))
- else:
- order_rules.append(asc(item[1:]))
- return order_rules
- @staticmethod
- def to_dict(model_obj):
- if not hasattr(model_obj, "_fields"):
- only = model_obj.__table__.columns
- result = {field.name: (
- getattr(model_obj, field.name).value if isinstance(getattr(model_obj, field.name), Choice) else getattr(
- model_obj, field.name)) for field in only}
- else:
- only = model_obj.keys()
- result = {field: (
- getattr(model_obj, field).value if isinstance(getattr(model_obj, field), Choice) else getattr(
- model_obj, field)) for field in only}
- return result
复制代码 上述的操作,没有设置数据库连接的获取与关闭,因为获取连接的方式封装在了配置文件内;
3.4 业务 API 的开发
API 常见业务的封装;
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : auth_logic.py
- @Time : 2025/6/5 23:17
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 认证的主要逻辑函数;
- """
- from typing import Optional, Dict
- from datetime import datetime, timedelta
- from jose import jwt, JWTError
- from passlib.context import CryptContext
- from fastapi import Depends, HTTPException
- from fastapi.security import OAuth2PasswordBearer
- from models import User
- from base_sqlalchemy_utils import BaseDBOperateModel
- from settings import Settings, DatabaseSession
- # 定义密码上下哈希文
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- # JWT 进行配置
- SECRET_KEY = Settings.SECRET_KEY
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
- # 设定密码模式
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- def verify_password(plain_password, hashed_password):
- """
- 校验原密码与 hash 存储的值;
- :param plain_password:
- :param hashed_password:
- :return:
- """
- return pwd_context.verify(plain_password, hashed_password)
- def get_password_hash(password):
- """
- 获取密码的 Hash 值;
- :param password:
- :return:
- """
- return pwd_context.hash(password)
- def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None):
- """
- 创建 token 函数;
- :param data:
- :param expires_delta:
- :return:
- """
- to_encode = data.copy()
- expire = datetime.utcnow() + timedelta(minutes=Settings.ACCESS_TOKEN_EXPIRE_MINUTES)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, Settings.SECRET_KEY, algorithm=Settings.ALGORITHM)
- return encoded_jwt
- async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict:
- """
- 实现 Token 对用户的解析;
- :param token:
- :return:
- """
- credentials_exception = HTTPException(
- status_code=401, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}
- )
- try:
- payload = jwt.decode(token, Settings.SECRET_KEY, algorithms=[Settings.ALGORITHM])
- username: str = payload.get("user_name")
- if username is None:
- raise credentials_exception
- except JWTError:
- raise credentials_exception
- with DatabaseSession() as db_session:
- user = BaseDBOperateModel.get_one(db_session.query(User), {User.username == username})
- if not user:
- raise credentials_exception
- return user
复制代码 核心的加密逻辑如上。- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : user_api.py
- @Time : 2025/6/5 23:29
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc : 使用用户的路由配置;
- """
- from typing import Dict
- from datetime import timedelta
- from fastapi import APIRouter, Depends, HTTPException
- from auth_logic import create_access_token, get_current_user
- from models import User
- from schemas import UserCreate
- from settings import DatabaseSession, Settings
- from auth_logic import get_password_hash, verify_password
- from base_sqlalchemy_utils import BaseDBOperateModel
- user_router = APIRouter(prefix="/user", tags=["user"])
- @user_router.post("/register")
- async def register(user_data: UserCreate):
- """
- 用户注册;
- user_data: UserCreate; pydantic.BaseModel; 设置参数的基本校验;
- :return:
- """
- # 1. 首先检查用户名是否存在, 存在返回异常的信息
- with DatabaseSession() as db_session:
- user = BaseDBOperateModel.get_one(db_session.query(User), filters={User.username == user_data.username})
- if user:
- raise HTTPException(status_code=400, detail="username already exists.")
- # 2. 用户名不存在, hash 密码,创建用户;
- hashed_password = get_password_hash(user_data.password)
- user = BaseDBOperateModel.insert(db_session, User,
- {"username": user_data.username, "hash_password": hashed_password})
- db_session.commit()
- return {"code": 200, "message": "success", "data": {"id": user, "username": user_data.username}}
- @user_router.post("/login-token")
- async def login(user_data: UserCreate):
- """
- 登录验证;
- :param user_data:
- :return:
- """
- # 1. 对参数的进行第一步的叫校验;
- user_name = user_data.username
- password = user_data.password
- with DatabaseSession() as db_session:
- user_obj = BaseDBOperateModel.get_one(db_session.query(User), filters={User.username == user_name})
- status_pwd = verify_password(password, user_obj.get("hash_password")) # 校验字段密码的信息;
- if not status_pwd:
- raise HTTPException(status_code=400, detail="invalid username or password")
- # 通过密码的校验, 进行 token 的创建;
- access_token = create_access_token(data={"user_id": user_obj.get("id"), "user_name": user_name},
- expires_delta=timedelta(minutes=Settings.ACCESS_TOKEN_EXPIRE_MINUTES))
- return {"status": 200, "message": "success", "data": {"access_token": access_token, "token_type": "bearer"}}
- @user_router.get("/home")
- async def home(current_user: Dict = Depends(get_current_user)):
- """
- 登录成功之后的跳转路由, 基于依赖注入做认证;
- :param current_user: User,
- :return:
- """
- print(current_user.get("username"), "登录成功了")
- return {"code": 200, "message": "Hello success message.", "data": {"id": current_user.get("id")}}
复制代码 API 部分的代码如上显示;
3.5 程序启动
提供启动方式,uvicorn,其他 corn 的启动方式可以百度参考。
- # !/usr/bin/env python
- # -*-coding:utf-8 -*-
- """
- @File : __init__.py.py
- @Time : 2025/6/5 23:03
- @Author : zi qing bao jian
- @Version : 1.0
- @Desc :
- """
- from fastapi import FastAPI
- from .user_api import user_router
- def create_app():
- app = FastAPI()
- app.include_router(user_router, prefix="/api")
- return app
复制代码 编写启动文件。- import uvicorn
- from user_web import create_app # 挂载路由的信息写入 __init__.py 的文件中, 对外部隐藏 app 细节;
- app = create_app()
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=8000)
复制代码 继续努力,终成大器。
来源:豆瓜网用户自行投稿发布,如果侵权,请联系站长删除 |