Jensen-holm commited on
Commit
ef83bf7
·
1 Parent(s): 6d466dc

re trained the models on only tournament games and the ChalkSeedDiff

Browse files

added as a feature. This really helped the womens neural network, but
the mens one is maybe a little bit worse

data/AllSuperDetailedGames.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:02656ec3af193d2e1823e9b1d1914a7293d77b957739ec6407b67b5453df7878
3
- size 978121784
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ec95fd20de671096891e8426969303b879e0720e52e6da3dc55a0369ba98787
3
+ size 1046244854
data/AllTeamsAgg.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1b9dae17132c76ed88c7ea1972c0149cdb954d3bad7b0fc07b90c8bda66fdce
3
  size 31040659
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdbaeb6c905ad5f6480ffe9acb3ebb75ffe8905954826574410c0d8c94a12826
3
  size 31040659
models/Mnn10k.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7863257c283f71b41da57decaaad44c8175b0428e5b339252f290e9de5f58298
3
- size 18914
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:500d7ddd0596b59cf7fbfab9abc7d8f50278f8269f4399e9cb5222fe843418a7
3
+ size 19170
models/Wnn10k.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae171639c1af16b2b6a199cbd915641cd77f7f2414d81472b46cab3520c7040a
3
- size 18914
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1de0daa46afa742af4a7d31a026b54917abfca4c972546208d50b827e2fe4119
3
+ size 19170
src/__pycache__/visual_eval.cpython-311.pyc ADDED
Binary file (2.7 kB). View file
 
src/baseline.ipynb CHANGED
@@ -68,7 +68,7 @@
68
  "# games_df[\"BaselinePrediction\"] = games_df.apply(\n",
69
  "# lambda row: predict_baseline(row),\n",
70
  "# axis=1,\n",
71
- "# )\n"
72
  ]
73
  },
74
  {
@@ -461,7 +461,7 @@
461
  "wmns_actual_T = torch.tensor(\n",
462
  " wmns_subset[\"Win\"].values,\n",
463
  " dtype=torch.float32,\n",
464
- ")\n"
465
  ]
466
  },
467
  {
@@ -513,7 +513,7 @@
513
  "plt.ylabel(\"True Positive Rate\")\n",
514
  "\n",
515
  "plt.tight_layout()\n",
516
- "plt.show()\n"
517
  ]
518
  },
519
  {
 
68
  "# games_df[\"BaselinePrediction\"] = games_df.apply(\n",
69
  "# lambda row: predict_baseline(row),\n",
70
  "# axis=1,\n",
71
+ "# )"
72
  ]
73
  },
74
  {
 
461
  "wmns_actual_T = torch.tensor(\n",
462
  " wmns_subset[\"Win\"].values,\n",
463
  " dtype=torch.float32,\n",
464
+ ")"
465
  ]
466
  },
467
  {
 
513
  "plt.ylabel(\"True Positive Rate\")\n",
514
  "\n",
515
  "plt.tight_layout()\n",
516
+ "plt.show()"
517
  ]
518
  },
519
  {
src/nn.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/pre_processing.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  "import numpy as np\n",
11
  "import os\n",
12
  "\n",
13
- "DATA_DIR = os.path.join(\"..\", \"data\") "
14
  ]
15
  },
16
  {
@@ -212,10 +212,16 @@
212
  }
213
  ],
214
  "source": [
215
- "detailed_tourney_games_df = pd.concat([\n",
216
- " pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneyDetailedResults.csv\")).assign(League=\"M\"),\n",
217
- " pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneyDetailedResults.csv\")).assign(League=\"W\"),\n",
218
- "])\n",
 
 
 
 
 
 
219
  "\n",
220
  "detailed_tourney_games_df.sample(5, random_state=1)"
221
  ]
@@ -419,10 +425,16 @@
419
  }
420
  ],
421
  "source": [
422
- "detailed_reg_games_df = pd.concat([\n",
423
- " pd.read_csv(os.path.join(DATA_DIR, \"MRegularSeasonDetailedResults.csv\")).assign(League=\"M\"), \n",
424
- " pd.read_csv(os.path.join(DATA_DIR, \"WRegularSeasonDetailedResults.csv\")).assign(League=\"W\"),\n",
425
- "])\n",
 
 
 
 
 
 
426
  "\n",
427
  "detailed_reg_games_df.sample(5, random_state=1)"
428
  ]
@@ -445,7 +457,7 @@
445
  "\n",
446
  "detailed_metrics = {\n",
447
  " \"Score\",\n",
448
- " # \"Loc\", \n",
449
  " \"FGM\",\n",
450
  " \"FGA\",\n",
451
  " \"FGM3\",\n",
@@ -460,8 +472,12 @@
460
  " \"PF\",\n",
461
  "}\n",
462
  "\n",
463
- "w_renamed_cols = {f\"W{col}\": f\"Team{col}\" for col in detailed_metrics} | {f\"L{col}\": f\"Opp{col}\" for col in detailed_metrics}\n",
464
- "l_renamed_cols = {f\"L{col}\": f\"Team{col}\" for col in detailed_metrics} | {f\"W{col}\": f\"Opp{col}\" for col in detailed_metrics}"
 
 
 
 
465
  ]
466
  },
467
  {
@@ -520,22 +536,26 @@
520
  }
521
  ],
522
  "source": [
523
- "\n",
524
- "detailed_reg_games_df = pd.concat([\n",
525
- " (\n",
526
- " # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"LTeamID\"]]\n",
527
- " detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
528
- " .assign(GameResult=\"W\")\n",
529
- " .rename(columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"})\n",
530
- " ),\n",
531
- " (\n",
532
- " # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"WTeamID\"]]\n",
533
- " detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
534
- " .assign(GameResult=\"L\")\n",
535
- " .rename(columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"})\n",
536
- " )\n",
537
- "\n",
538
- "]).reset_index(drop=True)\n",
 
 
 
 
539
  "\n",
540
  "detailed_reg_games_df.info()"
541
  ]
@@ -597,20 +617,30 @@
597
  ],
598
  "source": [
599
  "# do the same thing for the tournament games\n",
600
- "detailed_tourney_games_df = pd.concat([\n",
601
- " (\n",
602
- " # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"LTeamID\"]]\n",
603
- " detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns]]\n",
604
- " .assign(GameResult=\"W\")\n",
605
- " .rename(columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"})\n",
606
- " ),\n",
607
- " (\n",
608
- " # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"WTeamID\"]]\n",
609
- " detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns]]\n",
610
- " .assign(GameResult=\"L\")\n",
611
- " .rename(columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"})\n",
612
- " )\n",
613
- "]).reset_index(drop=True)\n",
 
 
 
 
 
 
 
 
 
 
614
  "\n",
615
  "detailed_tourney_games_df.info()"
616
  ]
@@ -621,7 +651,6 @@
621
  "metadata": {},
622
  "outputs": [],
623
  "source": [
624
- "\n",
625
  "for col in detailed_metrics:\n",
626
  " detailed_reg_games_df[f\"{col}Diff\"] = detailed_reg_games_df.apply(\n",
627
  " lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
@@ -631,7 +660,7 @@
631
  " detailed_tourney_games_df[f\"{col}Diff\"] = detailed_tourney_games_df.apply(\n",
632
  " lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
633
  " axis=1,\n",
634
- " )\n"
635
  ]
636
  },
637
  {
@@ -671,16 +700,16 @@
671
  " <th>TeamFGM</th>\n",
672
  " <th>TeamFGA</th>\n",
673
  " <th>...</th>\n",
674
- " <th>PFDiff</th>\n",
675
- " <th>TODiff</th>\n",
676
- " <th>ORDiff</th>\n",
677
- " <th>FGMDiff</th>\n",
678
- " <th>BlkDiff</th>\n",
679
  " <th>FTADiff</th>\n",
680
- " <th>StlDiff</th>\n",
681
- " <th>FGM3Diff</th>\n",
682
  " <th>ScoreDiff</th>\n",
683
  " <th>FGADiff</th>\n",
 
 
 
 
 
 
684
  " </tr>\n",
685
  " </thead>\n",
686
  " <tbody>\n",
@@ -697,16 +726,16 @@
697
  " <td>21</td>\n",
698
  " <td>55</td>\n",
699
  " <td>...</td>\n",
700
- " <td>9</td>\n",
701
- " <td>7</td>\n",
702
  " <td>-11</td>\n",
703
- " <td>-7</td>\n",
 
 
704
  " <td>1</td>\n",
 
705
  " <td>-11</td>\n",
706
  " <td>-7</td>\n",
707
- " <td>-3</td>\n",
708
- " <td>-28</td>\n",
709
- " <td>-12</td>\n",
710
  " </tr>\n",
711
  " <tr>\n",
712
  " <th>100732</th>\n",
@@ -721,16 +750,16 @@
721
  " <td>23</td>\n",
722
  " <td>60</td>\n",
723
  " <td>...</td>\n",
724
- " <td>-9</td>\n",
725
- " <td>-6</td>\n",
726
- " <td>-1</td>\n",
727
- " <td>-1</td>\n",
728
- " <td>2</td>\n",
729
  " <td>17</td>\n",
730
- " <td>4</td>\n",
731
- " <td>-2</td>\n",
732
  " <td>12</td>\n",
733
  " <td>-4</td>\n",
 
 
 
 
 
 
734
  " </tr>\n",
735
  " <tr>\n",
736
  " <th>83150</th>\n",
@@ -745,16 +774,16 @@
745
  " <td>27</td>\n",
746
  " <td>58</td>\n",
747
  " <td>...</td>\n",
 
748
  " <td>-5</td>\n",
 
 
 
749
  " <td>1</td>\n",
750
  " <td>4</td>\n",
751
- " <td>-1</td>\n",
752
- " <td>2</td>\n",
753
- " <td>10</td>\n",
754
  " <td>-5</td>\n",
755
  " <td>1</td>\n",
756
  " <td>13</td>\n",
757
- " <td>-6</td>\n",
758
  " </tr>\n",
759
  " <tr>\n",
760
  " <th>345009</th>\n",
@@ -769,16 +798,16 @@
769
  " <td>19</td>\n",
770
  " <td>55</td>\n",
771
  " <td>...</td>\n",
772
- " <td>7</td>\n",
773
- " <td>-5</td>\n",
774
- " <td>2</td>\n",
775
- " <td>1</td>\n",
776
- " <td>-3</td>\n",
777
  " <td>-11</td>\n",
 
 
 
778
  " <td>-3</td>\n",
779
  " <td>-1</td>\n",
 
 
 
780
  " <td>-7</td>\n",
781
- " <td>13</td>\n",
782
  " </tr>\n",
783
  " <tr>\n",
784
  " <th>318707</th>\n",
@@ -793,16 +822,16 @@
793
  " <td>20</td>\n",
794
  " <td>51</td>\n",
795
  " <td>...</td>\n",
 
796
  " <td>2</td>\n",
797
- " <td>4</td>\n",
798
- " <td>2</td>\n",
799
- " <td>-3</td>\n",
800
  " <td>1</td>\n",
801
- " <td>-11</td>\n",
 
802
  " <td>-7</td>\n",
 
803
  " <td>-3</td>\n",
804
- " <td>-18</td>\n",
805
- " <td>3</td>\n",
806
  " </tr>\n",
807
  " </tbody>\n",
808
  "</table>\n",
@@ -817,19 +846,19 @@
817
  "345009 2019 4 3435 58 3292 65 H 0 \n",
818
  "318707 2013 128 3322 45 3270 63 N 0 \n",
819
  "\n",
820
- " TeamFGM TeamFGA ... PFDiff TODiff ORDiff FGMDiff BlkDiff \\\n",
821
- "337067 21 55 ... 9 7 -11 -7 1 \n",
822
- "100732 23 60 ... -9 -6 -1 -1 2 \n",
823
- "83150 27 58 ... -5 1 4 -1 2 \n",
824
- "345009 19 55 ... 7 -5 2 1 -3 \n",
825
- "318707 20 51 ... 2 4 2 -3 1 \n",
826
  "\n",
827
- " FTADiff StlDiff FGM3Diff ScoreDiff FGADiff \n",
828
- "337067 -11 -7 -3 -28 -12 \n",
829
- "100732 17 4 -2 12 -4 \n",
830
- "83150 10 -5 1 13 -6 \n",
831
- "345009 -11 -3 -1 -7 13 \n",
832
- "318707 -11 -7 -3 -18 3 \n",
833
  "\n",
834
  "[5 rows x 49 columns]"
835
  ]
@@ -878,10 +907,12 @@
878
  "source": [
879
  "# combine the two detailed game dataframes into one for future use\n",
880
  "\n",
881
- "all_detailed_games_df = pd.concat([\n",
882
- " detailed_reg_games_df.assign(GameType=\"reg\"),\n",
883
- " detailed_tourney_games_df.assign(GameType=\"tourney\"),\n",
884
- "])"
 
 
885
  ]
886
  },
887
  {
@@ -1306,7 +1337,13 @@
1306
  "source": [
1307
  "team_reg_agg = (\n",
1308
  " detailed_reg_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
1309
- " .agg({col: agg_funcs for col in detailed_reg_games_df.select_dtypes(\"number\").columns if col not in exclude_agg_cols})\n",
 
 
 
 
 
 
1310
  " .reset_index()\n",
1311
  ")\n",
1312
  "\n",
@@ -1668,15 +1705,23 @@
1668
  }
1669
  ],
1670
  "source": [
1671
- "# aggregate the same metrics for the tournament dataset \n",
1672
  "\n",
1673
  "team_tourney_agg = (\n",
1674
  " detailed_tourney_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
1675
- " .agg({col: agg_funcs for col in detailed_tourney_games_df.select_dtypes(\"number\").columns if col not in exclude_agg_cols})\n",
 
 
 
 
 
 
1676
  " .reset_index()\n",
1677
  ")\n",
1678
  "\n",
1679
- "team_tourney_agg.columns = [\" \".join(col).strip() for col in team_tourney_agg.columns.values]\n",
 
 
1680
  "\n",
1681
  "team_tourney_agg.sample(10, random_state=1)"
1682
  ]
@@ -1727,83 +1772,83 @@
1727
  " </thead>\n",
1728
  " <tbody>\n",
1729
  " <tr>\n",
1730
- " <th>0</th>\n",
1731
- " <td>1985</td>\n",
1732
- " <td>W01</td>\n",
1733
- " <td>1207</td>\n",
1734
- " <td>M</td>\n",
1735
- " <td>big_east</td>\n",
1736
- " <td>Georgetown</td>\n",
1737
- " <td>1985</td>\n",
1738
- " <td>2024</td>\n",
1739
- " <td>1</td>\n",
1740
  " </tr>\n",
1741
  " <tr>\n",
1742
- " <th>1</th>\n",
1743
- " <td>1986</td>\n",
1744
- " <td>X04</td>\n",
1745
- " <td>1207</td>\n",
1746
- " <td>M</td>\n",
1747
- " <td>big_east</td>\n",
1748
- " <td>Georgetown</td>\n",
1749
- " <td>1985</td>\n",
1750
- " <td>2024</td>\n",
1751
- " <td>4</td>\n",
1752
  " </tr>\n",
1753
  " <tr>\n",
1754
- " <th>2</th>\n",
1755
- " <td>1987</td>\n",
1756
- " <td>X01</td>\n",
1757
- " <td>1207</td>\n",
1758
  " <td>M</td>\n",
1759
- " <td>big_east</td>\n",
1760
- " <td>Georgetown</td>\n",
1761
- " <td>1985</td>\n",
1762
- " <td>2024</td>\n",
1763
- " <td>1</td>\n",
1764
  " </tr>\n",
1765
  " <tr>\n",
1766
- " <th>3</th>\n",
1767
- " <td>1988</td>\n",
1768
- " <td>W08</td>\n",
1769
- " <td>1207</td>\n",
1770
  " <td>M</td>\n",
1771
- " <td>big_east</td>\n",
1772
- " <td>Georgetown</td>\n",
1773
- " <td>1985</td>\n",
1774
- " <td>2024</td>\n",
1775
- " <td>8</td>\n",
1776
  " </tr>\n",
1777
  " <tr>\n",
1778
- " <th>4</th>\n",
1779
- " <td>1989</td>\n",
1780
- " <td>W01</td>\n",
1781
- " <td>1207</td>\n",
1782
- " <td>M</td>\n",
1783
  " <td>big_east</td>\n",
1784
- " <td>Georgetown</td>\n",
1785
- " <td>1985</td>\n",
1786
- " <td>2024</td>\n",
1787
- " <td>1</td>\n",
1788
  " </tr>\n",
1789
  " </tbody>\n",
1790
  "</table>\n",
1791
  "</div>"
1792
  ],
1793
  "text/plain": [
1794
- " Season Seed TeamID League ConfAbbrev TeamName FirstD1Season \\\n",
1795
- "0 1985 W01 1207 M big_east Georgetown 1985 \n",
1796
- "1 1986 X04 1207 M big_east Georgetown 1985 \n",
1797
- "2 1987 X01 1207 M big_east Georgetown 1985 \n",
1798
- "3 1988 W08 1207 M big_east Georgetown 1985 \n",
1799
- "4 1989 W01 1207 M big_east Georgetown 1985 \n",
1800
  "\n",
1801
- " LastD1Season ChalkSeed \n",
1802
- "0 2024 1 \n",
1803
- "1 2024 4 \n",
1804
- "2 2024 1 \n",
1805
- "3 2024 8 \n",
1806
- "4 2024 1 "
1807
  ]
1808
  },
1809
  "execution_count": 15,
@@ -1812,26 +1857,34 @@
1812
  }
1813
  ],
1814
  "source": [
1815
- "conference_df = pd.concat([\n",
1816
- " # pd.read_csv(os.path.join(DATA_DIR, \"MTeamConferences.csv\")).assign(League=\"M\"),\n",
1817
- " # pd.read_csv(os.path.join(DATA_DIR, \"WTeamConferences.csv\")).assign(League=\"W\"),\n",
1818
- "\n",
1819
- " pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneySeeds.csv\")).assign(League=\"M\"),\n",
1820
- " pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneySeeds.csv\")).assign(League=\"W\"),\n",
1821
- "])\n",
1822
  "\n",
1823
- "team_conf_seeds_df = (\n",
1824
- " conference_df.merge(\n",
1825
- " right=(pd.concat([\n",
1826
- " # pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneySeeds.csv\")).assign(League=\"M\"),\n",
1827
- " # pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneySeeds.csv\")).assign(League=\"W\"),\n",
1828
- " pd.read_csv(os.path.join(DATA_DIR, \"MTeamConferences.csv\")).assign(League=\"M\"),\n",
1829
- " pd.read_csv(os.path.join(DATA_DIR, \"WTeamConferences.csv\")).assign(League=\"W\"),\n",
1830
- " ])),\n",
1831
- " on=[\"League\", \"Season\", \"TeamID\"],\n",
1832
- " how=\"left\",\n",
1833
- " )\n",
1834
- " .merge(right=pd.read_csv(os.path.join(DATA_DIR, \"MTeams.csv\")), on=\"TeamID\")\n",
 
 
 
 
 
 
 
 
 
1835
  ")\n",
1836
  "\n",
1837
  "team_conf_seeds_df[\"ChalkSeed\"] = team_conf_seeds_df.apply(\n",
@@ -1839,7 +1892,7 @@
1839
  " axis=1,\n",
1840
  ")\n",
1841
  "\n",
1842
- "team_conf_seeds_df.head()"
1843
  ]
1844
  },
1845
  {
@@ -2221,10 +2274,10 @@
2221
  "source": [
2222
  "# merge the tournament aggregated metrics with the regular season aggregated metrics\n",
2223
  "team_agg_df = pd.merge(\n",
2224
- " left=team_reg_agg, \n",
2225
- " right=team_tourney_agg, \n",
2226
  " how=\"left\",\n",
2227
- " on=[\"TeamID\", \"Season\", \"League\"], \n",
2228
  " suffixes=(\" reg\", \" tourney\"),\n",
2229
  " validate=\"1:1\",\n",
2230
  ")\n",
@@ -2260,10 +2313,10 @@
2260
  "output_type": "stream",
2261
  "text": [
2262
  "<class 'pandas.core.frame.DataFrame'>\n",
2263
- "Int64Index: 12857 entries, 0 to 12856\n",
2264
  "Columns: 459 entries, TeamID to ChalkSeed\n",
2265
- "dtypes: float64(363), int64(92), object(4)\n",
2266
- "memory usage: 45.1+ MB\n"
2267
  ]
2268
  }
2269
  ],
@@ -2283,7 +2336,7 @@
2283
  "<class 'pandas.core.frame.DataFrame'>\n",
2284
  "Int64Index: 377608 entries, 0 to 377607\n",
2285
  "Columns: 508 entries, Season to ChalkSeed\n",
2286
- "dtypes: float64(363), int64(138), object(7)\n",
2287
  "memory usage: 1.4+ GB\n"
2288
  ]
2289
  }
@@ -2344,7 +2397,7 @@
2344
  "<class 'pandas.core.frame.DataFrame'>\n",
2345
  "Int64Index: 377608 entries, 0 to 377607\n",
2346
  "Columns: 509 entries, Season to OppChalkSeed\n",
2347
- "dtypes: float64(364), int64(138), object(7)\n",
2348
  "memory usage: 1.4+ GB\n"
2349
  ]
2350
  }
@@ -2352,7 +2405,9 @@
2352
  "source": [
2353
  "opp_chalk_seed_map = team_conf_seeds_df.groupby(\"TeamID\")[\"ChalkSeed\"].last()\n",
2354
  "\n",
2355
- "super_detailed_games_df[\"OppChalkSeed\"] = super_detailed_games_df[\"OppTeamID\"].map(opp_chalk_seed_map)\n",
 
 
2356
  "\n",
2357
  "super_detailed_games_df.info()"
2358
  ]
@@ -2365,18 +2420,18 @@
2365
  {
2366
  "data": {
2367
  "text/plain": [
2368
- "0 8.0\n",
2369
- "1 11.0\n",
2370
- "2 2.0\n",
2371
- "3 12.0\n",
2372
- "4 10.0\n",
2373
- " ... \n",
2374
- "377603 NaN\n",
2375
- "377604 NaN\n",
2376
- "377605 NaN\n",
2377
- "377606 NaN\n",
2378
- "377607 NaN\n",
2379
- "Name: OppChalkSeed, Length: 377608, dtype: float64"
2380
  ]
2381
  },
2382
  "execution_count": 22,
@@ -2385,7 +2440,9 @@
2385
  }
2386
  ],
2387
  "source": [
2388
- "super_detailed_games_df[\"OppChalkSeed\"]"
 
 
2389
  ]
2390
  },
2391
  {
 
10
  "import numpy as np\n",
11
  "import os\n",
12
  "\n",
13
+ "DATA_DIR = os.path.join(\"..\", \"data\")"
14
  ]
15
  },
16
  {
 
212
  }
213
  ],
214
  "source": [
215
+ "detailed_tourney_games_df = pd.concat(\n",
216
+ " [\n",
217
+ " pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneyDetailedResults.csv\")).assign(\n",
218
+ " League=\"M\"\n",
219
+ " ),\n",
220
+ " pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneyDetailedResults.csv\")).assign(\n",
221
+ " League=\"W\"\n",
222
+ " ),\n",
223
+ " ]\n",
224
+ ")\n",
225
  "\n",
226
  "detailed_tourney_games_df.sample(5, random_state=1)"
227
  ]
 
425
  }
426
  ],
427
  "source": [
428
+ "detailed_reg_games_df = pd.concat(\n",
429
+ " [\n",
430
+ " pd.read_csv(os.path.join(DATA_DIR, \"MRegularSeasonDetailedResults.csv\")).assign(\n",
431
+ " League=\"M\"\n",
432
+ " ),\n",
433
+ " pd.read_csv(os.path.join(DATA_DIR, \"WRegularSeasonDetailedResults.csv\")).assign(\n",
434
+ " League=\"W\"\n",
435
+ " ),\n",
436
+ " ]\n",
437
+ ")\n",
438
  "\n",
439
  "detailed_reg_games_df.sample(5, random_state=1)"
440
  ]
 
457
  "\n",
458
  "detailed_metrics = {\n",
459
  " \"Score\",\n",
460
+ " # \"Loc\",\n",
461
  " \"FGM\",\n",
462
  " \"FGA\",\n",
463
  " \"FGM3\",\n",
 
472
  " \"PF\",\n",
473
  "}\n",
474
  "\n",
475
+ "w_renamed_cols = {f\"W{col}\": f\"Team{col}\" for col in detailed_metrics} | {\n",
476
+ " f\"L{col}\": f\"Opp{col}\" for col in detailed_metrics\n",
477
+ "}\n",
478
+ "l_renamed_cols = {f\"L{col}\": f\"Team{col}\" for col in detailed_metrics} | {\n",
479
+ " f\"W{col}\": f\"Opp{col}\" for col in detailed_metrics\n",
480
+ "}"
481
  ]
482
  },
483
  {
 
536
  }
537
  ],
538
  "source": [
539
+ "detailed_reg_games_df = pd.concat(\n",
540
+ " [\n",
541
+ " (\n",
542
+ " # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"LTeamID\"]]\n",
543
+ " detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
544
+ " .assign(GameResult=\"W\")\n",
545
+ " .rename(\n",
546
+ " columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"}\n",
547
+ " )\n",
548
+ " ),\n",
549
+ " (\n",
550
+ " # detailed_reg_games_df[[col for col in detailed_reg_games_df.columns if col != \"WTeamID\"]]\n",
551
+ " detailed_reg_games_df[[col for col in detailed_reg_games_df.columns]]\n",
552
+ " .assign(GameResult=\"L\")\n",
553
+ " .rename(\n",
554
+ " columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"}\n",
555
+ " )\n",
556
+ " ),\n",
557
+ " ]\n",
558
+ ").reset_index(drop=True)\n",
559
  "\n",
560
  "detailed_reg_games_df.info()"
561
  ]
 
617
  ],
618
  "source": [
619
  "# do the same thing for the tournament games\n",
620
+ "detailed_tourney_games_df = pd.concat(\n",
621
+ " [\n",
622
+ " (\n",
623
+ " # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"LTeamID\"]]\n",
624
+ " detailed_tourney_games_df[\n",
625
+ " [col for col in detailed_tourney_games_df.columns]\n",
626
+ " ]\n",
627
+ " .assign(GameResult=\"W\")\n",
628
+ " .rename(\n",
629
+ " columns=w_renamed_cols | {\"WTeamID\": \"TeamID\", \"LTeamID\": \"OppTeamID\"}\n",
630
+ " )\n",
631
+ " ),\n",
632
+ " (\n",
633
+ " # detailed_tourney_games_df[[col for col in detailed_tourney_games_df.columns if col != \"WTeamID\"]]\n",
634
+ " detailed_tourney_games_df[\n",
635
+ " [col for col in detailed_tourney_games_df.columns]\n",
636
+ " ]\n",
637
+ " .assign(GameResult=\"L\")\n",
638
+ " .rename(\n",
639
+ " columns=l_renamed_cols | {\"LTeamID\": \"TeamID\", \"WTeamID\": \"OppTeamID\"}\n",
640
+ " )\n",
641
+ " ),\n",
642
+ " ]\n",
643
+ ").reset_index(drop=True)\n",
644
  "\n",
645
  "detailed_tourney_games_df.info()"
646
  ]
 
651
  "metadata": {},
652
  "outputs": [],
653
  "source": [
 
654
  "for col in detailed_metrics:\n",
655
  " detailed_reg_games_df[f\"{col}Diff\"] = detailed_reg_games_df.apply(\n",
656
  " lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
 
660
  " detailed_tourney_games_df[f\"{col}Diff\"] = detailed_tourney_games_df.apply(\n",
661
  " lambda row: row[f\"Team{col}\"] - row[f\"Opp{col}\"],\n",
662
  " axis=1,\n",
663
+ " )"
664
  ]
665
  },
666
  {
 
700
  " <th>TeamFGM</th>\n",
701
  " <th>TeamFGA</th>\n",
702
  " <th>...</th>\n",
 
 
 
 
 
703
  " <th>FTADiff</th>\n",
704
+ " <th>PFDiff</th>\n",
 
705
  " <th>ScoreDiff</th>\n",
706
  " <th>FGADiff</th>\n",
707
+ " <th>BlkDiff</th>\n",
708
+ " <th>FGM3Diff</th>\n",
709
+ " <th>ORDiff</th>\n",
710
+ " <th>StlDiff</th>\n",
711
+ " <th>AstDiff</th>\n",
712
+ " <th>DRDiff</th>\n",
713
  " </tr>\n",
714
  " </thead>\n",
715
  " <tbody>\n",
 
726
  " <td>21</td>\n",
727
  " <td>55</td>\n",
728
  " <td>...</td>\n",
 
 
729
  " <td>-11</td>\n",
730
+ " <td>9</td>\n",
731
+ " <td>-28</td>\n",
732
+ " <td>-12</td>\n",
733
  " <td>1</td>\n",
734
+ " <td>-3</td>\n",
735
  " <td>-11</td>\n",
736
  " <td>-7</td>\n",
737
+ " <td>-1</td>\n",
738
+ " <td>-4</td>\n",
 
739
  " </tr>\n",
740
  " <tr>\n",
741
  " <th>100732</th>\n",
 
750
  " <td>23</td>\n",
751
  " <td>60</td>\n",
752
  " <td>...</td>\n",
 
 
 
 
 
753
  " <td>17</td>\n",
754
+ " <td>-9</td>\n",
 
755
  " <td>12</td>\n",
756
  " <td>-4</td>\n",
757
+ " <td>2</td>\n",
758
+ " <td>-2</td>\n",
759
+ " <td>-1</td>\n",
760
+ " <td>4</td>\n",
761
+ " <td>11</td>\n",
762
+ " <td>1</td>\n",
763
  " </tr>\n",
764
  " <tr>\n",
765
  " <th>83150</th>\n",
 
774
  " <td>27</td>\n",
775
  " <td>58</td>\n",
776
  " <td>...</td>\n",
777
+ " <td>10</td>\n",
778
  " <td>-5</td>\n",
779
+ " <td>13</td>\n",
780
+ " <td>-6</td>\n",
781
+ " <td>2</td>\n",
782
  " <td>1</td>\n",
783
  " <td>4</td>\n",
 
 
 
784
  " <td>-5</td>\n",
785
  " <td>1</td>\n",
786
  " <td>13</td>\n",
 
787
  " </tr>\n",
788
  " <tr>\n",
789
  " <th>345009</th>\n",
 
798
  " <td>19</td>\n",
799
  " <td>55</td>\n",
800
  " <td>...</td>\n",
 
 
 
 
 
801
  " <td>-11</td>\n",
802
+ " <td>7</td>\n",
803
+ " <td>-7</td>\n",
804
+ " <td>13</td>\n",
805
  " <td>-3</td>\n",
806
  " <td>-1</td>\n",
807
+ " <td>2</td>\n",
808
+ " <td>-3</td>\n",
809
+ " <td>4</td>\n",
810
  " <td>-7</td>\n",
 
811
  " </tr>\n",
812
  " <tr>\n",
813
  " <th>318707</th>\n",
 
822
  " <td>20</td>\n",
823
  " <td>51</td>\n",
824
  " <td>...</td>\n",
825
+ " <td>-11</td>\n",
826
  " <td>2</td>\n",
827
+ " <td>-18</td>\n",
828
+ " <td>3</td>\n",
 
829
  " <td>1</td>\n",
830
+ " <td>-3</td>\n",
831
+ " <td>2</td>\n",
832
  " <td>-7</td>\n",
833
+ " <td>2</td>\n",
834
  " <td>-3</td>\n",
 
 
835
  " </tr>\n",
836
  " </tbody>\n",
837
  "</table>\n",
 
846
  "345009 2019 4 3435 58 3292 65 H 0 \n",
847
  "318707 2013 128 3322 45 3270 63 N 0 \n",
848
  "\n",
849
+ " TeamFGM TeamFGA ... FTADiff PFDiff ScoreDiff FGADiff BlkDiff \\\n",
850
+ "337067 21 55 ... -11 9 -28 -12 1 \n",
851
+ "100732 23 60 ... 17 -9 12 -4 2 \n",
852
+ "83150 27 58 ... 10 -5 13 -6 2 \n",
853
+ "345009 19 55 ... -11 7 -7 13 -3 \n",
854
+ "318707 20 51 ... -11 2 -18 3 1 \n",
855
  "\n",
856
+ " FGM3Diff ORDiff StlDiff AstDiff DRDiff \n",
857
+ "337067 -3 -11 -7 -1 -4 \n",
858
+ "100732 -2 -1 4 11 1 \n",
859
+ "83150 1 4 -5 1 13 \n",
860
+ "345009 -1 2 -3 4 -7 \n",
861
+ "318707 -3 2 -7 2 -3 \n",
862
  "\n",
863
  "[5 rows x 49 columns]"
864
  ]
 
907
  "source": [
908
  "# combine the two detailed game dataframes into one for future use\n",
909
  "\n",
910
+ "all_detailed_games_df = pd.concat(\n",
911
+ " [\n",
912
+ " detailed_reg_games_df.assign(GameType=\"reg\"),\n",
913
+ " detailed_tourney_games_df.assign(GameType=\"tourney\"),\n",
914
+ " ]\n",
915
+ ")"
916
  ]
917
  },
918
  {
 
1337
  "source": [
1338
  "team_reg_agg = (\n",
1339
  " detailed_reg_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
1340
+ " .agg(\n",
1341
+ " {\n",
1342
+ " col: agg_funcs\n",
1343
+ " for col in detailed_reg_games_df.select_dtypes(\"number\").columns\n",
1344
+ " if col not in exclude_agg_cols\n",
1345
+ " }\n",
1346
+ " )\n",
1347
  " .reset_index()\n",
1348
  ")\n",
1349
  "\n",
 
1705
  }
1706
  ],
1707
  "source": [
1708
+ "# aggregate the same metrics for the tournament dataset\n",
1709
  "\n",
1710
  "team_tourney_agg = (\n",
1711
  " detailed_tourney_games_df.groupby([\"TeamID\", \"Season\", \"League\"])\n",
1712
+ " .agg(\n",
1713
+ " {\n",
1714
+ " col: agg_funcs\n",
1715
+ " for col in detailed_tourney_games_df.select_dtypes(\"number\").columns\n",
1716
+ " if col not in exclude_agg_cols\n",
1717
+ " }\n",
1718
+ " )\n",
1719
  " .reset_index()\n",
1720
  ")\n",
1721
  "\n",
1722
+ "team_tourney_agg.columns = [\n",
1723
+ " \" \".join(col).strip() for col in team_tourney_agg.columns.values\n",
1724
+ "]\n",
1725
  "\n",
1726
  "team_tourney_agg.sample(10, random_state=1)"
1727
  ]
 
1772
  " </thead>\n",
1773
  " <tbody>\n",
1774
  " <tr>\n",
1775
+ " <th>3591</th>\n",
1776
+ " <td>2004</td>\n",
1777
+ " <td>X02</td>\n",
1778
+ " <td>3243</td>\n",
1779
+ " <td>W</td>\n",
1780
+ " <td>big_twelve</td>\n",
1781
+ " <td>Kansas St</td>\n",
1782
+ " <td>NaN</td>\n",
1783
+ " <td>NaN</td>\n",
1784
+ " <td>2</td>\n",
1785
  " </tr>\n",
1786
  " <tr>\n",
1787
+ " <th>3528</th>\n",
1788
+ " <td>2013</td>\n",
1789
+ " <td>Y01</td>\n",
1790
+ " <td>3124</td>\n",
1791
+ " <td>W</td>\n",
1792
+ " <td>big_twelve</td>\n",
1793
+ " <td>Baylor</td>\n",
1794
+ " <td>NaN</td>\n",
1795
+ " <td>NaN</td>\n",
1796
+ " <td>1</td>\n",
1797
  " </tr>\n",
1798
  " <tr>\n",
1799
+ " <th>1891</th>\n",
1800
+ " <td>2003</td>\n",
1801
+ " <td>W02</td>\n",
1802
+ " <td>1448</td>\n",
1803
  " <td>M</td>\n",
1804
+ " <td>acc</td>\n",
1805
+ " <td>Wake Forest</td>\n",
1806
+ " <td>1985.0</td>\n",
1807
+ " <td>2024.0</td>\n",
1808
+ " <td>2</td>\n",
1809
  " </tr>\n",
1810
  " <tr>\n",
1811
+ " <th>778</th>\n",
1812
+ " <td>2019</td>\n",
1813
+ " <td>Y01</td>\n",
1814
+ " <td>1314</td>\n",
1815
  " <td>M</td>\n",
1816
+ " <td>acc</td>\n",
1817
+ " <td>North Carolina</td>\n",
1818
+ " <td>1985.0</td>\n",
1819
+ " <td>2024.0</td>\n",
1820
+ " <td>1</td>\n",
1821
  " </tr>\n",
1822
  " <tr>\n",
1823
+ " <th>2932</th>\n",
1824
+ " <td>2019</td>\n",
1825
+ " <td>X05</td>\n",
1826
+ " <td>3266</td>\n",
1827
+ " <td>W</td>\n",
1828
  " <td>big_east</td>\n",
1829
+ " <td>Marquette</td>\n",
1830
+ " <td>NaN</td>\n",
1831
+ " <td>NaN</td>\n",
1832
+ " <td>5</td>\n",
1833
  " </tr>\n",
1834
  " </tbody>\n",
1835
  "</table>\n",
1836
  "</div>"
1837
  ],
1838
  "text/plain": [
1839
+ " Season Seed TeamID League ConfAbbrev TeamName FirstD1Season \\\n",
1840
+ "3591 2004 X02 3243 W big_twelve Kansas St NaN \n",
1841
+ "3528 2013 Y01 3124 W big_twelve Baylor NaN \n",
1842
+ "1891 2003 W02 1448 M acc Wake Forest 1985.0 \n",
1843
+ "778 2019 Y01 1314 M acc North Carolina 1985.0 \n",
1844
+ "2932 2019 X05 3266 W big_east Marquette NaN \n",
1845
  "\n",
1846
+ " LastD1Season ChalkSeed \n",
1847
+ "3591 NaN 2 \n",
1848
+ "3528 NaN 1 \n",
1849
+ "1891 2024.0 2 \n",
1850
+ "778 2024.0 1 \n",
1851
+ "2932 NaN 5 "
1852
  ]
1853
  },
1854
  "execution_count": 15,
 
1857
  }
1858
  ],
1859
  "source": [
1860
+ "conference_df = pd.concat(\n",
1861
+ " [\n",
1862
+ " pd.read_csv(os.path.join(DATA_DIR, \"MNCAATourneySeeds.csv\")).assign(League=\"M\"),\n",
1863
+ " pd.read_csv(os.path.join(DATA_DIR, \"WNCAATourneySeeds.csv\")).assign(League=\"W\"),\n",
1864
+ " ]\n",
1865
+ ")\n",
 
1866
  "\n",
1867
+ "team_conf_seeds_df = conference_df.merge(\n",
1868
+ " right=(\n",
1869
+ " pd.concat(\n",
1870
+ " [\n",
1871
+ " pd.read_csv(os.path.join(DATA_DIR, \"MTeamConferences.csv\")).assign(\n",
1872
+ " League=\"M\"\n",
1873
+ " ),\n",
1874
+ " pd.read_csv(os.path.join(DATA_DIR, \"WTeamConferences.csv\")).assign(\n",
1875
+ " League=\"W\"\n",
1876
+ " ),\n",
1877
+ " ]\n",
1878
+ " )\n",
1879
+ " ),\n",
1880
+ " on=[\"League\", \"Season\", \"TeamID\"],\n",
1881
+ " how=\"left\",\n",
1882
+ ").merge(right=(\n",
1883
+ " pd.concat([\n",
1884
+ " pd.read_csv(os.path.join(DATA_DIR, \"MTeams.csv\")),\n",
1885
+ " pd.read_csv(os.path.join(DATA_DIR, \"WTeams.csv\")),\n",
1886
+ " ])),\n",
1887
+ " on=\"TeamID\",\n",
1888
  ")\n",
1889
  "\n",
1890
  "team_conf_seeds_df[\"ChalkSeed\"] = team_conf_seeds_df.apply(\n",
 
1892
  " axis=1,\n",
1893
  ")\n",
1894
  "\n",
1895
+ "team_conf_seeds_df.sample(5, random_state=1)"
1896
  ]
1897
  },
1898
  {
 
2274
  "source": [
2275
  "# merge the tournament aggregated metrics with the regular season aggregated metrics\n",
2276
  "team_agg_df = pd.merge(\n",
2277
+ " left=team_reg_agg,\n",
2278
+ " right=team_tourney_agg,\n",
2279
  " how=\"left\",\n",
2280
+ " on=[\"TeamID\", \"Season\", \"League\"],\n",
2281
  " suffixes=(\" reg\", \" tourney\"),\n",
2282
  " validate=\"1:1\",\n",
2283
  ")\n",
 
2313
  "output_type": "stream",
2314
  "text": [
2315
  "<class 'pandas.core.frame.DataFrame'>\n",
2316
+ "Int64Index: 13305 entries, 0 to 13304\n",
2317
  "Columns: 459 entries, TeamID to ChalkSeed\n",
2318
+ "dtypes: float64(453), int64(2), object(4)\n",
2319
+ "memory usage: 46.7+ MB\n"
2320
  ]
2321
  }
2322
  ],
 
2336
  "<class 'pandas.core.frame.DataFrame'>\n",
2337
  "Int64Index: 377608 entries, 0 to 377607\n",
2338
  "Columns: 508 entries, Season to ChalkSeed\n",
2339
+ "dtypes: float64(453), int64(48), object(7)\n",
2340
  "memory usage: 1.4+ GB\n"
2341
  ]
2342
  }
 
2397
  "<class 'pandas.core.frame.DataFrame'>\n",
2398
  "Int64Index: 377608 entries, 0 to 377607\n",
2399
  "Columns: 509 entries, Season to OppChalkSeed\n",
2400
+ "dtypes: float64(454), int64(48), object(7)\n",
2401
  "memory usage: 1.4+ GB\n"
2402
  ]
2403
  }
 
2405
  "source": [
2406
  "opp_chalk_seed_map = team_conf_seeds_df.groupby(\"TeamID\")[\"ChalkSeed\"].last()\n",
2407
  "\n",
2408
+ "super_detailed_games_df[\"OppChalkSeed\"] = super_detailed_games_df[\"OppTeamID\"].map(\n",
2409
+ " opp_chalk_seed_map\n",
2410
+ ")\n",
2411
  "\n",
2412
  "super_detailed_games_df.info()"
2413
  ]
 
2420
  {
2421
  "data": {
2422
  "text/plain": [
2423
+ "0 2.0\n",
2424
+ "1 -4.0\n",
2425
+ "2 1.0\n",
2426
+ "3 NaN\n",
2427
+ "4 -9.0\n",
2428
+ " ... \n",
2429
+ "377603 1.0\n",
2430
+ "377604 2.0\n",
2431
+ "377605 -1.0\n",
2432
+ "377606 -2.0\n",
2433
+ "377607 -1.0\n",
2434
+ "Name: ChalkSeedDiff, Length: 377608, dtype: float64"
2435
  ]
2436
  },
2437
  "execution_count": 22,
 
2440
  }
2441
  ],
2442
  "source": [
2443
+ "super_detailed_games_df[\"ChalkSeedDiff\"] = (\n",
2444
+ " super_detailed_games_df[\"ChalkSeed\"] - super_detailed_games_df[\"OppChalkSeed\"]\n",
2445
+ ")"
2446
  ]
2447
  },
2448
  {
src/visual_eval.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import roc_curve, precision_recall_curve
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+
6
+ def eval_binary_classification(pred: np.array, true: np.array):
7
+ plt.figure(figsize=(12, 6))
8
+ eval_roc_curve(pred, true)
9
+ eval_pr_curve(pred, true)
10
+ plt.tight_layout()
11
+ plt.show()
12
+
13
+
14
+ def eval_pr_curve(pred: np.array, true: np.array):
15
+ precision, recall, _ = precision_recall_curve(true, pred)
16
+ plt.subplot(1, 2, 1)
17
+ plt.plot(recall, precision, label="Precision-Recall Curve", color="red")
18
+ plt.ylim(0)
19
+ plt.xlabel("Recall")
20
+ plt.ylabel("Precision")
21
+ plt.title("Precision-Recall Curve")
22
+ plt.legend(loc="lower right")
23
+
24
+
25
+ def eval_roc_curve(pred: np.array, true: np.array) -> None:
26
+ false_pos_rate, true_pos_rate, _ = roc_curve(true, pred)
27
+ plt.subplot(1, 2, 2)
28
+ plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve")
29
+ plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model")
30
+ plt.title("ROC Curve vs. Random")
31
+ plt.xlabel("False Positive Rate")
32
+ plt.ylabel("True Positive Rate")
33
+ plt.legend(loc="lower right")
src/visualizations.py DELETED
@@ -1,27 +0,0 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- from sklearn.metrics import (
4
- roc_curve,
5
- roc_auc_score,
6
- precision_recall_curve,
7
- )
8
-
9
-
10
- def roc_plot(y_true: np.array, y_pred: np.array):
11
- fpr, tpr, _ = roc_curve(y_true, y_pred)
12
-
13
- # Plot ROC Curve
14
- plt.plot(fpr, tpr, label="ROC Curve")
15
- plt.plot([0, 1], [0, 1], linestyle="--", label="Random Model")
16
- plt.title("ROC Curve vs. Random Model")
17
- plt.xlabel("False Positive Rate")
18
- plt.ylabel("True Positive Rate")
19
- plt.legend("lower right")
20
-
21
- plt.tight_layout()
22
- plt.show()
23
-
24
-
25
- def precision_recall_plot(y_true: np.array, y_pred: np.array, baseline_pred: np.array):
26
- ...
27
-