{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "df569247", "metadata": {}, "outputs": [], "source": [ "import json\n", "import math\n", "import random\n", "import os\n", "import pickle as pkl\n", "import time\n", "from typing import Dict, List\n", "\n", "import awkward as ak\n", "import fastjet\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "import numpy as np\n", "import sklearn\n", "import sklearn.metrics\n", "import torch\n", "import tqdm\n", "import vector\n", "from torch_geometric.data import Batch, Data\n", "\n", "plt.style.use(hep.style.CMS)\n", "plt.rcParams.update({\"font.size\": 20})" ] }, { "cell_type": "code", "execution_count": 2, "id": "edf87dc8", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 8, "id": "16ca6a3a", "metadata": {}, "outputs": [], "source": [ "# import relevant functions from mlpf.pyg\n", "import sys\n", "sys.path.append(\"/home/jovyan/particleflow/mlpf/\")\n", "import pyg\n", "sys.path.append(\"/home/jovyan/particleflow/mlpf/pyg/\")\n", "import utils\n", "\n", "from PFDataset import PFDataset, get_interleaved_dataloaders, Collater\n", "\n", "from pyg.mlpf import MLPF\n", "from pyg.utils import X_FEATURES, Y_FEATURES, unpack_predictions\n", "from jet_utils import match_two_jet_collections\n", "\n", "#################################### must update this function to have the proper p4\n", "def unpack_target(y):\n", " ret = {}\n", " ret[\"cls_id\"] = y[..., 0].long()\n", " ret[\"charge\"] = torch.clamp((y[..., 1] + 1).to(dtype=torch.float32), 0, 2) # -1, 0, 1 -> 0, 1, 2\n", "\n", " for i, feat in enumerate(Y_FEATURES):\n", " if i >= 2: # skip the cls and charge as they are defined above\n", " ret[feat] = y[..., i].to(dtype=torch.float32)\n", " ret[\"phi\"] = torch.atan2(ret[\"sin_phi\"], ret[\"cos_phi\"])\n", "\n", " # do some sanity checks\n", " # assert torch.all(ret[\"pt\"] >= 0.0) # pt\n", " # assert torch.all(torch.abs(ret[\"sin_phi\"]) <= 1.0) # sin_phi\n", " # assert torch.all(torch.abs(ret[\"cos_phi\"]) <= 1.0) # cos_phi\n", " # assert torch.all(ret[\"energy\"] >= 0.0) # energy\n", "\n", " # note ~ momentum = [\"pt\", \"eta\", \"sin_phi\", \"cos_phi\", \"energy\"]\n", " ret[\"momentum\"] = y[..., 2:7].to(dtype=torch.float32)\n", " ret[\"p4\"] = torch.cat(\n", " [ret[\"pt\"].unsqueeze(-1), ret[\"eta\"].unsqueeze(-1), ret[\"phi\"].unsqueeze(-1), ret[\"energy\"].unsqueeze(-1)], axis=-1\n", " )\n", "\n", " ret[\"genjet_idx\"] = y[..., -1].long()\n", "\n", " return ret" ] }, { "cell_type": "code", "execution_count": 9, "id": "14875b0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Will use NVIDIA A100-SXM4-80GB\n" ] } ], "source": [ "# define the global base device\n", "world_size = 1\n", "if torch.cuda.device_count():\n", " rank = 0\n", " device = torch.device(\"cuda:0\")\n", " print(f\"Will use {torch.cuda.get_device_name(device)}\")\n", "else:\n", " rank = \"cpu\"\n", " device = \"cpu\"\n", " print(\"Will use cpu\")" ] }, { "cell_type": "markdown", "id": "3c4aec55", "metadata": {}, "source": [ "# Load the pre-trained MLPF model" ] }, { "cell_type": "code", "execution_count": 17, "id": "1879def6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLPF(\n", " (nn0_id): Sequential(\n", " (0): Linear(in_features=17, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (nn0_reg): Sequential(\n", " (0): Linear(in_features=17, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (conv_id): ModuleList(\n", " (0-2): 3 x SelfAttentionLayer(\n", " (mha): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", " )\n", " (norm0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (seq): Sequential(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (conv_reg): ModuleList(\n", " (0-2): 3 x SelfAttentionLayer(\n", " (mha): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", " )\n", " (norm0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (seq): Sequential(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " )\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (nn_id): Sequential(\n", " (0): Linear(in_features=529, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=6, bias=True)\n", " )\n", " (nn_pt): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_eta): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_sin_phi): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_cos_phi): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", " (nn_energy): RegressionOutput(\n", " (nn): Sequential(\n", " (0): Linear(in_features=535, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=512, out_features=2, bias=True)\n", " )\n", " )\n", ")\n" ] } ], "source": [ "def load_checkpoint(checkpoint, model, optimizer=None):\n", " if isinstance(model, torch.nn.parallel.DistributedDataParallel):\n", " model.module.load_state_dict(checkpoint[\"model_state_dict\"])\n", " else:\n", " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", " if optimizer:\n", " optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n", " return model, optimizer\n", " else:\n", " return model\n", " \n", " \n", "loaddir = \"/pfvol/experiments/MLPF_clic_backbone_pyg-clic_20240429_101112_971749\"\n", "\n", "with open(f\"{loaddir}/model_kwargs.pkl\", \"rb\") as f:\n", " mlpf_kwargs = pkl.load(f)\n", "\n", "mlpf_kwargs[\"attention_type\"] = \"flash\"\n", "\n", "mlpf = MLPF(**mlpf_kwargs).to(torch.device(rank))\n", "checkpoint = torch.load(f\"{loaddir}/best_weights.pth\", map_location=torch.device(rank))\n", "\n", "mlpf = load_checkpoint(checkpoint, mlpf)\n", "mlpf.eval()\n", "\n", "print(mlpf) " ] }, { "cell_type": "markdown", "id": "dde5f191", "metadata": {}, "source": [ "# CLIC dataset" ] }, { "cell_type": "code", "execution_count": 18, "id": "d98c6857", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clic_edm_qq_pf\t\t cms_pf_qcd_high_pt\t cms_pf_single_proton\r\n", "clic_edm_ttbar_pf\t cms_pf_single_electron cms_pf_single_tau\r\n", "clic_edm_ttbar_pu10_pf\t cms_pf_single_gamma\t cms_pf_sms_t1tttt\r\n", "clic_edm_ww_fullhad_pf\t cms_pf_single_mu\t cms_pf_ttbar\r\n", "clic_edm_zh_tautau_pf\t cms_pf_single_neutron cms_pf_ztt\r\n", "cms_pf_multi_particle_gun cms_pf_single_pi\t delphes_qcd_pf\r\n", "cms_pf_qcd\t\t cms_pf_single_pi0\t delphes_ttbar_pf\r\n" ] } ], "source": [ "! ls /pfvol/tensorflow_datasets/" ] }, { "cell_type": "code", "execution_count": 19, "id": "49d849e9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['type',\n", " 'pt | et',\n", " 'eta',\n", " 'sin_phi',\n", " 'cos_phi',\n", " 'p | energy',\n", " 'chi2 | position.x',\n", " 'ndf | position.y',\n", " 'dEdx | position.z',\n", " 'dEdxError | iTheta',\n", " 'radiusOfInnermostHit | energy_ecal',\n", " 'tanLambda | energy_hcal',\n", " 'D0 | energy_other',\n", " 'omega | num_hits',\n", " 'Z0 | sigma_x',\n", " 'time | sigma_y',\n", " 'Null | sigma_z']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can see the 17th features here (recall type is 1 for tracks and 2 for clusters)\n", "X_FEATURES[\"clic\"]" ] }, { "cell_type": "code", "execution_count": 20, "id": "b3781b28", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['cls_id', 'charge', 'pt', 'eta', 'sin_phi', 'cos_phi', 'energy']" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we can see the 8 gen features per pf element here (notice the jet_index which may be useful)\n", "Y_FEATURES" ] }, { "cell_type": "markdown", "id": "d602e8e2", "metadata": {}, "source": [ "# Get the dataset (Events)" ] }, { "cell_type": "code", "execution_count": 24, "id": "dfc6b930", "metadata": {}, "outputs": [], "source": [ "data_dir = \"/home/jovyan/particleflow/tensorflow_datasets/\"\n", "sample = \"clic_edm_ttbar_pf\"\n", "\n", "dataset_train = PFDataset(data_dir, f\"{sample}:1.5.0\", \"train\", num_samples=10_000)\n", "\n", "batch_size = 100\n", "\n", "train_loader = torch.utils.data.DataLoader(\n", " dataset_train.ds,\n", " batch_size=batch_size,\n", " collate_fn=Collater([\"X\", \"ygen\", \"ycand\"]),\n", " pin_memory=True,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "abf11cbc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 246, 17])\n" ] } ], "source": [ "for batch in train_loader:\n", " batch = batch.to(rank, non_blocking=True)\n", " break\n", "print(batch.X.shape) " ] }, { "cell_type": "markdown", "id": "883675d2", "metadata": {}, "source": [ "# Pre-processing (Events -> Jets)" ] }, { "cell_type": "code", "execution_count": 48, "id": "f8b7ad0b", "metadata": {}, "outputs": [], "source": [ "freeze_backbone = True" ] }, { "cell_type": "code", "execution_count": 53, "id": "d22f4bdf", "metadata": {}, "outputs": [], "source": [ "############################### set up forward hooks to retrive the latent representations of MLPF\n", "latent_reps = {}\n", "def get_activations(name):\n", " def hook(mlpf, input, output):\n", " latent_reps[name] = output\n", "\n", " return hook\n", "\n", "mlpf.conv_reg[0].dropout.register_forward_hook(get_activations(\"conv_reg0\"))\n", "mlpf.conv_reg[1].dropout.register_forward_hook(get_activations(\"conv_reg1\"))\n", "mlpf.conv_reg[2].dropout.register_forward_hook(get_activations(\"conv_reg2\"))\n", "mlpf.nn_id.register_forward_hook(get_activations(\"nn_id\")) \n", "###############################\n", "\n", "def get_latent_reps(batch, latent_reps):\n", " for layer in latent_reps:\n", " if \"conv\" in layer:\n", " latent_reps[layer] *= batch.mask.unsqueeze(-1)\n", "\n", " latentX = torch.cat(\n", " [\n", " batch.X.to(rank),\n", " latent_reps[\"conv_reg0\"],\n", " latent_reps[\"conv_reg1\"],\n", " latent_reps[\"conv_reg2\"],\n", " latent_reps[\"nn_id\"],\n", " ],\n", " axis=-1,\n", " )\n", " return latentX" ] }, { "cell_type": "code", "execution_count": 54, "id": "1bf42610", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running MLPF inference on batch 0\n" ] } ], "source": [ "sample_to_lab = {\n", " \"clic_edm_ttbar_pf\": 1,\n", " \"clic_edm_qq_pf\": 0, \n", "}\n", "\n", "####################### Config\n", "jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)\n", "jet_ptcut = 15.0\n", "jet_match_dr = 0.1\n", "\n", "save_every_X_batch = 3 # will save to disk every \"X\" batches\n", "\n", "######################## Build the dataset\n", "jet_dataset = [] # will save on disk and reinitialize at the end of the loop\n", "saving_i = 0 # will just increment with every save\n", "\n", "for ibatch, batch in enumerate(train_loader):\n", "\n", " # run the MLPF model in inference mode to get the MLPF cands / latent representations \n", " print(f\"Running MLPF inference on batch {ibatch}\")\n", " batch = batch.to(rank, non_blocking=True)\n", "# with torch.no_grad():\n", " with torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, enabled=True):\n", " ymlpf = mlpf(batch.X, batch.mask)\n", " ymlpf = unpack_predictions(ymlpf)\n", " \n", " # get the latent representations\n", " ymlpf[\"latentX\"] = get_latent_reps(batch, latent_reps)\n", "\n", "# for k, v in ymlpf.items():\n", "# ymlpf[k] = v.detach().cpu()\n", "\n", "# msk_ymlpf = ymlpf[\"cls_id\"] != 0\n", "# ymlpf[\"p4\"] = ymlpf[\"p4\"] * msk_ymlpf.unsqueeze(-1)\n", "\n", "# jets_coll = {}\n", "# ####################### get the reco jet collection\n", "# vec = vector.awk(\n", "# ak.zip(\n", "# {\n", "# \"pt\": ymlpf[\"p4\"][:, :, 0].to(\"cpu\"),\n", "# \"eta\": ymlpf[\"p4\"][:, :, 1].to(\"cpu\"),\n", "# \"phi\": ymlpf[\"p4\"][:, :, 2].to(\"cpu\"),\n", "# \"e\": ymlpf[\"p4\"][:, :, 3].to(\"cpu\"),\n", "# }\n", "# )\n", "# )\n", "# cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)\n", "# jets_coll[\"reco\"] = cluster.inclusive_jets(min_pt=jet_ptcut)\n", " \n", "# # get the constituents to mask the MLPF candidates and build the input for the downstream \n", "# reco_constituents = cluster.constituent_index(min_pt=jet_ptcut)\n", "# ####################### \n", " \n", "# ####################### get the gen jet collection\n", "# ygen = unpack_target(batch.ygen) \n", "# vec = vector.awk(\n", "# ak.zip(\n", "# {\n", "# \"pt\": ygen[\"p4\"][:, :, 0].to(\"cpu\"),\n", "# \"eta\": ygen[\"p4\"][:, :, 1].to(\"cpu\"),\n", "# \"phi\": ygen[\"p4\"][:, :, 2].to(\"cpu\"),\n", "# \"e\": ygen[\"p4\"][:, :, 3].to(\"cpu\"),\n", "# }\n", "# )\n", "# )\n", "# cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)\n", "# jets_coll[\"gen\"] = cluster.inclusive_jets(min_pt=jet_ptcut) \n", "# #######################\n", " \n", "# matched_jets = match_two_jet_collections(jets_coll, \"reco\", \"gen\", jet_match_dr)\n", " \n", "# # build the big jet list\n", "# for iev in tqdm.tqdm(range(len(matched_jets[\"gen\"]))):\n", " \n", "# num_matched_jets = len(matched_jets[\"gen\"][iev]) # number of gen jets matched to reco\n", " \n", "# jets_per_event = []\n", "# for j in range(num_matched_jets):\n", " \n", "# # get the actual indices of the matched jets\n", "# igenjet = matched_jets[\"gen\"][iev][j]\n", "# irecojet = matched_jets[\"reco\"][iev][j]\n", "\n", "# # build a mask tensor that will select the particles that belong to the gen jet \n", "# msk_indices = reco_constituents[iev][irecojet].to_numpy()\n", "\n", "# if len(msk_indices)<3:\n", "# # don't save jets with very few particles\n", "# continue\n", "\n", "# jets_per_event += [\n", "\n", "# Data(\n", "# # Target for jet tagging\n", "# gen_jet_label=torch.tensor(sample_to_lab[sample]).unsqueeze(0).to(dtype=torch.float32),\n", "\n", "# # Target for jet p4 regression \n", "# gen_jet_pt=torch.tensor(jets_coll[\"gen\"][iev][igenjet].pt, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_eta=torch.tensor(jets_coll[\"gen\"][iev][igenjet].eta, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_phi=torch.tensor(jets_coll[\"gen\"][iev][igenjet].phi, dtype=torch.float32).unsqueeze(0),\n", "# gen_jet_energy=torch.tensor(jets_coll[\"gen\"][iev][igenjet].energy, dtype=torch.float32).unsqueeze(0),\n", " \n", "# # could be part of the target\n", "# reco_jet_pt=torch.tensor(jets_coll[\"reco\"][iev][irecojet].pt, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_eta=torch.tensor(jets_coll[\"reco\"][iev][irecojet].eta, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_phi=torch.tensor(jets_coll[\"reco\"][iev][irecojet].phi, dtype=torch.float32).unsqueeze(0),\n", "# reco_jet_energy=torch.tensor(jets_coll[\"reco\"][iev][irecojet].energy, dtype=torch.float32).unsqueeze(0),\n", " \n", "# # Input\n", "# mlpfcands_momentum=ymlpf[\"momentum\"][iev][msk_indices],\n", "# mlpfcands_pid=ymlpf[\"cls_id_onehot\"][iev][msk_indices],\n", "# mlpfcands_charge=ymlpf[\"charge\"][iev][msk_indices],\n", "# mlpfcands_latentX=ymlpf[\"latentX\"][iev][msk_indices],\n", "# )\n", "# ]\n", "\n", "# # break # per jet \n", " \n", "# # random.shuffle(jets_per_event)\n", "# jet_dataset += jets_per_event\n", "\n", "# break # per event\n", "\n", "# # random.shuffle(jet_dataset)\n", "# if (ibatch % (save_every_X_batch-1) == 0) and (ibatch != 0):\n", "# print(f\"saving at iteration {ibatch} on disk /pfvol/jetdataset/{sample}/train/{saving_i}.pt\")\n", "# torch.save(jet_dataset, f\"/pfvol/jetdataset/{sample}/train/{saving_i}.pt\")\n", "# saving_i += 1\n", "# jet_dataset = []\n", " break # per batch" ] }, { "cell_type": "code", "execution_count": 56, "id": "71bfbe1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.0000e+00, 4.0499e+00, 9.5381e-01, ..., -5.0312e+00,\n", " -4.0625e-01, -5.9766e-01],\n", " [ 1.0000e+00, 1.8992e+01, 6.8335e-02, ..., -6.7500e+00,\n", " 2.4414e-01, 1.2500e+00],\n", " [ 1.0000e+00, 4.8831e+00, 1.2232e-01, ..., -5.4688e+00,\n", " -2.0469e+00, 6.9375e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 5.2279e+01, 2.1137e-01, ..., -6.8125e+00,\n", " -1.6875e+00, 1.1597e-03],\n", " [ 1.0000e+00, 1.2576e+01, -3.7162e-02, ..., -6.2812e+00,\n", " 2.4219e+00, -2.2500e+00],\n", " [ 1.0000e+00, 8.7163e+00, -6.4841e-02, ..., -7.1250e+00,\n", " 3.0781e+00, 6.9922e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.2418e+01, -1.5321e-02, ..., -7.7188e+00,\n", " 1.7031e+00, 2.3594e+00],\n", " [ 1.0000e+00, 1.5101e+01, -3.2354e-02, ..., -8.1875e+00,\n", " 1.3594e+00, 5.6250e+00],\n", " [ 1.0000e+00, 2.4137e+01, 8.7432e-01, ..., -6.5312e+00,\n", " -7.1484e-01, 7.3750e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " ...,\n", "\n", " [[ 1.0000e+00, 2.7388e+01, 1.2865e+00, ..., -4.1250e+00,\n", " 5.6250e+00, -4.6875e+00],\n", " [ 1.0000e+00, 2.1186e+01, 1.8190e-01, ..., -6.2812e+00,\n", " 6.7500e+00, -2.5938e+00],\n", " [ 1.0000e+00, 1.5213e+01, -4.2097e-01, ..., -5.9688e+00,\n", " 1.8848e-01, -7.4219e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.1207e+01, -7.9346e-01, ..., -6.0000e+00,\n", " 1.4688e+00, -1.1016e+00],\n", " [ 1.0000e+00, 6.9457e+00, 9.0808e-02, ..., -7.4375e+00,\n", " 1.6953e+00, 7.5781e-01],\n", " [ 1.0000e+00, 6.7352e+00, -5.3468e-01, ..., -7.6875e+00,\n", " 1.0703e+00, 2.6094e+00],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]],\n", "\n", " [[ 1.0000e+00, 1.4524e+01, -4.8954e-01, ..., -7.2188e+00,\n", " 6.3438e+00, -2.3594e+00],\n", " [ 1.0000e+00, 1.0940e+01, -6.0876e-01, ..., -6.5938e+00,\n", " 4.4062e+00, -3.7500e+00],\n", " [ 1.0000e+00, 1.0282e+01, -6.9122e-01, ..., -7.3438e+00,\n", " 8.2031e-01, 7.6953e-01],\n", " ...,\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00],\n", " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.3125e+00,\n", " -5.4688e+00, -7.5625e+00]]], device='cuda:0')" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"latentX\"].detach()" ] }, { "cell_type": "code", "execution_count": 45, "id": "956e9d8b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"cls_id_onehot\"].requires_grad" ] }, { "cell_type": "code", "execution_count": 46, "id": "ce0911a4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ymlpf[\"latentX\"].requires_grad" ] }, { "cell_type": "code", "execution_count": null, "id": "03aa548e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "fe1dd2af", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "891aa275", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b1383441", "metadata": {}, "outputs": [], "source": [ "print(len(reco_constituents[0][0]))\n", "print(len(reco_constituents[0][1]))\n", "print(len(reco_constituents[0][2]))\n", "print(len(reco_constituents[0][3]))\n", "print(len(reco_constituents[0][4]))" ] }, { "cell_type": "code", "execution_count": 77, "id": "8e4d08ad", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{reco: [0, 1, 2, 3, 4],\n", " gen: [0, 1, 2, 3, 4]}\n", "-----------------------\n", "type: {\n", " reco: var * int64,\n", " gen: var * int64\n", "}" ], "text/plain": [ "
{reco: [0, 1, 2, 3],\n", " gen: [0, 1, 2, 3]}\n", "----------------------\n", "type: {\n", " reco: var * int64,\n", " gen: var * int64\n", "}" ], "text/plain": [ "