您現在的位置是:網站首頁>JAVApytorch和numpy默認浮點類型位數詳解

pytorch和numpy默認浮點類型位數詳解

宸宸2024-06-01JAVA78人已圍觀

給大家整理一篇相關的編程文章,網友史英睿根據主題投稿了本篇教程內容,涉及到pytorch numpy、numpy默認浮點類型位數、pytorch默認浮點類型、pytorch和numpy默認浮點類型位數相關內容,已被978網友關注,相關難點技巧可以閲讀下方的電子資料。

pytorch和numpy默認浮點類型位數

pytorch和numpy默認浮點類型位數

numpy中默認浮點類型爲64位,pytorch中默認浮點類型位32位

測試代碼如下

  • numpy版本:1.19.2
  • pytorch版本:1.2.0
In [1]: import torch
In [2]: import numpy as np
# 版本信息
In [3]: "pytorch version: {}, numpy version: {}".format(torch.__version__, np.__version__)
Out[3]: 'pytorch version: 1.2.0, numpy version: 1.19.2'

# numpy
In [4]: dat_np = np.array([1,2,3], dtype="float")
In [5]: dat_np.dtype
Out[5]: dtype('float64')

# pytorch
In [6]: dat_torch = torch.tensor([1,2,3])
In [7]: dat_torch = dat_torch.float()
In [8]: dat_torch.dtype
Out[8]: torch.float32

pytorch和numpy的默認類型與轉換問題

pytorch對於浮點類型默認爲float32,而numpy的默認類型是float64,轉換的代碼:

torch.from_numpy(a).type(torch.FloatTensor)
torch.from_numpy(np.float32(a))

縂結

以上爲個人經騐,希望能給大家一個蓡考,也希望大家多多支持碼辳之家。

我的名片

網名:星辰

職業:程式師

現居:河北省-衡水市

Email:[email protected]