当前位置:首页 > 文章列表 > 文章 > python教程 > PyTorch 无循环张量多对一求和方法

PyTorch 无循环张量多对一求和方法

2026-04-09 18:45:44 0浏览 收藏
本文揭秘了如何利用 PyTorch 的 `scatter_add_` 原语,结合 `repeat_interleave` 和索引展平技巧,以完全向量化、零 Python 循环的方式高效实现一维张量到另一维张量的“一对多”映射累加(如多源值聚合至目标位置),不仅大幅提升 GPU 并行计算效率、保持梯度可导性,还显著简化代码逻辑——告别慢速循环与手动索引遍历,让复杂映射操作变得简洁、健壮且生产就绪。

如何在 PyTorch 中高效实现张量的一对多映射求和(无显式循环)

本文介绍使用 torch.Tensor.scatter_add_ 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。

本文介绍使用 `torch.Tensor.scatter_add_` 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。

在 PyTorch 中处理「一对多」映射关系(即每个输入元素贡献至多个输出位置)并执行聚合(如求和)时,若采用 Python 循环或列表推导,不仅代码冗长,更会严重拖慢训练速度、破坏计算图完整性,且无法充分利用 GPU 并行能力。幸运的是,PyTorch 提供了高度优化的原语——scatter_add,它专为这类“按索引分散累加”场景设计,可一次性完成全部映射与聚合。

核心思想是将不规则映射结构(如嵌套列表 mapping)转化为两个齐次一维张量:

  • src:待累加的源值序列,其中每个 input[i] 根据其映射目标数量被重复;
  • index:对应的目标位置索引序列,与 src 严格对齐;
  • out:初始化为零的输出张量,长度由最大目标索引决定。

以下为完整实现示例:

import torch

# 输入定义
input = torch.tensor([0, 1, 2, 3], dtype=torch.float32)
mapping = [[1], [0, 2, 4], [0, 3], [1, 2]]

# 步骤 1:计算各输入项的重复次数(即每个 input[i] 映射到多少个 output 位置)
reps = torch.tensor([len(x) for x in mapping])

# 步骤 2:构建 src —— 按 reps 重复 input 中每个元素
src = input.repeat_interleave(reps)  # tensor([0, 1, 1, 1, 2, 2, 3, 3])

# 步骤 3:构建 index —— 展平 mapping,得到所有 (src[i] → output[j]) 的 j 序列
index = torch.tensor([j for sublist in mapping for j in sublist])  # tensor([1, 0, 2, 4, 0, 3, 1, 2])

# 步骤 4:初始化输出张量(长度 = max(index) + 1)
out = torch.zeros(max(index) + 1, dtype=src.dtype)

# 步骤 5:执行向量化累加:out[j] += src[i] for each (i,j) pair
result = out.scatter_add(dim=0, index=index, src=src)

print(result)  # tensor([3., 3., 4., 2., 1.])

关键优势

  • 全程无 Python 循环,100% 张量操作,支持 CUDA 加速;
  • 时间复杂度为 O(∑|mapping[i]|),空间复杂度为 O(len(output)),理论最优;
  • 自动兼容梯度传播(scatter_add 是可微分操作),适用于模型中间层。

⚠️ 注意事项

  • index 中的索引必须是非负整数,且严格小于 out.size(dim),否则抛出 RuntimeError;
  • 若 mapping 可能为空(如 []),需提前过滤或用 max(index, default=0) 防御;
  • 当 output 维度极大但稀疏时,该方法仍会分配全量内存;如需极致稀疏支持,可考虑结合 torch.sparse 或自定义 CUDA kernel,但绝大多数场景 scatter_add 已足够高效。

总结而言,scatter_add 是解决 PyTorch 中「一对多映射+聚合」问题的标准、简洁且高性能方案。掌握其与 repeat_interleave、索引展平等组合技巧,能显著提升数据预处理与自定义层的表达力与执行效率。

好了,本文到此结束,带大家了解了《PyTorch 无循环张量多对一求和方法》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多文章知识!

Win11关闭U盘自动播放方法Win11关闭U盘自动播放方法
上一篇
Win11关闭U盘自动播放方法
Go语言解析JSON文件流程全解析
下一篇
Go语言解析JSON文件流程全解析
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    543次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    516次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    500次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    485次学习
查看更多
AI推荐
  • ljg-skills -
    ljg-skills
    ljg-skills 是李继刚开源的 AI 技能与提示词集合,面向大模型使用者整理了一批可复用的 prompt、角色设定和任务技能模板,适合用于学习提示词设计、搭建个人 AI 工作流和沉淀团队常用智能体能力。
    215次使用
  • MELO音乐 - AI 音乐生成平台,支持多模态创作能力
    MELO音乐
    MELO音乐是一站式AI视频与音乐制作助手,对标suno, udio的高品质体验。提供伴奏生成、原创写词、无损导出、哼唱识曲、混音变声等全套音频与短视频编辑工具。无论是流行Kpop、电音说唱、民谣古风、摇滚儿歌还是商用轻音乐,MELO为你免费谱曲,轻松做同款!
    237次使用
  • UniScribe - AI 免费在线音视频转文字平台
    UniScribe
    UniScribe 是一款 AI 音视频转文字与内容整理工具,支持上传音频、视频文件或粘贴 YouTube 链接,自动生成转写文本、摘要、思维导图和关键问题,并支持多格式导出,适合会议记录、课程学习、访谈整理和内容创作复盘。
    207次使用
  • 剧云 - 免费 AI 智能中文剧本创作平台
    剧云
    剧云是专业中文剧本创作平台,安全稳定运行十余年,集成AI编剧、剧本医生审核、人物小传、剧情关系图、大纲编写、多人协作、Word导入导出、版权管控功能,数据安全防护,轻松高效创作剧本。
    372次使用
  • 万象有声 - AI 一站式有声内容创作平台
    万象有声
    万象有声,一个专为有声创作者打造的新一代智能有声内容创作平台。平台提供专业的智能拆章、智能画本编辑、AI配音、AI生成音效、后期制作、智能对轨、智能审听等有声创作全流程工具,可以帮助创作者高效、低成本创作出引人入胜的有声作品。立即体验,让有声书制作更简单!
    371次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码