找回密码
 立即注册
首页 业界区 科技 Pytorch基础问题RuntimeError: Expected all tensors to ...

Pytorch基础问题RuntimeError: Expected all tensors to be on the same device

庾芷秋 2025-7-6 21:59:20
Pytorch基础问题RuntimeError: Expected all tensors to be on the same device

Introduction

今天让 Claude 4 Sonnet 给我写Nogo的reinforcement learning的训练代码,结果就直接报错:
  1. RuntimeError: Expected all tensors to be on the same device
复制代码
我平时不怎么注意细节,为了养成不依赖LLM的好习惯,以后对报错写博客记录一下
Main part

这个错误 RuntimeError: Expected all tensors to be on the same device 通常出现在把不同设备(比如 CPU 和 GPU)上的张量放在一起进行操作时:
  1. import torch
  2. # 一个张量在 CPU 上
  3. a = torch.tensor([1.0, 2.0])
  4. # 一个张量在 GPU 上(假设有 CUDA)
  5. b = torch.tensor([3.0, 4.0]).to("cuda")
  6. # 尝试把它们加起来会报错
  7. c = a + b  # RuntimeError: Expected all tensors to be on the same device
复制代码
解决方法是用 .to() 或 .cuda() 等方法把它们放到同一个设备上。
torch.device

还是上述代码
  1. print(a.device)
  2. print(b.device)
  3. #打印结果:
  4. # cpu
  5. # cuda:0
复制代码
设置device
  1. device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
复制代码
一般默认是第一块即 cuda:0
但也可以指定具体哪一个:
  1. # 把张量放在第1个GPU(编号0)
  2. a = torch.tensor([1.0, 2.0]).to("cuda:0")
  3. # 放在第2个GPU(编号1)
  4. b = torch.tensor([3.0, 4.0]).to("cuda:1")
复制代码
当然用我的笔记本必然报错:
  1. RuntimeError: CUDA error: invalid device ordinal**
复制代码
多卡的时候可以打印一下数量
  1. print(torch.cuda.device_count())
复制代码
torch.tensor().to()

torch.tensor().to() 是 PyTorch 中将张量转移到指定设备(如 CPU 或 GPU)上的方法。
这个函数可以用来显式地将数据放到某个设备上,以便后续运算不报错。
  1. dvc=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  2. #设置指定显卡
  3. a = torch.tensor([1.0, 2.0]).to(dvc)
  4. b = torch.tensor([3.0, 4.0]).to(dvc)
  5. print(a.device)
  6. print(b.device)
  7. c = a + b
  8. print(c)
  9. #打印结果:
  10. # cuda:0
  11. # cuda:0
  12. # tensor([4., 6.], device='cuda:0')
复制代码
Summary

显示调用张量位置是个好习惯,尤其是多卡训练的情况

  • 排查错误的关键线索
多卡环境下最常见的错误是:
  1. RuntimeError: Expected all tensors to be on the same device
复制代码
如果你在关键位置加上 .device 显示,可以快速发现是谁跑偏了。

  • 防止设备错配(如模型在 GPU0,数据在 GPU1)
  1. print("model on:", next(model.parameters()).device)
  2. print("inputs on:", inputs.device)
复制代码
这些信息一眼就能看出是否匹配。

  • 帮助调试和日志记录
在训练日志中打印 device 信息,比如:
  1. print(f"Epoch {epoch}: input.device={inputs.device}, label.device={labels.device}")
复制代码
可以让你在远程服务器、异步运行、多卡调度环境中更清楚程序状态。

来源:豆瓜网用户自行投稿发布,如果侵权,请联系站长删除

相关推荐

您需要登录后才可以回帖 登录 | 立即注册