2团
Published on 2026-03-06 / 1 Visits
0
0

基于Chainlit实现Qwen3问答Web界面:钩子机制与功能实现

原本想使用flask加上模板引擎来实现一个简单的Web界面来记录AI历史会话,但是整了半天都没调整好样式。于是考虑找一个简单的前端开源项目使用,最好是python的,这样就可以和模型推理代码放在一起,结果发现了Chainlit这个专为AI应用设计的Web框架。很好,不用自己写就行。

项目代码已上传至Github,地址qwen3-local-chat

摘要:本文详细介绍如何利用 Chainlit 框架的钩子机制构建 Qwen3 Web 问答界面,涵盖数据持久化、用户认证、流式输出、思考过程展示及性能监控等核心功能的实现方案,为快速搭建 AI 聊天应用提供完整参考。

1. Chainlit简介

1.1 什么是Chainlit?

Chainlit是一个专为AI聊天应用设计的Python框架,旨在简化大语言模型(LLM)应用的Web界面开发。它提供了一套完整的解决方案,让开发者能够快速构建类似ChatGPT的对话界面,而无需处理复杂的前端开发工作。

1.2 Chainlit的核心特性

  • 内置聊天UI组件: 提供美观、响应式的聊天界面,支持Markdown渲染、代码高亮、图片展示等;

  • 流式输出支持: 原生支持LLM的流式token输出,提供流畅的打字机效果;

  • 数据持久化层: 内置数据层抽象,支持多种数据库(PostgreSQL、SQLite等)存储会话数据;

  • 用户认证机制: 提供多种认证方式(密码、OAuth、JWT等),支持多用户管理;

  • 自定义步骤(Step)展示: 支持将复杂的处理过程分解为多个步骤,清晰展示AI的思考过程;

  • 钩子机制: 提供丰富的生命周期钩子,允许开发者在特定时机插入自定义逻辑。

项目中,因为是自己整的玩项目,那当然是选择SQLite作为数据存储层,能降低系统占用当然是好事。

2. Chainlit钩子机制

2.1 什么是钩子(Hook)?

钩子是Chainlit框架提供的生命周期回调函数,允许开发者在特定事件发生时执行自定义逻辑。Chainlit通过装饰器(Decorator)的方式定义钩子,开发者只需在函数上添加相应的装饰器即可。

2.2 Chainlit核心钩子列表

钩子名称

装饰器

触发时机

用途

数据层钩子

@cl.data_layer

应用启动时

注册数据持久化层

认证钩子

@cl.password_auth_callback

用户登录时

验证用户身份

会话开始钩子

@cl.on_chat_start

新会话创建时

初始化会话状态

消息钩子

@cl.on_message

用户发送消息时

处理用户输入

会话恢复钩子

@cl.on_chat_resume

恢复历史会话时

重建会话上下文

3. 围绕钩子的功能实现

3.1 数据层钩子:@cl.data_layer

数据层钩子用于注册Chainlit的数据持久化层。Chainlit本身不直接操作数据库,而是通过数据层抽象接口与各种数据库交互。

3.1.1 实现代码

@cl.data_layer
def get_data_layer():
    """使用SQLAlchemy + aiosqlite将Chainlit会话持久化到本地SQLite"""
    db_url = f"sqlite+aiosqlite:///{Config.CHAINLIT_DB_PATH}"
    return SQLAlchemyDataLayer(conninfo=db_url)

实现考量:

  • 使用aiosqlite驱动,支持异步操作

  • 数据库路径配置为用户主目录下的.qwen3_chainlit.db

  • SQLAlchemyDataLayer是Chainlit提供的通用数据层实现

3.1.2 数据库表初始化

在项目运行初始化阶段,需要确保相关数据库表已经创建,具体建表语句如下:

def _init_chainlit_db():
    """确保 Chainlit 数据层所需的 SQLite 表已创建。
    SQLAlchemyDataLayer 本身不建表(设计面向已有 PostgreSQL schema),
    因此需要在应用启动前用 sqlite3 手动初始化。
    """
    import sqlite3

    db_path = str(Config.CHAINLIT_DB_PATH)
    log.info(f"初始化 Chainlit 数据库: {db_path}")
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.executescript(
        """
        CREATE TABLE IF NOT EXISTS users (
            "id"         TEXT PRIMARY KEY,
            "identifier" TEXT NOT NULL UNIQUE,
            "metadata"   TEXT NOT NULL,
            "createdAt"  TEXT
        );

        CREATE TABLE IF NOT EXISTS threads (
            "id"             TEXT PRIMARY KEY,
            "createdAt"      TEXT,
            "name"           TEXT,
            "userId"         TEXT,
            "userIdentifier" TEXT,
            "tags"           TEXT,
            "metadata"       TEXT,
            FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
        );

        CREATE TABLE IF NOT EXISTS steps (
            "id"            TEXT PRIMARY KEY,
            "name"          TEXT NOT NULL,
            "type"          TEXT NOT NULL,
            "threadId"      TEXT NOT NULL,
            "parentId"      TEXT,
            "streaming"     BOOLEAN NOT NULL DEFAULT 0,
            "waitForAnswer" BOOLEAN,
            "isError"       BOOLEAN,
            "metadata"      TEXT,
            "tags"          TEXT,
            "input"         TEXT,
            "output"        TEXT,
            "createdAt"     TEXT,
            "command"       TEXT,
            "start"         TEXT,
            "end"           TEXT,
            "generation"    TEXT,
            "showInput"     TEXT,
            "language"      TEXT,
            "indent"        INT,
            "defaultOpen"   BOOLEAN,
            FOREIGN KEY ("threadId") REFERENCES threads("id") ON DELETE CASCADE
        );
    """
    )

    conn.commit()
    conn.close()
    log.success("Chainlit 数据库表已就绪")

表结构说明:

  • users: 存储用户信息(ID、标识符、元数据);

  • threads: 存储会话线程信息(会话ID、名称、用户关联);

  • steps: 存储对话步骤(用户消息、助手回复、思考过程、性能统计)。

3.2 认证钩子:@cl.password_auth_callback

认证钩子用于验证用户登录凭据。Chainlit的历史会话功能需要用户认证才能按用户区分会话,否则所有会话会混在一起。

实现代码:

@cl.password_auth_callback
def auth_callback(username: str, password: str):
    """简单密码认证。可按需替换为数据库校验。"""
    if (username, password) == ("admin", "admin"):
        return cl.User(
            identifier="admin",
            metadata={"role": "admin", "provider": "credentials"},
        )
    return None

实现细节:

  • 接收用户名和密码作为参数;

  • 验证成功返回cl.User对象;

  • 验证失败返回None

  • identifier是用户的唯一标识,用于区分不同用户。

这里为了简化实现,直接使用了固定的用户名和密码;如果不设置用户名和密码的话,就无从解锁历史会话功能了。

3.3 会话开始钩子:@cl.on_chat_start

会话开始钩子在用户打开新会话时触发。这是初始化会话状态、加载模型的最佳时机。

3.3.1 实现代码

@cl.on_chat_start
async def on_chat_start():
    """用户打开新会话时触发:首次加载模型。"""
    log.info("新会话开始")
    cl.user_session.set("messages", [])

    if srv.model is None:
        loading = cl.Message(content="⏳ 正在加载模型,请稍候…")
        await loading.send()
        success = await cl.make_async(_ensure_model)()
        content = (
            "✅ 模型加载完成,开始对话吧!"
            if success
            else "❌ 模型加载失败,请检查日志。"
        )
        loading.content = content
        await loading.update()
    else:
        log.info("模型已就绪,跳过加载")
  • 使用cl.user_session.set()存储会话级别的数据;

  • 显示加载提示消息,提升用户体验;

  • 使用cl.make_async()将同步函数包装为异步函数;

  • 通过await loading.update()更新消息内容。

3.3.2 模型加载函数

def _ensure_model() -> bool:
    """如果模型尚未加载则加载;返回是否成功。"""
    if srv.model is not None:
        log.info("模型已加载,跳过初始化")
        return True
    model_name = os.environ.get("QWEN3_MODEL", Config.DEFAULT_MODEL)
    log.info(f"开始加载模型: {model_name}")
    try:
        srv.init_model(model_name)
        if srv.model is not None:
            log.success(f"模型加载成功: {model_name}")
        else:
            log.error("模型加载后仍为 None")
        return srv.model is not None
    except Exception as e:
        log.error(f"模型加载失败: {e}")
        srv.cleanup_model()
        return False

3.4 消息钩子:@cl.on_message

消息钩子是整个应用的核心,处理用户发送的每条消息。它负责:

  • 获取历史消息上下文;

  • 调用模型进行推理;

  • 流式展示思考过程和回答;

  • 记录性能统计;

  • 保存会话数据。

3.4.1 实现代码

@cl.on_message
async def on_message(message: cl.Message):
    """
    处理用户消息:
      1. 流式展示思考过程(Step)
      2. 流式展示答案(Message)
      3. 显示性能统计(Step)
      4. 持久化到 SQLite
    """
    if srv.model is None:
        await cl.make_async(_ensure_model)()
        if srv.model is None:
            await cl.Message(content="❌ 模型未就绪,请重启服务。").send()
            return

    messages: List[Dict] = list(cl.user_session.get("messages", []))

    user_text = message.content.strip()
    messages.append({"role": "user", "content": user_text})
    log.info(f"收到用户消息: {user_text[:80]}{'…' if len(user_text) > 80 else ''}")

    raw_chunks: List[str] = []
    think_done = False
    answer_started = False
    t_start = time.time()

    # ── 性能监控 ──
    perf_monitor = PerformanceMonitor("Generation")
    perf_monitor.start()
    perf_monitor.mark_think_start()

    # ── 1. 创建思考过程 Step(先创建,后续流式更新)──
    think_step = cl.Step(name="💭 思考过程", type="tool")
    await think_step.send()

    # ── 2. 创建答案 Message(初始为空,检测到</think> 后开始流式更新)──
    answer_msg = cl.Message(content="")

    # ── 3. 流式处理 token ──
    for token in _stream_tokens(messages):
        raw_chunks.append(token)
        perf_monitor.record_token(is_think=not think_done)
        joined = "".join(raw_chunks)

        if not think_done:
            # 仍在思考阶段,检测 </think>
            end_match = re.search(r"</think>", joined, re.IGNORECASE)
            if end_match:
                # 检测到思考结束
                think_done = True
                perf_monitor.mark_think_end()
                perf_monitor.mark_answer_start()
                elapsed = time.time() - t_start
                log.info(f"检测到 </think>,思考阶段耗时 {elapsed:.1f}s")

                # 提取完整思考内容并更新 Step
                think_match = re.search(r"<think>(.*?)</think>", joined, re.DOTALL | re.IGNORECASE)
                if think_match:
                    think_step.output = think_match.group(1).strip()
                else:
                    think_step.output = "(无法提取思考内容)"
                await think_step.update()

                # 开始发送答案消息(先发送空消息,后续流式更新)
                await answer_msg.send()
                answer_started = True
            else:
                # 仍在思考中,流式更新思考内容
                think_start = re.search(r"<think>(.*)", joined, re.DOTALL | re.IGNORECASE)
                if think_start:
                    partial_think = think_start.group(1).strip()
                    if partial_think:
                        think_step.output = partial_think
                        await think_step.update()
                await asyncio.sleep(0)
        else:
            # 思考完成,流式输出答案
            if not answer_started:
                await answer_msg.send()
                answer_started = True

            # 从 token 中过滤特殊标记后流式发送
            cleaned_token = _STRIP_SPECIAL.sub("", token)
            if cleaned_token:
                await answer_msg.stream_token(cleaned_token)

        if Config.ENABLE_RESOURCE_MONITORING and think_done:
            perf_monitor.log_streaming_stats()

    perf_monitor.mark_answer_end()

    # ── 4. 确保答案内容正确(最终解析) ──
    raw_full = "".join(raw_chunks)
    raw_full = _STRIP_SPECIAL.sub("", raw_full)
    elapsed_total = time.time() - t_start
    log.info(f"流式生成结束,总耗时 {elapsed_total:.1f}s")

    think_content, clean_answer = srv.parse_response(raw_full)
    log.info(f"思考内容长度: {len(think_content)} 字符")
    log.info(f"答案长度: {len(clean_answer)} 字符")

    # 如果没有检测到思考过程,补充处理
    if not think_done:
        think_step.output = think_content or "(无思考过程)"
        await think_step.update()
        if not answer_started:
            answer_msg.content = clean_answer or "(空回答)"
            await answer_msg.send()

    # 最终确保答案内容正确(修正可能的流式累积误差)
    if answer_started and clean_answer:
        answer_msg.content = clean_answer
        await answer_msg.update()

    # 保存多轮上下文
    messages.append({"role": "assistant", "content": clean_answer})
    cl.user_session.set("messages", messages)

    # ── 3. 显示性能统计 ──
    perf_monitor.log_stats("total")

    # 创建性能统计 Step
    think_stats = perf_monitor.get_think_stats()
    answer_stats = perf_monitor.get_answer_stats()
    total_stats = perf_monitor.get_total_stats()
    sys_stats = perf_monitor.get_system_stats()

    perf_lines = ["⚡ **性能统计**\n\n"]

    if think_stats:
        perf_lines.append(
            f"**思考阶段**: {think_stats['duration']:.2f}s, "
            f"{think_stats['token_count']} tokens, "
            f"{think_stats['tokens_per_second']:.1f} tokens/s\n\n"
        )

    if answer_stats:
        perf_lines.append(
            f"**回答阶段**: {answer_stats['duration']:.2f}s, "
            f"{answer_stats['token_count']} tokens, "
            f"{answer_stats['tokens_per_second']:.1f} tokens/s\n\n"
        )

    if total_stats:
        perf_lines.append(
            f"**总计**: {total_stats['duration']:.2f}s, "
            f"{total_stats['token_count']} tokens, "
            f"{total_stats['tokens_per_second']:.1f} tokens/s\n\n"
        )

    if sys_stats:
        perf_lines.append(f"**系统资源**\n")
        perf_lines.append(
            f"CPU: {sys_stats.get('cpu_percent', 0):.1f}%, "
            f"Memory: {sys_stats.get('memory_percent', 0):.1f}% "
            f"({sys_stats.get('memory_used_gb', 0):.1f}GB/{sys_stats.get('memory_total_gb', 0):.1f}GB)\n"
        )

        if sys_stats.get("gpu_stats"):
            for gpu_name, gpu_data in sys_stats["gpu_stats"].items():
                perf_lines.append(
                    f"{gpu_name}: {gpu_data['allocated_gb']:.1f}/{gpu_data['total_gb']:.1f}GB "
                    f"({gpu_data['memory_percent']:.1f}%)\n"
                )

    perf_step = cl.Step(name="📊 性能统计", type="tool")
    perf_step.output = "".join(perf_lines)
    await perf_step.send()

3.4.2 流式生成实现

# Qwen3 特殊 token,需要从输出中过滤(但保留 <think>/</ think>)
_STRIP_SPECIAL = re.compile(r"<\|im_start\|>|<\|im_end\|>|<\|endoftext\|>")

def _stream_tokens(messages: List[Dict]) -> Generator[str, None, None]:
    """
    使用 TextIteratorStreamer 逐 token 生成,在独立线程中跑 model.generate。
    yields str tokens.
    """
    import torch

    chat_messages = [{"role": "system", "content": "You are a helpful assistant."}]
    chat_messages.extend(messages)

    if hasattr(srv.tokenizer, "apply_chat_template"):
        prompt_text = srv.tokenizer.apply_chat_template(
            chat_messages, tokenize=False, add_generation_prompt=True
        )
    else:
        prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in chat_messages)
        prompt_text += "\nassistant:"

    log.info(f"Prompt 长度: {len(prompt_text)} 字符")

    inputs = srv.tokenizer(prompt_text, return_tensors="pt").to(srv.model.device)
    input_len = inputs["input_ids"].shape[-1]
    log.info(f"输入 token 数: {input_len}")

    # ── 关键修复: skip_special_tokens=False ──
    # Qwen3 将 <think>/</ think> 注册为 special tokens,
    # 若 skip_special_tokens=True 会导致这些 tag 被过滤,
    # 使得流式解析永远检测不到 </think>,看不到回答。
    streamer = TextIteratorStreamer(
        srv.tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=60.0
    )

    # 收集所有可能的 eos token id(含 <|im_end|> 等)
    eos_ids = srv.tokenizer.eos_token_id
    if isinstance(eos_ids, int):
        eos_ids = [eos_ids]
    im_end_token = srv.tokenizer.convert_tokens_to_ids("<|im_end|>")
    if isinstance(im_end_token, int) and im_end_token not in eos_ids:
        eos_ids.append(im_end_token)
    log.info(f"eos_token_ids: {eos_ids}")

    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=Config.MAX_NEW_TOKENS,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=eos_ids,
    )

    thread = threading.Thread(target=srv.model.generate, kwargs=gen_kwargs, daemon=True)
    thread.start()
    log.info("生成线程已启动,开始流式输出…")

    token_count = 0
    for token in streamer:
        # 过滤掉 <|im_end|> 等非 think 类特殊 token
        cleaned = _STRIP_SPECIAL.sub("", token)
        if cleaned:
            token_count += 1
            yield cleaned

    thread.join()
    log.info(f"生成完成,共输出 {token_count} 个有效 token 片段")

3.5 会话恢复钩子:@cl.on_chat_resume

会话恢复钩子在用户从历史会话列表点击恢复时触发。它负责从数据库读取历史消息,重建内存中的消息上下文。

3.5.1 实现代码

@cl.on_chat_resume
async def on_chat_resume(thread):
    """用户从历史会话列表点击恢复时触发,重建内存中的消息上下文。"""
    log.info(f"恢复历史会话: {thread.get('id', 'unknown')}")
    messages = []
    for step in thread.get("steps", []):
        if step["type"] == "user_message":
            # 用户消息的内容可能在 input 或 output 字段中,尝试两者
            content = step.get("output") or step.get("input") or ""
            if content:  # 只有在有内容时才添加
                messages.append({"role": "user", "content": content})
            else:
                log.warning("用户消息缺少内容,跳过")
        elif step["type"] == "assistant_message":
            # 助手消息的内容应该在 output 字段中
            content = step.get("output") or step.get("input") or ""
            if content:  # 只有在有内容时才添加
                messages.append({"role": "assistant", "content": content})
            else:
                log.warning("助手消息缺少内容,跳过")
    cl.user_session.set("messages", messages)
    log.info(f"已恢复 {len(messages)} 条消息到内存上下文")

实现细节:

  • thread参数包含会话的所有steps;

  • 需要区分user_messageassistant_message类型;

  • 消息内容可能在outputinput字段中,需要兼容处理;

  • 恢复后保存到cl.user_session,供后续对话使用。

4. 总结

通过Chainlit的钩子机制,我们构建了一个功能完善的Qwen3 Web问答界面,相当Easy:

钩子

功能

实现效果

@cl.data_layer

数据持久化

SQLite存储会话数据

@cl.password_auth_callback

用户认证

简单密码登录

@cl.on_chat_start

会话初始化

加载模型、初始化上下文

@cl.on_message

消息处理

流式生成、思考展示、性能统计

@cl.on_chat_resume

会话恢复

历史会话上下文重建


Comment