airflow.providers.google.cloud.utils.mlengine_prediction_summary

一個由 DataFlowPythonOperator 呼叫的模板,用於彙總批次預測 (BatchPrediction)。

它接受一個使用者函式來計算預測結果中每個例項的指標,然後進行聚合並輸出彙總。

它接受以下引數

  • --prediction_path: 包含批次預測結果的 GCS 資料夾,其中包含 JSON 格式的 prediction.results-NNNNN-of-NNNNN 檔案。輸出也將儲存在此資料夾中,名稱為 ‘prediction.summary.json’。

  • --metric_fn_encoded: 一個編碼的函式,用於計算並返回給定例項(以字典形式)的一個或多個指標元組。它應該透過 base64.b64encode(dill.dumps(fn, recurse=True)) 進行編碼。

  • --metric_keys: 彙總輸出中聚合指標的逗號分隔鍵。鍵的順序和數量必須與 metric_fn 的輸出匹配。彙總將包含一個額外的鍵 ‘count’,表示例項總數,因此鍵不應包含 ‘count’。

使用示例

當輸入檔案如下所示時

{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}

輸出檔案將是

{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}

在 DAG 外部測試

subprocess.check_call(
    [
        "python",
        "-m",
        "airflow.providers.google.cloud.utils.mlengine_prediction_summary",
        "--prediction_path=gs://...",
        "--metric_fn_encoded=" + metric_fn_encoded,
        "--metric_keys=log_loss,mse",
        "--runner=DataflowRunner",
        "--staging_location=gs://...",
        "--temp_location=gs://...",
    ]
)

JsonCoder

JSON 編碼器/解碼器。

函式

MakeSummary(pcoll, metric_fn, metric_keys)

在 Dataflow 中使用的彙總 PTransform。

run([argv])

獲取預測彙總。

模組內容

class airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder[原始碼]

基類: apache_beam.coders.coders.Coder

JSON 編碼器/解碼器。

static encode(x)[原始碼]

JSON 編碼器。

static decode(x)[原始碼]

JSON 解碼器。

airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary(pcoll, metric_fn, metric_keys)[原始碼]

在 Dataflow 中使用的彙總 PTransform。

airflow.providers.google.cloud.utils.mlengine_prediction_summary.run(argv=None)[原始碼]

獲取預測彙總。

此條目有幫助嗎?