四时宝库

程序员的知识宝库

分享一波实用的PyTorch常用代码段

1、固定随机种子

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

2、指定程序运行在特定 GPU 卡上

在命令行指定环境变量

CUDA_VISIBLE_DEVICES=0,1 python train.py

3、如果想要避免这种结果波动,设置

torch.backends.cudnn.deterministic = True

4、np.ndarray 与 PIL.Image 转换

# np.ndarray -> PIL.Image.
image = PIL.Image.fromarray(ndarray.astypde(np.uint8))
# PIL.Image -> np.ndarray.
ndarray = np.asarray(PIL.Image.open(path))

5、水平翻转

PyTorch 不支持 tensor[::-1] 这样的负步长操作,水平翻转可以用张量索引实现。

# Assume tensor has shape N*D*H*W.tensor = tensor[:, :, :, torch.arange

6、矩阵乘法

# Matrix multiplication: (m*n) * (n*p) -> (m*p).
result = torch.mm(tensor1, tensor2)
# Batch matrix multiplication: (b*m*n) * (b*n*p) -> (b*m*p).
result = torch.bmm(tensor1, tensor2)
# Element-wise multiplication.
result = tensor1 * tensor2

↓↓↓

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接