{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "df569247", "metadata": {}, "outputs": [], "source": [ "import json\n", "import math\n", "import random\n", "import os\n", "import pickle as pkl\n", "import time\n", "from typing import Dict, List\n", "\n", "import awkward as ak\n", "import fastjet\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "import numpy as np\n", "import sklearn\n", "import sklearn.metrics\n", "import torch\n", "import tqdm\n", "import vector\n", "from torch_geometric.data import Batch, Data\n", "\n", "plt.style.use(hep.style.CMS)\n", "plt.rcParams.update({\"font.size\": 20})" ] }, { "cell_type": "code", "execution_count": 2, "id": "edf87dc8", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 8, "id": "16ca6a3a", "metadata": {}, "outputs": [], "source": [ "# import relevant functions from mlpf.pyg\n", "import sys\n", "sys.path.append(\"/home/jovyan/particleflow/mlpf/\")\n", "import pyg\n", "sys.path.append(\"/home/jovyan/particleflow/mlpf/pyg/\")\n", "import utils\n", "\n", "from PFDataset import PFDataset, get_interleaved_dataloaders, Collater\n", "\n", "from pyg.mlpf import MLPF\n", "from pyg.utils import X_FEATURES, Y_FEATURES, unpack_predictions\n", "from jet_utils import match_two_jet_collections\n", "\n", "#################################### must update this function to have the proper p4\n", "def unpack_target(y):\n", " ret = {}\n", " ret[\"cls_id\"] = y[..., 0].long()\n", " ret[\"charge\"] = torch.clamp((y[..., 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2\n", "\n", " for i, feat in enumerate(Y_FEATURES):\n", " if i >= 2: # skip the cls and charge as they are defined above\n", " ret[feat] = y[..., i].to(dtype=torch.float32)\n", " ret[\"phi\"] = torch.atan2(ret[\"sin_phi\"], ret[\"cos_phi\"])\n", "\n", " # do some sanity checks\n", " # assert torch.all(ret[\"pt\"] >= 0.0) # pt\n", " # assert torch.all(torch.abs(ret[\"sin_phi\"]) <= 1.0) # sin_phi\n", " # assert torch.all(torch.abs(ret[\"cos_phi\"]) <= 1.0) # cos_phi\n", " # assert torch.all(ret[\"energy\"] >= 0.0) # energy\n", "\n", " # note ~ momentum = [\"pt\", \"eta\", \"sin_phi\", \"cos_phi\", \"energy\"]\n", " ret[\"momentum\"] = y[..., 2:7].to(dtype=torch.float32)\n", " ret[\"p4\"] = torch.cat(\n", " [ret[\"pt\"].unsqueeze(-1), ret[\"eta\"].unsqueeze(-1), ret[\"phi\"].unsqueeze(-1), ret[\"energy\"].unsqueeze(-1)], axis=-1\n", " )\n", "\n", " ret[\"genjet_idx\"] = y[..., -1].long()\n", "\n", " return ret" ] }, { "cell_type": "code", "execution_count": 9, "id": "14875b0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will use NVIDIA A100-SXM4-80GB\n" ] } ], "source": [ "# define the global base device\n", "world_size = 1\n", "if torch.cuda.device_count():\n", " rank = 0\n", " device = torch.device(\"cuda:0\")\n", " print(f\"Will use {torch.cuda.get_device_name(device)}\")\n", "else:\n", " rank = \"cpu\"\n", " device = \"cpu\"\n", " print(\"Will use cpu\")" ] }, { "cell_type": "markdown", "id": "3c4aec55", "metadata": {}, "source": [ "# Load the pre-trained MLPF model" ] }, { "cell_type": "code", "execution_count": 17, "id": "1879def6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLPF(\n", " (nn0_id): Sequential(\n", " (0): Linear(in_features=17, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (nn0_reg): Sequential(\n", " (0): Linear(in_features=17, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (conv_id): ModuleList(\n", " (0-2): 3 x SelfAttentionLayer(\n", " (mha): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", " )\n", " (norm0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (seq): Sequential(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (conv_reg): ModuleList(\n", " (0-2): 3 x SelfAttentionLayer(\n", " (mha): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", " )\n", " (norm0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (seq): Sequential(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (nn_id): Sequential(\n", " (0): Linear(in_features=529, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=6, bias=True)\n", " )\n", " (nn_pt): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_eta): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_sin_phi): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_cos_phi): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_energy): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", ")\n" ] } ], "source": [ "def load_checkpoint(checkpoint, model, optimizer=None):\n", " if isinstance(model, torch.nn.parallel.DistributedDataParallel):\n", " model.module.load_state_dict(checkpoint[\"model_state_dict\"])\n", " else:\n", " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", " if optimizer:\n", " optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", " return model, optimizer\n", " else:\n", " return model\n", " \n", " \n", "loaddir = \"/pfvol/experiments/MLPF_clic_backbone_pyg-clic_20240429_101112_971749\"\n", "\n", "with open(f\"{loaddir}/model_kwargs.pkl\", \"rb\") as f:\n", " mlpf_kwargs = pkl.load(f)\n", "\n", "mlpf_kwargs[\"attention_type\"] = \"flash\"\n", "\n", "mlpf = MLPF(**mlpf_kwargs).to(torch.device(rank))\n", "checkpoint = torch.load(f\"{loaddir}/best_weights.pth\", map_location=torch.device(rank))\n", "\n", "mlpf = load_checkpoint(checkpoint, mlpf)\n", "mlpf.eval()\n", "\n", "print(mlpf) " ] }, { "cell_type": "markdown", "id": "dde5f191", "metadata": {}, "source": [ "# CLIC dataset" ] }, { "cell_type": "code", "execution_count": 18, "id": "d98c6857", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clic_edm_qq_pf\t\t cms_pf_qcd_high_pt\t cms_pf_single_proton\r\n", "clic_edm_ttbar_pf\t cms_pf_single_electron cms_pf_single_tau\r\n", "clic_edm_ttbar_pu10_pf\t cms_pf_single_gamma\t cms_pf_sms_t1tttt\r\n", "clic_edm_ww_fullhad_pf\t cms_pf_single_mu\t cms_pf_ttbar\r\n", "clic_edm_zh_tautau_pf\t cms_pf_single_neutron cms_pf_ztt\r\n", "cms_pf_multi_particle_gun cms_pf_single_pi\t delphes_qcd_pf\r\n", "cms_pf_qcd\t\t cms_pf_single_pi0\t delphes_ttbar_pf\r\n" ] } ], "source": [ "! ls /pfvol/tensorflow_datasets/" ] }, { "cell_type": "code", "execution_count": 19, "id": "49d849e9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['type',\n", " 'pt | et',\n", " 'eta',\n", " 'sin_phi',\n", " 'cos_phi',\n", " 'p | energy',\n", " 'chi2 | position.x',\n", " 'ndf | position.y',\n", " 'dEdx | position.z',\n", " 'dEdxError | iTheta',\n", " 'radiusOfInnermostHit | energy_ecal',\n", " 'tanLambda | energy_hcal',\n", " 'D0 | energy_other',\n", " 'omega | num_hits',\n", " 'Z0 | sigma_x',\n", " 'time | sigma_y',\n", " 'Null | sigma_z']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can see the 17th features here (recall type is 1 for tracks and 2 for clusters)\n", "X_FEATURES[\"clic\"]" ] }, { "cell_type": "code", "execution_count": 20, "id": "b3781b28", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['cls_id', 'charge', 'pt', 'eta', 'sin_phi', 'cos_phi', 'energy']" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can see the 8 gen features per pf element here (notice the jet_index which may be useful)\n", "Y_FEATURES" ] }, { "cell_type": "markdown", "id": "d602e8e2", "metadata": {}, "source": [ "# Get the dataset (Events)" ] }, { "cell_type": "code", "execution_count": 24, "id": "dfc6b930", "metadata": {}, "outputs": [], "source": [ "data_dir = \"/home/jovyan/particleflow/tensorflow_datasets/\"\n", "sample = \"clic_edm_ttbar_pf\"\n", "\n", "dataset_train = PFDataset(data_dir, f\"{sample}:1.5.0\", \"train\", num_samples=10_000)\n", "\n", "batch_size = 100\n", "\n", "train_loader = torch.utils.data.DataLoader(\n", " dataset_train.ds,\n", " batch_size=batch_size,\n", " collate_fn=Collater([\"X\", \"ygen\", \"ycand\"]),\n", " pin_memory=True,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "abf11cbc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 246, 17])\n" ] } ], "source": [ "for batch in train_loader:\n", " batch = batch.to(rank, non_blocking=True)\n", " break\n", "print(batch.X.shape) " ] }, { "cell_type": "markdown", "id": "883675d2", "metadata": {}, "source": [ "# Pre-processing (Events -> Jets)" ] }, { "cell_type": "code", "execution_count": 48, "id": "f8b7ad0b", "metadata": {}, "outputs": [], "source": [ "freeze_backbone = True" ] }, { "cell_type": "code", "execution_count": 53, "id": "d22f4bdf", "metadata": {}, "outputs": [], "source": [ "############################### set up forward hooks to retrive the latent representations of MLPF\n", "latent_reps = {}\n", "def get_activations(name):\n", " def hook(mlpf, input, output):\n", " latent_reps[name] = output\n", "\n", " return hook\n", "\n", "mlpf.conv_reg[0].dropout.register_forward_hook(get_activations(\"conv_reg0\"))\n", "mlpf.conv_reg[1].dropout.register_forward_hook(get_activations(\"conv_reg1\"))\n", "mlpf.conv_reg[2].dropout.register_forward_hook(get_activations(\"conv_reg2\"))\n", "mlpf.nn_id.register_forward_hook(get_activations(\"nn_id\")) \n", "###############################\n", "\n", "def get_latent_reps(batch, latent_reps):\n", " for layer in latent_reps:\n", " if \"conv\" in layer:\n", " latent_reps[layer] *= batch.mask.unsqueeze(-1)\n", "\n", " latentX = torch.cat(\n", " [\n", " batch.X.to(rank),\n", " latent_reps[\"conv_reg0\"],\n", " latent_reps[\"conv_reg1\"],\n", " latent_reps[\"conv_reg2\"],\n", " latent_reps[\"nn_id\"],\n", " ],\n", " axis=-1,\n", " )\n", " return latentX" ] }, { "cell_type": "code", "execution_count": 54, "id": "1bf42610", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running MLPF inference on batch 0\n" ] } ], "source": [ "sample_to_lab = {\n", " \"clic_edm_ttbar_pf\": 1,\n", " \"clic_edm_qq_pf\": 0, \n", "}\n", "\n", "####################### Config\n", "jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)\n", "jet_ptcut = 15.0\n", "jet_match_dr = 0.1\n", "\n", "save_every_X_batch = 3 # will save to disk every \"X\" batches\n", "\n", "######################## Build the dataset\n", "jet_dataset = [] # will save on disk and reinitialize at the end of the loop\n", "saving_i = 0 # will just increment with every save\n", "\n", "for ibatch, batch in enumerate(train_loader):\n", "\n", " # run the MLPF model in inference mode to get the MLPF cands / latent representations \n", " print(f\"Running MLPF inference on batch {ibatch}\")\n", " batch = batch.to(rank, non_blocking=True)\n", "# with torch.no_grad():\n", " with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, enabled=True):\n", " ymlpf = mlpf(batch.X, batch.mask)\n", " ymlpf = unpack_predictions(ymlpf)\n", " \n", " # get the latent representations\n", " ymlpf[\"latentX\"] = get_latent_reps(batch, latent_reps)\n", "\n", "# for k, v in ymlpf.items():\n", "# ymlpf[k] = v.detach().cpu()\n", "\n", "# msk_ymlpf = ymlpf[\"cls_id\"] != 0\n", "# ymlpf[\"p4\"] = ymlpf[\"p4\"] * msk_ymlpf.unsqueeze(-1)\n", "\n", "# jets_coll = {}\n", "# ####################### get the reco jet collection\n", "# vec = vector.awk(\n", "# ak.zip(\n", "# {\n", "# \"pt\": ymlpf[\"p4\"][:, :, 0].to(\"cpu\"),\n", "# \"eta\": ymlpf[\"p4\"][:, :, 1].to(\"cpu\"),\n", "# \"phi\": ymlpf[\"p4\"][:, :, 2].to(\"cpu\"),\n", "# \"e\": ymlpf[\"p4\"][:, :, 3].to(\"cpu\"),\n", "# }\n", "# )\n", "# )\n", "# cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)\n", "# jets_coll[\"reco\"] = cluster.inclusive_jets(min_pt=jet_ptcut)\n", " \n", "# # get the constituents to mask the MLPF candidates and build the input for the downstream \n", "# reco_constituents = cluster.constituent_index(min_pt=jet_ptcut)\n", "# ####################### \n", " \n", "# ####################### get the gen jet collection\n", "# ygen = unpack_target(batch.ygen) \n", "# vec = vector.awk(\n", "# ak.zip(\n", "# {\n", "# \"pt\": ygen[\"p4\"][:, :, 0].to(\"cpu\"),\n", "# \"eta\": ygen[\"p4\"][:, :, 1].to(\"cpu\"),\n", "# \"phi\": ygen[\"p4\"][:, :, 2].to(\"cpu\"),\n", "# \"e\": ygen[\"p4\"][:, :, 3].to(\"cpu\"),\n", "# }\n", "# )\n", "# )\n", "# cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)\n", "# jets_coll[\"gen\"] = cluster.inclusive_jets(min_pt=jet_ptcut) \n", "# #######################\n", " \n", "# matched_jets = match_two_jet_collections(jets_coll, \"reco\", \"gen\", jet_match_dr)\n", " \n", "# # build the big jet list\n", "# for iev in tqdm.tqdm(range(len(matched_jets[\"gen\"]))):\n", " \n", "# num_matched_jets = len(matched_jets[\"gen\"][iev]) # number of gen jets matched to reco\n", " \n", "# jets_per_event = []\n", "# for j in range(num_matched_jets):\n", " \n", "# # get the actual indices of the matched jets\n", "# igenjet = matched_jets[\"gen\"][iev][j]\n", "# irecojet = matched_jets[\"reco\"][iev][j]\n", "\n", "# # build a mask tensor that will select the particles that belong to the gen jet \n", "# msk_indices = reco_constituents[iev][irecojet].to_numpy()\n", "\n", "# if len(msk_indices)<3:\n", "# # don't save jets with very few particles\n", "# continue\n", "\n", "# jets_per_event += [\n", "\n", "# Data(\n", "# # Target for jet tagging\n", "# gen_jet_label=torch.tensor(sample_to_lab[sample]).unsqueeze(0).to(dtype=torch.float32),\n", "\n", "# # Target for jet p4 regression \n", "# gen_jet_pt=torch.tensor(jets_coll[\"gen\"][iev][igenjet].pt, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_eta=torch.tensor(jets_coll[\"gen\"][iev][igenjet].eta, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_phi=torch.tensor(jets_coll[\"gen\"][iev][igenjet].phi, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_energy=torch.tensor(jets_coll[\"gen\"][iev][igenjet].energy, dtype=torch.float32).unsqueeze(0),\n", " \n", "# # could be part of the target\n", "# reco_jet_pt=torch.tensor(jets_coll[\"reco\"][iev][irecojet].pt, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_eta=torch.tensor(jets_coll[\"reco\"][iev][irecojet].eta, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_phi=torch.tensor(jets_coll[\"reco\"][iev][irecojet].phi, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_energy=torch.tensor(jets_coll[\"reco\"][iev][irecojet].energy, dtype=torch.float32).unsqueeze(0),\n", " \n", "# # Input\n", "# mlpfcands_momentum=ymlpf[\"momentum\"][iev][msk_indices],\n", "# mlpfcands_pid=ymlpf[\"cls_id_onehot\"][iev][msk_indices],\n", "# mlpfcands_charge=ymlpf[\"charge\"][iev][msk_indices],\n", "# mlpfcands_latentX=ymlpf[\"latentX\"][iev][msk_indices],\n", "# )\n", "# ]\n", "\n", "# # break # per jet \n", " \n", "# # random.shuffle(jets_per_event)\n", "# jet_dataset += jets_per_event\n", "\n", "# break # per event\n", "\n", "# # random.shuffle(jet_dataset)\n", "# if (ibatch % (save_every_X_batch-1) == 0) and (ibatch != 0):\n", "# print(f\"saving at iteration {ibatch} on disk /pfvol/jetdataset/{sample}/train/{saving_i}.pt\")\n", "# torch.save(jet_dataset, f\"/pfvol/jetdataset/{sample}/train/{saving_i}.pt\")\n", "# saving_i += 1\n", "# jet_dataset = []\n", " break # per batch" ] }, { "cell_type": "code", "execution_count": 56, "id": "71bfbe1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.0000e+00, 4.0499e+00, 9.5381e-01, ..., -5.0312e+00,\n", " -4.0625e-01, -5.9766e-01],\n", " [ 1.0000e+00, 1.8992e+01, 6.8335e-02, ..., -6.7500e+00,\n", " 2.4414e-01, 1.2500e+00],\n", " [ 1.0000e+00, 4.8831e+00, 1.2232e-01, ..., -5.4688e+00,\n", " -2.0469e+00, 6.9375e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 5.2279e+01, 2.1137e-01, ..., -6.8125e+00,\n", " -1.6875e+00, 1.1597e-03],\n", " [ 1.0000e+00, 1.2576e+01, -3.7162e-02, ..., -6.2812e+00,\n", " 2.4219e+00, -2.2500e+00],\n", " [ 1.0000e+00, 8.7163e+00, -6.4841e-02, ..., -7.1250e+00,\n", " 3.0781e+00, 6.9922e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.2418e+01, -1.5321e-02, ..., -7.7188e+00,\n", " 1.7031e+00, 2.3594e+00],\n", " [ 1.0000e+00, 1.5101e+01, -3.2354e-02, ..., -8.1875e+00,\n", " 1.3594e+00, 5.6250e+00],\n", " [ 1.0000e+00, 2.4137e+01, 8.7432e-01, ..., -6.5312e+00,\n", " -7.1484e-01, 7.3750e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " ...,\n", "\n", " [[ 1.0000e+00, 2.7388e+01, 1.2865e+00, ..., -4.1250e+00,\n", " 5.6250e+00, -4.6875e+00],\n", " [ 1.0000e+00, 2.1186e+01, 1.8190e-01, ..., -6.2812e+00,\n", " 6.7500e+00, -2.5938e+00],\n", " [ 1.0000e+00, 1.5213e+01, -4.2097e-01, ..., -5.9688e+00,\n", " 1.8848e-01, -7.4219e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.1207e+01, -7.9346e-01, ..., -6.0000e+00,\n", " 1.4688e+00, -1.1016e+00],\n", " [ 1.0000e+00, 6.9457e+00, 9.0808e-02, ..., -7.4375e+00,\n", " 1.6953e+00, 7.5781e-01],\n", " [ 1.0000e+00, 6.7352e+00, -5.3468e-01, ..., -7.6875e+00,\n", " 1.0703e+00, 2.6094e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.4524e+01, -4.8954e-01, ..., -7.2188e+00,\n", " 6.3438e+00, -2.3594e+00],\n", " [ 1.0000e+00, 1.0940e+01, -6.0876e-01, ..., -6.5938e+00,\n", " 4.4062e+00, -3.7500e+00],\n", " [ 1.0000e+00, 1.0282e+01, -6.9122e-01, ..., -7.3438e+00,\n", " 8.2031e-01, 7.6953e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]]], device='cuda:0')" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"latentX\"].detach()" ] }, { "cell_type": "code", "execution_count": 45, "id": "956e9d8b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"cls_id_onehot\"].requires_grad" ] }, { "cell_type": "code", "execution_count": 46, "id": "ce0911a4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"latentX\"].requires_grad" ] }, { "cell_type": "code", "execution_count": null, "id": "03aa548e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "fe1dd2af", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "891aa275", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b1383441", "metadata": {}, "outputs": [], "source": [ "print(len(reco_constituents[0][0]))\n", "print(len(reco_constituents[0][1]))\n", "print(len(reco_constituents[0][2]))\n", "print(len(reco_constituents[0][3]))\n", "print(len(reco_constituents[0][4]))" ] }, { "cell_type": "code", "execution_count": 77, "id": "8e4d08ad", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{reco: [0, 1, 2, 3, 4],\n",
       " gen: [0, 1, 2, 3, 4]}\n",
       "-----------------------\n",
       "type: {\n",
       "    reco: var * int64,\n",
       "    gen: var * int64\n",
       "}
" ], "text/plain": [ "" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matched_jets[0]" ] }, { "cell_type": "code", "execution_count": 70, "id": "b3b4d823", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4\n", "18\n", "29\n", "32\n", "143\n" ] } ], "source": [ "print(len(reco_constituents[0][0]))\n", "print(len(reco_constituents[0][1]))\n", "print(len(reco_constituents[0][2]))\n", "print(len(reco_constituents[0][3]))\n", "print(len(reco_constituents[0][4]))" ] }, { "cell_type": "code", "execution_count": 72, "id": "3fda2809", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{reco: [0, 1, 2, 3],\n",
       " gen: [0, 1, 2, 3]}\n",
       "----------------------\n",
       "type: {\n",
       "    reco: var * int64,\n",
       "    gen: var * int64\n",
       "}
" ], "text/plain": [ "" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matched_jets[0]" ] }, { "cell_type": "markdown", "id": "543e188a", "metadata": {}, "source": [ "# Load the dataset" ] }, { "cell_type": "code", "execution_count": 33, "id": "b92d8c4b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "97M\t/pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt\r\n" ] } ], "source": [ "! du -sh /pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt" ] }, { "cell_type": "code", "execution_count": 5, "id": "1a97f4e5", "metadata": {}, "outputs": [], "source": [ "# load one of the train files\n", "jet_dataset = torch.load(\"/pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "0ad26b72", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "19821" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(jet_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "id": "756bed70", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Jet # 0\n", "keys ['mlpfcands_pid', 'gen_jet_eta', 'gen_jet_phi', 'mlpfcands_charge', 'gen_jet_energy', 'gen_jet_pt', 'reco_jet_energy', 'mlpfcands_momentum', 'gen_jet_label', 'reco_jet_pt', 'reco_jet_phi', 'mlpfcands_latentX', 'reco_jet_eta']\n", "--------------------------\n", " gen jet pt: 41.288578033447266\n", " reco jet pt: 41.34593963623047\n", "--------------------------\n", " mlpfcands_momentum: torch.Size([18, 5]) - 18 particles with 5 p4 features (pt, eta, sphi, cphi, energy)\n", " mlpfcands_pid: torch.Size([18, 6]) - 18 particles with 6 PID features\n", " mlpfcands_charge: torch.Size([18, 3]) - 18 particles with 3 charge features\n", " mlpfcands_latentX: torch.Size([18, 791]) - 18 particles with 791 latent features\n" ] } ], "source": [ "# inspect one jet\n", "ijet = 0\n", "\n", "print(f\"Jet # {ijet}\")\n", "print(f\"keys {jet_dataset[ijet].keys()}\")\n", "print(\"--------------------------\")\n", "print(\" gen jet pt:\", jet_dataset[ijet][\"gen_jet_pt\"].item())\n", "print(\" reco jet pt:\", jet_dataset[ijet][\"reco_jet_pt\"].item())\n", "print(\"--------------------------\")\n", "print(\" mlpfcands_momentum:\", jet_dataset[0][\"mlpfcands_momentum\"].shape, f\"- {jet_dataset[0]['mlpfcands_momentum'].shape[0]} particles with 5 p4 features (pt, eta, sphi, cphi, energy)\")\n", "print(\" mlpfcands_pid:\", jet_dataset[0][\"mlpfcands_pid\"].shape, f\"- {jet_dataset[0]['mlpfcands_momentum'].shape[0]} particles with 6 PID features\")\n", "print(\" mlpfcands_charge:\", jet_dataset[0][\"mlpfcands_charge\"].shape, f\"- {jet_dataset[0]['mlpfcands_momentum'].shape[0]} particles with 3 charge features\")\n", "print(\" mlpfcands_latentX:\", jet_dataset[0][\"mlpfcands_latentX\"].shape, f\"- {jet_dataset[0]['mlpfcands_momentum'].shape[0]} particles with 791 latent features\")" ] }, { "cell_type": "markdown", "id": "09d976c1", "metadata": {}, "source": [ "# Build a DataLoader" ] }, { "cell_type": "code", "execution_count": 21, "id": "7b5708a8", "metadata": {}, "outputs": [], "source": [ "# Define your custom collate function to add a batch key\n", "def collate_fn(data_list):\n", " batch = Batch.from_data_list(data_list)\n", " \n", " batch_list = []\n", " for ijet, jet in enumerate(data_list):\n", " num_MLPFcands = len(jet[\"mlpfcands_momentum\"]) # number of MLPFcands\n", " batch_list += [ijet] * num_MLPFcands\n", "\n", " batch.batch = torch.tensor(batch_list)\n", " \n", " return batch" ] }, { "cell_type": "code", "execution_count": 52, "id": "f6ae14a4", "metadata": {}, "outputs": [], "source": [ "jetloader = torch.utils.data.DataLoader(jet_dataset, batch_size=5, collate_fn=collate_fn)" ] }, { "cell_type": "code", "execution_count": 53, "id": "b2c8511b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['gen_jet_eta',\n", " 'reco_jet_pt',\n", " 'mlpfcands_pid',\n", " 'gen_jet_phi',\n", " 'gen_jet_energy',\n", " 'batch',\n", " 'reco_jet_energy',\n", " 'gen_jet_label',\n", " 'reco_jet_eta',\n", " 'reco_jet_phi',\n", " 'gen_jet_pt',\n", " 'mlpfcands_charge',\n", " 'mlpfcands_momentum',\n", " 'mlpfcands_latentX']" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "for batch in jetloader:\n", " break\n", "batch.keys()" ] }, { "cell_type": "code", "execution_count": 48, "id": "62ac64c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gen_jet_eta\n", "reco_jet_pt\n", "gen_jet_phi\n", "gen_jet_energy\n", "reco_jet_energy\n", "gen_jet_label\n", "reco_jet_eta\n", "reco_jet_phi\n", "gen_jet_pt\n" ] } ], "source": [ "for key in batch.keys():\n", " if \"jet\" in key:\n", " print(key)" ] }, { "cell_type": "code", "execution_count": 47, "id": "5c076f88", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mlpfcands_pid\n", "mlpfcands_charge\n", "mlpfcands_momentum\n", "mlpfcands_latentX\n" ] } ], "source": [ "for key in batch.keys():\n", " if \"mlpf\" in key:\n", " print(key)" ] }, { "cell_type": "code", "execution_count": 24, "id": "8e4821fa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([38.7100, 16.4567, 31.4733, 45.9380, 66.5003])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"gen_jet_pt\"]" ] }, { "cell_type": "code", "execution_count": 25, "id": "33dab9e0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([36.2084, 15.6482, 41.6876, 48.7271, 80.7797])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"reco_jet_pt\"]" ] }, { "cell_type": "code", "execution_count": 49, "id": "acb9b0c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([154, 5])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"mlpfcands_momentum\"].shape # 154 particles per jet; 5 p4 info (pt, eta, sphi, cphi, e)" ] }, { "cell_type": "code", "execution_count": 51, "id": "8fd812b8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([154, 791])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"mlpfcands_latentX\"].shape # 154 particles per jet; 791 latent features" ] }, { "cell_type": "code", "execution_count": 27, "id": "7f399a0d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3,\n", " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,\n", " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n", " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"batch\"] # indices to which jet the particles belong" ] }, { "cell_type": "markdown", "id": "2ec70de5", "metadata": {}, "source": [ "# Quick validation" ] }, { "cell_type": "code", "execution_count": 28, "id": "e784ee72", "metadata": {}, "outputs": [], "source": [ "genjets, recojets = {}, {}\n", "for x in [\"pt\", \"eta\", \"phi\", \"energy\"]:\n", " genjets[x], recojets[x] = [], []\n", " \n", "for batch in jetloader:\n", " \n", " for x in [\"pt\", \"eta\", \"phi\", \"energy\"]:\n", " genjets[x] += batch[f\"gen_jet_{x}\"].tolist()\n", " recojets[x] += batch[f\"reco_jet_{x}\"].tolist()" ] }, { "cell_type": "code", "execution_count": 29, "id": "859396a0", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(figsize=(10,10), nrows=2, ncols=2)\n", "\n", "bins = np.linspace(15,130,40)\n", "axes[0,0].hist2d(genjets[\"pt\"], recojets[\"pt\"], bins) #row=0, col=0\n", "axes[0,0].set_title(\"pt\")\n", "\n", "bins = np.linspace(-1.8,1.8,40)\n", "axes[0,1].hist2d(genjets[\"eta\"], recojets[\"eta\"], bins) #row=1, col=0\n", "axes[0,1].set_title(\"eta\");\n", "\n", "bins = np.linspace(-3.14,3.14,40)\n", "axes[1,0].hist2d(genjets[\"phi\"], recojets[\"phi\"], bins) #row=1, col=0\n", "axes[1,0].set_title(\"phi\");\n", "\n", "bins = np.linspace(15,130,40)\n", "axes[1,1].hist2d(genjets[\"energy\"], recojets[\"energy\"], bins) #row=1, col=0\n", "axes[1,1].set_title(\"energy\");\n", "\n", "fig.text(0.5, 0, 'Gen Jet', ha='center')\n", "fig.text(0, 0.5, 'Reco Jet', va='center', rotation='vertical')\n", "\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "id": "e9baced1", "metadata": {}, "source": [ "# Setup the downstream task" ] }, { "cell_type": "code", "execution_count": 30, "id": "89613a90", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "from torch_geometric.nn.pool import global_add_pool\n", "\n", "\n", "def ffn(input_dim, output_dim, width, act, dropout):\n", " return nn.Sequential(\n", " nn.Linear(input_dim, width),\n", " act(),\n", " torch.nn.LayerNorm(width),\n", " nn.Dropout(dropout),\n", " nn.Linear(width, output_dim),\n", " )\n", "\n", "class JetRegressor(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim=14,\n", " embedding_dim=64,\n", " output_dim=1,\n", " width=256,\n", " dropout=0,\n", " ):\n", " super(JetRegressor, self).__init__()\n", "\n", " \"\"\"\n", " Takes as input either (1) the MLPF candidates OR (2) the latent representations of the MLPF candidates,\n", " and runs an MLP to predict an output per jet: \"ptcorr\"; which will enter the loss as follows:\n", " pred_jetpt = ptcorr * reco_pt\n", "\n", " LOSS = Huber(true_jetpt, pred_jetpt)\n", "\n", " \"\"\"\n", "\n", " self.act = nn.ELU\n", " self.nn1 = ffn(input_dim, embedding_dim, width, self.act, dropout)\n", " self.nn2 = ffn(embedding_dim, output_dim, width, self.act, dropout)\n", "\n", " # @torch.compile\n", " def forward(self, X, batch):\n", "\n", " embeddings = self.nn1(X)\n", " \n", "# pooled_embeddings = embeddings.sum(axis=1) # recall ~ [Batch, Particles, Features]\n", " \n", " pooled_embeddings = global_add_pool(embeddings, batch)\n", "\n", " return self.nn2(pooled_embeddings)" ] }, { "cell_type": "code", "execution_count": 31, "id": "2e2e9af7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "JetRegressor(\n", " (nn1): Sequential(\n", " (0): Linear(in_features=791, out_features=256, bias=True)\n", " (1): ELU(alpha=1.0)\n", " (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0, inplace=False)\n", " (4): Linear(in_features=256, out_features=64, bias=True)\n", " )\n", " (nn2): Sequential(\n", " (0): Linear(in_features=64, out_features=256, bias=True)\n", " (1): ELU(alpha=1.0)\n", " (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0, inplace=False)\n", " (4): Linear(in_features=256, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "run_with_latentX = True\n", "\n", "if run_with_latentX:\n", " input_dim = 791\n", "else:\n", " input_dim = 14 \n", " \n", "model = JetRegressor(input_dim).to(rank)\n", "model.train()" ] }, { "cell_type": "code", "execution_count": 32, "id": "1143a728", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.0532, device='cuda:0', grad_fn=)\n" ] } ], "source": [ "for batch in jetloader:\n", "\n", " batch = batch.to(rank)\n", "\n", " if run_with_latentX:\n", " X = batch[\"mlpfcands_latentX\"]\n", " else:\n", " X = torch.cat([batch[\"mlpfcands_momentum\"], batch[\"mlpfcands_pid\"], batch[\"mlpfcands_charge\"]], axis=-1)\n", " \n", " ptcorr = model(X, batch.batch).squeeze(1)\n", " \n", " target = torch.log(batch[\"gen_jet_pt\"] / batch[\"reco_jet_pt\"])\n", "\n", " loss = torch.nn.functional.huber_loss(target, ptcorr)\n", " \n", " break\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": null, "id": "926cebd1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "0f7325ed", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "94fe4139", "metadata": {}, "source": [ "########\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8cea34d3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }