Pyspark で列をグループ化し、最大値を持つ行をフィルターする 質問する

Pyspark で列をグループ化し、最大値を持つ行をフィルターする 質問する

おそらく以前にもこの質問があったと思いますが、stackoverflowで検索私の質問に答えなかった。重複ではありません[2]最も頻繁な項目ではなく、最大値が必要なためです。私は pyspark を初めて使用しており、非常に単純なことを実行しようとしています。列 "A" をグループ化し、列 "B" に最大値を持つ各グループの行のみを保持します。次のようになります。

df_cleaned = df.groupBy("A").agg(F.max("B"))

残念ながら、これにより他のすべての列が破棄されます。df_cleaned には列「A」と列 B の最大値のみが含まれます。代わりに行を保持するにはどうすればよいですか? (「A」、「B」、「C」...)

ベストアンサー1

udfを使用せずにこれを行うことができますWindow

次の例を考えてみましょう。

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = sqlCtx.createDataFrame(data, ["A", "B"])
df.show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  5|
#|  a|  8|
#|  a|  7|
#|  b|  1|
#|  b|  3|
#+---+---+

Window列ごとにパーティションを作成し、これを使用して各グループの最大値を計算します。次に、列の値が最大値と等しくなるAように行をフィルター処理します。B

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

または、次のように同等に使用できますpyspark-sql

df.registerTempTable('table')
q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
sqlCtx.sql(q).show()
#+---+---+
#|  A|  B|
#+---+---+
#|  b|  3|
#|  a|  8|
#+---+---+

おすすめ記事