详解 @符号在 PyTorch 中的矩阵乘法规则

news/2025/2/24 14:48:59

详解 @ 符号在 PyTorch 中的矩阵乘法规则

在 PyTorch 和 NumPy 中,@ 符号被用作矩阵乘法运算符,它本质上等价于 torch.matmul()numpy.matmul(),用于执行张量之间的矩阵乘法。

在本篇博客中,我们将深入探讨:

  • @ 运算符的基本概念
  • @ 在不同维度张量上的计算规则
  • @(d, k) @ (d, 1) 这种情况下的运算细节
  • PyTorch 自动广播机制
  • 代码示例与直观理解

1. 什么是 @

在 Python 3.5 之后,@ 被引入作为 矩阵乘法运算符,它在 NumPyPyTorch 中与 matmul() 等价。例如:

import numpy as np

A = np.array([[1, 2], [3, 4]])
B = np.array([[5], [6]])

C = A @ B  # 矩阵乘法
print(C)

输出:

[[17]
 [39]]

等价于

C = np.matmul(A, B)

PyTorch 中,@ 也适用于张量计算:

import torch
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5], [6]], dtype=torch.float32)

C = A @ B  # PyTorch 版本的矩阵乘法
print(C)

2. @ 在不同维度张量上的计算规则

2.1 规则概述

@ 的运算规则依赖于输入张量的维度:

  1. 两个标量(0D):返回标量
  2. 标量和张量:标量与张量的元素逐个相乘
  3. 一维向量(1D)
    • (N,) @ (N,) → 标量(点积)
    • (N,) @ (N, M) → (M,)(左向量 × 矩阵
    • (N, M) @ (M,) → (N,)矩阵 × 右向量)
  4. 二维矩阵(2D)
    • (N, M) @ (M, K) → (N, K)(标准矩阵乘法)
  5. 高维张量(≥3D)
    • (A, B, C) @ (C, D) → (A, B, D)(批量矩阵乘法)

3. 重点解析 (d, k) @ (d, 1)

PyTorch 中,如果 A.shape = (d, k)B.shape = (d, 1)A @ B非法操作,因为矩阵乘法要求 A 的列数(k)等于 B 的行数(d),但这里 B 的形状 (d, 1) 无法与 (d, k) 匹配。

3.1 (d, k) @ (d, 1) 为什么不合法?

假设:

import torch
d, k = 4, 3

A = torch.randn(d, k)  # (4, 3)
B = torch.randn(d, 1)  # (4, 1)

C = A @ B  # ❌ 错误:形状不匹配

会报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x1)

原因:

  • 矩阵乘法规则: A 的列数(k)必须等于 B 的行数(d)。
  • (d, k) @ (d, 1) 不符合这个规则,因为 d ≠ k

3.2 如何让 (d, k) @ (d, 1) 变成合法操作?

我们需要 调整矩阵的形状,使其满足矩阵乘法的规则。

方法 1:交换操作数顺序

如果计算 B.T @ A

C = B.T @ A  # shape (1, d) @ (d, k) → (1, k)

就变成了合法操作。

方法 2:转置 A

如果我们计算:

C = A.T @ B  # shape (k, d) @ (d, 1) → (k, 1)

这个计算是 合法的,因为 A.T.shape = (k, d)B.shape = (d, 1),满足矩阵乘法规则。

示例:

C = A.T @ B  # (k, d) @ (d, 1) → (k, 1)

现在 A.T 变成 (k, d)B 仍然是 (d, 1),最终 C 的形状是 (k, 1)


3.3 PyTorch 如何正确处理 (d, k) @ (d,)

在 PyTorch 代码中,我们常见这样的计算:

q = P_q @ x  # (h, d, k) @ (d,)

为什么这里不需要转置 P_q

  • x.shape = (d,),PyTorch 自动扩展为 (d, 1) 使其成为列向量
  • 计算 (d, k) @ (d, 1)非法的,PyTorch 自动调整计算规则
  • PyTorch 实际执行的是 P_q.T @ x,确保计算正确
  • 最终返回 (h, k),去掉了多余的维度

因此 PyTorch 不需要我们手动转置 P_q,它会自动处理 x 为列向量进行计算!


4. 代码示例

import torch

d, k = 4, 3
torch.manual_seed(42)

A = torch.randn(d, k)  # (4, 3)
x = torch.randn(d)     # (4,)

# PyTorch 自动扩展 x,使其符合矩阵乘法规则
C = A.T @ x  # (k, d) @ (d,) → (k,)

print("A shape:", A.shape)  # (4, 3)
print("x shape:", x.shape)  # (4,)
print("C shape:", C.shape)  # (3,)

5. 结论

  • @矩阵乘法运算符,等价于 torch.matmul(A, B)
  • (d, k) @ (d, 1) 是不合法的矩阵乘法
  • PyTorch 会自动扩展 (d,) → (d, 1) 并进行正确的矩阵计算
  • (d, k) @ (d,) 实际等价于 (k, d) @ (d, 1),避免了显式转置

🚀 PyTorch 的 @ 计算规则很智能,能够自动扩展维度,让矩阵乘法符合数学规则! 🎯

q = P_q @ x 计算中,P_q.T 转置的是哪个维度?如何判断?

在 PyTorch 代码:

q = P_q @ x  # (h, d, k) @ (d,)

核心问题

  • P_q.shape = (h, d, k)
  • x.shape = (d,)

为什么 不需要手动转置 P_q?以及 PyTorch 在计算 P_q @ x 时转置了哪个维度


1. @ 运算规则

PyTorch 处理 torch.matmul(A, B) 时,遵循 广播机制矩阵乘法规则

  1. 最后两个维度 参与矩阵乘法
  2. 如果 B 是 1D 张量(即 B.shape = (d,)),PyTorch 会自动扩展为 (d, 1) 但不会影响计算逻辑

2. q = P_q @ x 具体计算

2.1 P_q.shape = (h, d, k), x.shape = (d,)

按照 PyTorch 规则:

  1. 扩展 x 形状
    • x.shape = (d,) 自动扩展为 (d, 1),使其符合矩阵乘法规则:
    x = x.unsqueeze(-1)  # (d,) → (d, 1)
    
  2. 选择 P_q 参与矩阵乘法的维度
    • P_q.shape = (h, d, k),表示:
      • h:注意力头数(不参与矩阵计算)
      • d:输入维度(x 匹配
      • k:查询维度(计算目标)
    • P_q @ x 的计算目标是:
      ( h , d , k ) @ ( d , 1 ) (h, d, k) @ (d, 1) (h,d,k)@(d,1)
      需要 P_q d 维度与 xd 维度对齐,才能进行矩阵乘法。

2.2 PyTorch 自动调整 P_q 计算方式

PyTorch 不会转置完整的 P_q,但会 调整最后两个维度 (d, k) 进行计算

  • 等价于
    q = ( h , k , d ) @ ( d , 1 ) = ( h , k , 1 ) q = (h, k, d) @ (d, 1) = (h, k, 1) q=(h,k,d)@(d,1)=(h,k,1)
  • 等价于
    q = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1))  # shape (h, k, 1)
    
    其中 P_q.transpose(-2, -1) 交换 (d, k)(k, d)

最终 PyTorch 计算:

q = (h, d, k) @ (d,) = (h, k)

其中 PyTorch 自动去除了 1 维度,返回 (h, k),而不是 (h, k, 1)


3. 如何判断 PyTorch 进行了哪些维度调整?

我们可以用 transpose()matmul() 手动验证

import torch

h, d, k = 2, 4, 3  # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42)

P_q = torch.randn(h, d, k)  # shape (h, d, k)
x = torch.randn(d)  # shape (d,)

# PyTorch 计算
q1 = P_q @ x  # (h, d, k) @ (d,) → (h, k)

# 手动转置 + matmul
q2 = torch.matmul(P_q.transpose(-2, -1), x.unsqueeze(-1)).squeeze(-1)  # (h, k)

print("q1 shape:", q1.shape)  # (h, k)
print("q2 shape:", q2.shape)  # (h, k)
print(torch.allclose(q1, q2))  # True

结果:

q1 shape: torch.Size([2, 3])
q2 shape: torch.Size([2, 3])
True

说明 PyTorch 自动进行了 P_q.transpose(-2, -1),使 d 维度匹配 xd 维度


4. 结论

💡 PyTorch 只会转置 P_qd, k 维度,确保矩阵乘法合法,但不会改变 h 维度

判断 PyTorch 何时自动调整维度

操作等效 PyTorch 计算
(d, k) @ (d,)自动转置 (d, k)(k, d), 计算 (k, d) @ (d, 1)
(h, d, k) @ (d,)自动调整 (d, k)(k, d), 计算 (h, k, d) @ (d, 1)
(d, k) @ (k, 1)直接符合矩阵乘法规则,正常计算
(h, d, k) @ (k, 1)符合矩阵乘法规则,正常计算

5. 关键点总结

P_qd, k 维度会被 PyTorch 自动调整,以匹配 x.shape = (d,)
PyTorch 计算 (h, d, k) @ (d,),本质等价于 P_q.transpose(-2, -1) @ x.unsqueeze(-1)
最终 q.shape = (h, k),符合多头注意力计算要求

🚀 PyTorch 的 @ 操作非常智能,会自动调整张量的形状,使矩阵乘法符合数学规则! 🎯

后记

2025年2月23日07点49分于上海,在GPT4o大模型辅助下完成。


http://www.niftyadmin.cn/n/5864485.html

相关文章

Kafka系列之:记录一次源头数据库刷数据,造成数据丢失的原因

Kafka系列之:记录一次源头数据库刷数据,造成数据丢失的原因 一、背景二、查看topic日志信息三、结论四、解决方法一、背景 源头数据库在很短的时间内刷了大量的数据,部分数据在hdfs丢失了 理论上debezium数据采集不会丢失,就需要排查数据链路某个节点是否有数据丢失。 数据…

Unity Android SDK 升级、安装 build-tools、platform-tools

Unity Android SDK 升级、安装 build-tools、platform-tools 通过 Unity Hub 安装的 Android SDK 需要下载 特定版本的 build-tools、platform-tools 如何操作? 以 Unity 2022.3.26f1 为例,打开安装目录,找到如下目录 2022.3.26f1\Editor\…

0083.基于springboot+uni-app的社区车位租赁系统小程序+论文

一、系统说明 基于springbootuni-app的社区车位租赁系统小程序,系统功能齐全, 代码简洁易懂,适合小白学编程。 现如今,信息种类变得越来越多,信息的容量也变得越来越大,这就是信息时代的标志。近些年,计算机科学发展…

Windows 主机与安卓设备网线直连配置教程

在一些特殊场景下,我们可能需要在 Windows 主机没有联网的情况下,与安卓设备通过网线直连进行通信。本文将详细介绍具体的配置步骤。 一、硬件准备 一根网线(直通线或交叉线,具体取决于设备接口)。 一台支持以太网连…

20-R 绘图 - 饼图

R 绘图 - 饼图 R 语言提供来大量的库来实现绘图功能。 饼图,或称饼状图,是一个划分为几个扇形的圆形统计图表,用于描述量、频率或百分比之间的相对关系。 R 语言使用 pie() 函数来实现饼图,语法格式如下: pie(x, l…

基于Springboot医院预约挂号小程序系统【附源码】

基于Springboot医院预约挂号小程序系统 效果如下: 小程序主页面 帖子页面 医生账号页面 留言内容页面 管理员主页面 用户管理页面 我的挂号页面 医生管理页面 研究背景 随着信息技术的飞速发展和互联网医疗的兴起,传统的医疗服务模式正面临着深刻的变…

一周学会Flask3 Python Web开发-flask3上下文全局变量session,g和current_app

锋哥原创的Flask3 Python Web开发 Flask3视频教程: 2025版 Flask3 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili flask3提供了session,g和current_app上下文全局变量来方便我们操作访问数据。 以下是一个表格,用于比较Flask中的…

《Operating System Concepts》阅读笔记:p87-p94

《Operating System Concepts》学习第 12 天,p87-p94 总结,总计 8 页。 一、技术总结 1.Android The Android operating system was designed by the Open Handset Alliance (led primarily by Google) and was developed for Android smartphones an…