{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Subgroup analysis\n", "\n", "In many real-world applications, we are not just interested in the calibration of the overall population, but also interested in the calibration for subgroups within the population. calzone provides a simple way to perform subgroup analysis given some data input format. In order to perform subgroup analysis, the input csv file should contain the following columns:\n", "\n", "proba_0, proba_1, ..., proba_n, subgroup_1, subgroup_2, ..., subgroup_m, label\n", "\n", "\n", "where n >= 1 and m >= 1.\n", "\n", "In this example, we will use the example simulated dataset in the calzone package with only one subgroup field and two subgroups. See quickstart for more details.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['proba_0,proba_1,subgroup_1,label'\n", " '0.1444156178040511,0.8555843821959489,A,0'\n", " '0.8552048445812981,0.1447951554187019,A,0'\n", " '0.2569696048872897,0.7430303951127103,A,0'\n", " '0.39931305655530125,0.6006869434446988,A,1']\n", "Whether the dataset has subgroup: True\n" ] } ], "source": [ "### import the packages and read the data\n", "import numpy as np\n", "from calzone.utils import data_loader\n", "from calzone.metrics import CalibrationMetrics\n", "\n", "dataset = data_loader('../../../example_data/simulated_data_subgroup.csv')\n", "print(np.loadtxt('../../../example_data/simulated_data_subgroup.csv',dtype=str)[:5]) #first 5 lines of the csv files\n", "print(\"Whether the dataset has subgroup:\",dataset.have_subgroup)\n", "\n", "### Create the CalibrationMetrics class\n", "metrics_cal = CalibrationMetrics(class_to_calculate=1)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subgroup 1\n", "subgroup 1 class A\n", "SpiegelhalterZ score: 0.3763269161877356\n", "SpiegelhalterZ p-value: 0.7066738713391099\n", "ECE-H topclass: 0.009608653731328977\n", "ECE-H: 0.01208775955804901\n", "MCE-H topclass: 0.03926468843081976\n", "MCE-H: 0.04848338618970194\n", "HL-H score: 8.884991559088098\n", "HL-H p-value: 0.35209071874348785\n", "ECE-C topclass: 0.009458033653818828\n", "ECE-C: 0.008733966945443138\n", "MCE-C topclass: 0.020515047600205505\n", "MCE-C: 0.02324031223486256\n", "HL-C score: 3.694947603203135\n", "HL-C p-value: 0.8835446575708198\n", "COX coef: 0.9942499557748269\n", "COX intercept: -0.04497652296600376\n", "COX coef lowerci: 0.9372902801721911\n", "COX coef upperci: 1.0512096313774626\n", "COX intercept lowerci: -0.12348577118577644\n", "COX intercept upperci: 0.03353272525376893\n", "COX ICI: 0.005610391483826338\n", "Loess ICI: 0.00558856942568957\n", "subgroup 1 class B\n", "SpiegelhalterZ score: 27.93575342117766\n", "SpiegelhalterZ p-value: 0.0\n", "ECE-H topclass: 0.07658928982434714\n", "ECE-H: 0.0765892898243467\n", "MCE-H topclass: 0.1327565894838103\n", "MCE-H: 0.16250572519432438\n", "HL-H score: 910.4385762101924\n", "HL-H p-value: 0.0\n", "ECE-C topclass: 0.07429481165606829\n", "ECE-C: 0.07479369479609524\n", "MCE-C topclass: 0.14090872416947742\n", "MCE-C: 0.14045600565696226\n", "HL-C score: 2246.1714434139853\n", "HL-C p-value: 0.0\n", "COX coef: 0.5071793536874274\n", "COX intercept: 0.00037947714112375366\n", "COX coef lowerci: 0.47838663128188996\n", "COX coef upperci: 0.5359720760929648\n", "COX intercept lowerci: -0.07796623141885761\n", "COX intercept upperci: 0.07872518570110512\n", "COX ICI: 0.07746407648179383\n", "Loess ICI: 0.06991428582761099\n" ] } ], "source": [ "### subgroup analysis for each group\n", "### You can preform other analysis during the loop (eg. plotting the reliability diagram etc)\n", "for i,subgroup_column in enumerate(dataset.subgroup_indices):\n", " print(f\"subgroup {i+1}\")\n", " for j,subgroup_class in enumerate(dataset.subgroups_class[i]):\n", " print(f\"subgroup {i+1} class {subgroup_class}\")\n", " proba = dataset.probs[dataset.subgroups_index[i][j],:]\n", " label = dataset.labels[dataset.subgroups_index[i][j]]\n", " result = metrics_cal.calculate_metrics(label, proba,metrics='all')\n", " for metric in result:\n", " print(f\"{metric}: {result[metric]}\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metrics:\n", "SpiegelhalterZ score: 18.327\n", "SpiegelhalterZ p-value: 0.\n", "ECE-H topclass: 0.042\n", "ECE-H: 0.042\n", "MCE-H topclass: 0.055\n", "MCE-H: 0.063\n", "HL-H score: 429.732\n", "HL-H p-value: 0.\n", "ECE-C topclass: 0.042\n", "ECE-C: 0.038\n", "MCE-C topclass: 0.065\n", "MCE-C: 0.064\n", "HL-C score: 1138.842\n", "HL-C p-value: 0.\n", "COX coef: 0.668\n", "COX intercept: -0.02\n", "COX coef lowerci: 0.641\n", "COX coef upperci: 0.696\n", "COX intercept lowerci: -0.074\n", "COX intercept upperci: 0.034\n", "COX ICI: 0.049\n", "Loess ICI: 0.037\n", "Metrics for subgroup subgroup_1_group_A:\n", "SpiegelhalterZ score: 0.376\n", "SpiegelhalterZ p-value: 0.707\n", "ECE-H topclass: 0.01\n", "ECE-H: 0.012\n", "MCE-H topclass: 0.039\n", "MCE-H: 0.048\n", "HL-H score: 8.885\n", "HL-H p-value: 0.352\n", "ECE-C topclass: 0.009\n", "ECE-C: 0.009\n", "MCE-C topclass: 0.021\n", "MCE-C: 0.023\n", "HL-C score: 3.695\n", "HL-C p-value: 0.884\n", "COX coef: 0.994\n", "COX intercept: -0.045\n", "COX coef lowerci: 0.937\n", "COX coef upperci: 1.051\n", "COX intercept lowerci: -0.123\n", "COX intercept upperci: 0.034\n", "COX ICI: 0.006\n", "Loess ICI: 0.006\n", "Metrics for subgroup subgroup_1_group_B:\n", "SpiegelhalterZ score: 27.936\n", "SpiegelhalterZ p-value: 0.\n", "ECE-H topclass: 0.077\n", "ECE-H: 0.077\n", "MCE-H topclass: 0.133\n", "MCE-H: 0.163\n", "HL-H score: 910.439\n", "HL-H p-value: 0.\n", "ECE-C topclass: 0.074\n", "ECE-C: 0.075\n", "MCE-C topclass: 0.141\n", "MCE-C: 0.140\n", "HL-C score: 2246.171\n", "HL-C p-value: 0.\n", "COX coef: 0.507\n", "COX intercept: 0.000\n", "COX coef lowerci: 0.478\n", "COX coef upperci: 0.536\n", "COX intercept lowerci: -0.078\n", "COX intercept upperci: 0.079\n", "COX ICI: 0.077\n", "Loess ICI: 0.07\n" ] } ], "source": [ "### An alernative way to do the same thing is through command line interface\n", "\n", "%run ../../../cal_metrics.py \\\n", "--csv_file '../../../example_data/simulated_data_subgroup.csv' \\\n", "--metrics all \\\n", "--class_to_calculate 1 \\\n", "--num_bins 10 \\\n", "--verbose" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "uq", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 2 }