Tensorflow PrefetchDatasetからターゲットを抽出する 質問する

Tensorflow PrefetchDatasetからターゲットを抽出する 質問する

私はまだ tensorflow と keras を学習中ですが、この質問には非常に簡単な答えがあるのに、慣れていないためにそれを見逃しているのではないかと思います。

オブジェクトがありますPrefetchDataset:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...機能とターゲットで構成されています。forループを使用してこれを反復処理できます。

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

しかし、これは非常に遅いです。私がやりたいのは、クラス ラベルに対応するテンソルにアクセスし、それを numpy 配列、リスト、または scikit-learn の分類レポートや混同行列に入力できる任意の反復可能オブジェクトに変換することです。

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...または、Tensorflow の混同行列で使用できるようにデータにアクセスします。

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

どちらの場合も、計算コストがかからない方法で元のオブジェクトからターゲット データを取得できる一般的な機能は非常に役立ちます (また、Tensorflow と Keras に関する私の基本的な直感にも役立つ可能性があります)。

アドバイスをいただければ幸いです。

ベストアンサー1

これを を使用してリストに変換しlist(ds)、 を使用して通常のデータセットとして再コンパイルすることができますtf.data.Dataset.from_tensor_slices(list(ds))。そこから悪夢が再び始まりますが、少なくともそれは他の人が以前に経験した悪夢です。

より複雑なデータセット (ネストされた辞書など) の場合は、 を呼び出した後にさらに前処理が必要になりますlist(ds)が、これは質問の例では機能するはずです。

これは満足のいく答えからは程遠いですが、残念ながらこのクラスはまったく文書化されておらず、標準的なデータセットのトリックはどれも機能しません。

おすすめ記事