【原理+使用】DeepCache: Accelerating Diffusion Models for Free

论文:arxiv.org/pdf/2312.00858

代码:horseee/DeepCache: [CVPR 2024] DeepCache: Accelerating Diffusion Models for Free (github.com)

介绍


DeepCache是一种新颖的无训练且几乎无损的范式,从模型架构的角度加速了扩散模型。DeepCache利用 扩散模型顺序去噪步骤中观察到的固有时间冗余,缓存和检索相邻去噪阶段的特征,从而减少冗余计算。利用U-Net的特性,重用高级特征,同时以低成本的方式更新低级特征。将 Stable Diffusion v1.5 加速了 2.3 倍,CLIP 分数仅下降了 0.05 倍,LDM-4-G(ImageNet) 加速了 4.1 倍,FID 降低了 0.22。

动机:

由于顺序去噪过程和繁琐的模型尺寸,训练扩散模型会产生大量的计算成本。本文希望在没有额外训练的情况下,减少每个去噪步骤的计算开销,从而实现对扩散模型的无成本压缩。

背景:

反向扩散过程的加速。反向扩散过程的固有性质减慢了推理速度。目前的研究主要集中在两种加速扩散模型推理的方法上:

  1. 优化采样效率。侧重于减少采样步骤的数量。DDIM、一致性模型将随机噪声转换为初始图像,只需要进行一次模型评估。
  2. 优化结构效率。减少每个采样步骤的推理时间。

U-Net的高级和低级特征。由于跳跃式连接,UNet具有很强的合并低级和高级特征的能力。U-Net构建在堆叠的下采样和上采样块上,将输入图像编码为高级表示,然后对其进行解码,用于下游任务。表示为的块对,通过额外的跳过路径连接,直接将低级的信息从Di转发到Ui。在U-Net体系结构的前向传播过程中,数据通过两条路径并发地遍历:主分支和跳过分支。这些分支汇聚在一个连接模块,主分支提供处理过的高级特征,这些特征来自前面的上采样块Ui+1,而跳过分支提供来自对称块Di的相应特征。因此,U-Net模型的核心是来自跳过分支的低级特征和来自主分支的高级特征的连接:

原理

序列去噪中的特征冗余

去噪过程中的相邻步骤在高级特征上表现出显著的时间相似性。

图2实验揭示了两个主要观点:

  1. 在去噪过程中,相邻步骤之间,存在明显的时间特征相似性,表明连续步骤之间的变化通常较小。
  2. 无论使用哪种扩散模型,如稳定扩散、LDM和DDPM,对于每个时间步长,至少有10%的相邻时间步长与当前步长表现出高度相似(>0.95),这表明某些高级特征以渐进的速度变化。

每次计算,得到的特征都与前一步相似,存在大量冗余计算,产生边际效益。本文目标是利用这一特性来加速去噪过程。

扩散模型的深度缓存

DeepCache利用反向扩散过程中步骤之间的时间冗余来加速推理。从计算机系统中的缓存机制中获得灵感,结合了为随时间变化最小的元素设计的存储组件。应用于扩散模型,通过缓存那些变化缓慢的特征,来消除冗余计算,从而无需在后续步骤中重复计算

实现重点为U-Net中的跳过连接,它本质上提供了双路径优势:主分支需要大量的计算来遍行整个网络,而跳过分支只需要通过一些浅层,从而产生非常小的计算负载。主要分支中突出的特征相似性允许重用已经计算的结果,而不是为所有时间步重复计算。

去噪中的可缓存特性。

在两个连续时间步长 𝑡 和 𝑡−1 之间,根据反向过程,𝑥𝑡−1 将基于先前的结果 𝑥𝑡​ 进行条件生成。实验:首先生成 𝑥𝑡​,计算跨整个U-Net进行。为了获得下一个输出 𝑥𝑡−1,我们检索在先前时间步长 𝑡 中生成的高层次特征。即,考虑U-Net中的一个跳跃分支 𝑚,它连接 𝐷𝑚​ 和 𝑈𝑚​,在时间 𝑡 从先前的上采样块缓存特征图:

这是时间步长 𝑡 的主分支中的特征。这些缓存的特征将在后续推理中使用。

在下一个时间步长 𝑡−1 中,推理并不在整个网络上进行,只计算 m-th 跳跃分支中所需的部分,并用缓存中的特征替代主分支的计算。因此,时间步长 𝑡−1 中 𝑈𝑡−1𝑚​ 的输入可表示为:

𝐷𝑡−1𝑚​ 代表 m-th 下采样块的输出,如果选择一个较小的 𝑚,则只包含几层。例如,如果我们在第一层执行 DeepCache 并选择 𝑚=1,则只需要执行一个下采样块以获得。至于第二个特征 ​,由于可以简单地从缓存中检索,因此不需要额外的计算成本。过程如图3.

在第t - 1步,通过重用第t步缓存的特征,来生成xt - 1,并且为了更有效的推理,不执行D2, D3, U2, U3块。

扩展到1:N推理。缓存的特征计算一次,可以在后续的N−1步中重用,以取代原始的。对于所有去噪的T步,执行完全推理的时间步长序列为:

非均匀1:N推理。

基于1:N策略,在假定高级特征在连续N步中不变的前提下,成功地加速了扩散推理。然而,并非总是如此,特别是对于N,如图2(c)所示,特征的相似性并不是在所有步骤中都保持不变。对于像LDM这样的模型,特征的时间相似性会在去噪过程中显著降低40%左右。

因此,对于非均匀的1:N推理,我们倾向于对那些与相邻步骤相似度相对较小的步骤进行更多采样。在这里,执行完整推理的时间步长序列变为:

使用

import torch
from diffusers import StableDiffusionPipeline
from DeepCache import DeepCacheSDHelper

# 加载 Stable Diffusion 模型
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda:0")

# 创建 DeepCacheSDHelper 对象
helper = DeepCacheSDHelper(pipe=pipe)

# 设置缓存参数
helper.set_params(
    cache_interval=3,
    cache_branch_id=0,
)

# 启用缓存机制
helper.enable()

# 定义输入提示词
prompt = "a beautiful landscape with mountains and rivers"

# 生成图像
deepcache_image = pipe(
    prompt,
    output_type='pt'
).images[0]

# 禁用缓存机制
helper.disable()

 库:

diffusers==0.24.0
transformer

仅需要用DeepCache提供的Pipeline替换Diffusers库的Pipeline,即可实现扩散模型加速。目前支持 StableDiffusionPipeline 可以加载的模型。可以通过参数指定模型名称。

尝试1:将 DeepCacheSDHelper 应用于整个 pipeline,并确保缓存机制只启用一次

 pipe = Pose2VideoPipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    # 初始化 DeepCacheSDHelper
    helper = DeepCacheSDHelper(pipe=pipe)
    # 设置缓存参数
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    # 启用缓存机制
    helper.enable()

报错:

AttributeError: 'Pose2VideoPipeline' object has no attribute 'unet'

报错信息显示 Pose2VideoPipeline 对象没有 unet 属性,这说明 DeepCacheSDHelper 无法找到所需的 UNet 模型。要解决这个问题,必须确保传递给 DeepCacheSDHelper 的 pipeline 具有 unet 属性,并且该属性指向实际的 UNet 模型。

而 Pose2VideoPipeline 包含多个 UNet 模型( reference_unetdenoising_unet),需要对 DeepCacheSDHelper 进行修改,使其能够处理这种情况。一种解决方法是扩展 DeepCacheSDHelper 以接受多个 UNet 模型。解决方案:修改DeepCacheSDHelper类,pipe 和包含所有 UNet 模型的列表传递给 DeepCacheSDHelper:

    pipe = Pose2VideoPipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    # Initialize DeepCacheSDHelper with both UNet models
    helper = DeepCacheSDHelper(pipe=pipe, unets=[reference_unet, denoising_unet])
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    helper.enable()

尝试2:分别对 reference_unetdenoising_unet 初始化并启用 DeepCacheSDHelper

 reference_unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_base_model_path,
        subfolder="unet",
    ).to(dtype=weight_dtype, device="cuda")
    
    # Import the DeepCacheSDHelper
    helper = DeepCacheSDHelper(reference_unet=reference_unet)
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    helper.enable()

    inference_config_path = config.inference_config
    infer_config = OmegaConf.load(inference_config_path)
    denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        config.pretrained_base_model_path,
        config.motion_module_path,
        subfolder="unet",
        unet_additional_kwargs=infer_config.unet_additional_kwargs,
    ).to(dtype=weight_dtype, device="cuda")
    
    helper = DeepCacheSDHelper(denoising_unet=denoising_unet)
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,   # 指定缓存的分支 ID,上下两个unet是否需要不同分支?
    )
    helper.enable()

TypeError: DeepCacheSDHelper.__init__() got an unexpected keyword argument 'reference_unet', DeepCacheSDHelper 需要对 pipeline 中所有相关的 UNet 模型进行统一处理,而不是分别处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/782105.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

树(相关知识点)

目录 结点的度:某一个结点所含有字数的个数 叶节点:最后一个结点 非终端节点:不是叶结点 兄弟结点:亲兄弟结点 树的度:最大节点的度 层次:根为第一层,根的子结点为第二层,以此类推 森林&am…

[附源码]基于Flask的演唱会购票系统

摘要 随着互联网技术的普及和发展,传统购票方式因其效率低下、流程繁琐等问题已难以满足现代社会的需求。本文设计并实现了一个基于Flask框架的演唱会购票系统,该系统集成了用户管理、演唱会信息管理、票务管理以及数据统计与分析等功能模块&#xff0c…

linux centos7.9 安装mysql5.7;root设置客户端登录、配置并发、表名大小写敏感等

查看centos版本 cat /etc/centos-releasecentos版本为7.9 查看是否已安装mariadb,安装了需要先删除 1.查看是否安装了mariadb和mysql,安装了需要先删除 mariadb是mysql的一个分支,但要安装mysql需要删除它 执行rpm -qa|grep mariadb,查看mariadb情况…

Hi6602 恒压恒流SSR电源方案

Hi6602是一款针对离线式反激电源设计的高性能PWM控制器。Hi6602内集成有通用的原边恒流控制技术,可支持断续模式和连续模式工作,适用于恒流输出的隔离型电源应用中。Hi6602内部具有高精度65kHz开关频率振荡器,且带有抖频功能可优化EMI性能。H…

AI大模型技术分析

一文读懂:AI大模型! 引言 近年来,随着深度学习技术的迅猛发展,AI大模型已经成为人工智能领域的重要研究方向和热点话题。AI大模型,指的是拥有巨大参数规模和强大学习能力的神经网络模型,如BERT、GPT等&…

java IO流(1)

一. 文件类 java中提供了一个File类来表示一个文件或目录(文件夹),并提供了一些方法可以操作该文件 1. 文件类的常用方法 File(String pathname)构造方法,里面传一个路径名,用来表示一个文件boolean canRead()判断文件是否是可读文件boolean canWrite()判断文件是否是可写文…

spring boot读取yml配置注意点记录

问题1:yml中配置的值加载到代码后值变了。 现场yml配置如下: type-maps:infos:data_register: 0ns_xzdy: 010000ns_zldy: 020000ns_yl: 030000ns_jzjz: 040000ns_ggglyggfwjz: 050000ns_syffyjz: 060000ns_gyjz: 070000ns_ccywljz: 080000ns_qtjz: 090…

【论文通读】RuleR: Improving LLM Controllability by Rule-based Data Recycling

RuleR: Improving LLM Controllability by Rule-based Data Recycling 前言AbstractMotivationSolutionMethodExperimentsConclusion 前言 一篇关于提升LLMs输出可控性的短文,对SFT数据以规则的方式进行增强,从而提升SFT数据的质量,进而间接帮…

数组算法(二):交替子数组计数

1. 官方描述 给你一个二进制数组nums 。如果一个子数组中 不存在 两个 相邻 元素的值 相同 的情况,我们称这样的子数组为 交替子数组 。 返回数组 nums 中交替子数组的数量。 示例 1: 输入: nums [0,1,1,1] 输出: 5 解释&#…

数学系C++ 排序算法简述(八)

目录 排序 选择排序 O(n2) 不稳定:48429 归并排序 O(n log n) 稳定 插入排序 O(n2) 堆排序 O(n log n) 希尔排序 O(n log2 n) 图书馆排序 O(n log n) 冒泡排序 O(n2) 优化: 基数排序 O(n k) 快速排序 O(n log n)【分治】 不稳定 桶排序 O(n…

一.2.(4)放大电路静态工作点的稳定;(未完待续)

1.Rb对Q点及Au的影响 输入特性曲线:Rb减少,IBQ,UBEQ增大 输出特性曲线:ICQ增大,UCEQ减少 AUUO/Ui分子减少,分母增大,但由于分子带负号,所以|Au|减少 2.Rc对Q点及Au的影响 输入特性曲…

【密码学】什么是密码?什么是密码学?

一、密码的定义 根据《中华人民共和国密码法》对密码的定义如下: 密码是指采用特定变换的方法对信息等进行加密保护、安全认证的技术、产品和服务。 二、密码学的定义 密码学是研究编制密码和破译密码的技术科学。由定义可以知道密码学分为两个主要分支&#x…

【做一道算一道】和为 K 的子数组

给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1: 输入:nums [1,1,1], k 2 输出:2 示例 2: 输入:nums [1,2,3],…

深度学习图像生成与分割模型详解:从StyleGAN到PSPNet

文章目录 Style GANDeeplab-v3FCNAdversarial AutoencodersHigh-Resolution Image Synthesis with Latent Diffusion ModelsNeRF: Representing Scenes as Neural Radiance Fields for View SynthesisPyramid Scene Parsing Network Style GAN 输入是一个潜在向量 (z)&#xff…

嵌入式开发SPI基本介绍与应用

目录 #SPI通信协议 #SPI基础概念 #SPI通信模式 #SPI通信时序类型 前言:本篇笔记参考嘉立创的开发文档,连接放在最后。 #SPI通信协议 #SPI基础概念 Serial Peripheral Interface 缩写SPI 翻译:串行外设接口 同步串行通信协议&…

FMEA在大型光伏电站安全生产管理中的应用

一、FMEA概述 FMEA(Failure Modes and Effects Analysis)即失效模式和影响分析,是一种用于识别和分析产品或过程中潜在故障模式及其影响的方法。它通过对产品或过程中可能出现的故障模式进行系统性地梳理和分析,评估其可能的影响…

Miniconda的常见用法——以Isaacgym为例

1. ubuntu24.04安装minicondda mkdir -p ~/miniconda3 wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh解释下这段代码 bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3~/miniconda3/miniconda.sh: 指向Mi…

【笔记】记一次redis将从节点变成主节点 主节点变成从节点

1.连上虚拟机centos7 2.打开finalshell连接虚拟机 将从节点变为主节点 输出redis-cli -p 要变成主节点的从节点 -a此从节点的密码 输入 replicaof no one 查看端口状态 info replication 总结: redis-cli -p 端口号 -a 密码 replicaof no one info replicati…

STM32第十七课:连接云平台进行数据传输

目录 需求一、云平台项目创建二、代码编写1.导入MQTT包2.连接阿里云3.发布数据 三、关键代码总结 需求 1.通过生活物联网平台设计一个空气质量检测仪app。 2.连接阿里云平台将硬件数据传输到云端,使手机端能够实时收到。 一、云平台项目创建 先进入阿里云生活服务…

cs231n 作业3

使用普通RNN进行图像标注 单个RNN神经元行为 前向传播: 反向传播: def rnn_step_backward(dnext_h, cache):dx, dprev_h, dWx, dWh, db None, None, None, None, Nonex, Wx, Wh, prev_h, next_h cachedtanh 1 - next_h**2dx (dnext_h*dtanh).dot(…